32 template <
typename BottomTensorView_,
33 typename WindowLengths_,
34 typename StaticTileDistribution_,
35 typename StaticPageIndexArray_,
36 typename StaticValidArray_,
55 static constexpr
index_t NDimP = TileDstr::get_num_of_dimension_p();
56 static constexpr
index_t NDimY = TileDstr::get_num_of_dimension_y();
60 static_assert(NumCoord == 1);
65 "wrong! lengths should be static");
68 static_assert(
NDimBottomTensor == WindowAdaptor::get_num_of_bottom_dimension(),
69 "wrong! inconsistent # of diemsnions");
83 static constexpr
auto get_vector_dim_y_scalar_per_vector()
85 const auto [ys_vector_lengths, ys_vector_strides] =
93 if(ys_vector_strides[i] == 1 && ys_vector_lengths[i] > ScalarPerVector_)
95 ScalarPerVector_ = ys_vector_lengths[i];
100 return make_tuple(VectorDimY_, ScalarPerVector_);
108 get_vector_dim_y_scalar_per_vector().template at<1>();
115 static constexpr
auto scalars_per_access_ = [] {
120 constexpr
auto NDimY_ =
NDimY;
122 return TO_SEQUENCE(scalars_per_access_arr, NDimY_);
125 static constexpr
auto get_space_filling_curve()
127 constexpr
auto tile_dstr =
TileDstr{};
129 constexpr
auto thread_tensor_lengths_ys =
130 to_sequence(tile_dstr.get_ys_to_d_descriptor().get_lengths());
137 decltype(scalars_per_access_)>{};
141 using SFC_Ys = decltype(get_space_filling_curve());
145 static_assert(0 <
NumAccess,
"Wrong! NumAccess should be larger than 0");
146 static_assert(
NumAccess % NumCoord == 0,
"wrong! # of access is not divisible by NumCoord");
171 static_assert(
NDimP == 1 or
NDimP == 2,
"wrong!");
175 if constexpr(
NDimP == 1)
180 else if constexpr(
NDimP == 2)
182 window_adaptor_thread_coord_tmp =
195 window_origin + window_adaptor_thread_coord_tmp.get_bottom_index();
196 bottom_tensor_thread_origin_idx_tmp(HsGatherDim) = 0;
203 using SFC_Ys =
typename Traits::SFC_Ys;
206 auto window_adaptor_thread_coord = window_adaptor_thread_coord_tmp;
207 auto bottom_tensor_thread_coord = bottom_tensor_thread_coord_tmp;
209 constexpr
auto idx_diff_ys =
216 window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
219 make_tuple(window_adaptor_thread_coord, bottom_tensor_thread_coord);
246 template <
typename ATopIndex>
250 const ATopIndex& idx_diff_adaptor_top)
const
255 window_adaptor_thread_coord,
256 idx_diff_adaptor_top,
257 idx_diff_adaptor_bottom);
260 bottom_tensor_thread_coord,
261 idx_diff_adaptor_bottom);
268 const auto [bottom_tensor_top_dim_vector_lengths, bottom_tensor_top_dim_vector_strides] =
269 BottomTensorDesc::get_top_dimension_safe_vector_length_strides();
272 const auto window_adaptor_bottom_dim_vector_lengths = bottom_tensor_top_dim_vector_lengths;
273 const auto window_adaptor_bottom_dim_vector_strides = bottom_tensor_top_dim_vector_strides;
276 array<
index_t, WindowAdaptor::get_num_of_hidden_dimension()> window_adaptor_vector_lengths{
278 array<
index_t, WindowAdaptor::get_num_of_hidden_dimension()> window_adaptor_vector_strides{
281 constexpr
auto window_adaptor_bottom_dims =
282 WindowAdaptor::get_bottom_dimension_hidden_ids();
285 window_adaptor_bottom_dims,
286 window_adaptor_bottom_dim_vector_lengths);
288 window_adaptor_bottom_dims,
289 window_adaptor_bottom_dim_vector_strides);
291 const auto [window_adaptor_ps_ys_vector_lengths, window_adaptor_ps_ys_vector_strides] =
292 WindowAdaptor{}.get_top_dimension_safe_vector_length_strides(
293 window_adaptor_vector_lengths, window_adaptor_vector_strides);
306 template <
index_t i_access_unsupport_ = -1,
bool oob_conditional_check =
true>
310 constexpr
auto tile_dstr =
TileDstr{};
311 auto dst_tensor = make_static_distributed_tensor<DataType>(tile_dstr);
312 load(dst_tensor, number<i_access_unsupport_>{}, bool_constant<oob_conditional_check>{});
316 template <
typename DistributedTensor,
317 index_t i_access_unsupport_ = -1,
318 bool oob_conditional_check =
true>
323 using Traits = load_store_traits;
324 using vector_t =
typename Traits::vector_t;
325 using SFC_Ys =
typename Traits::SFC_Ys;
327 constexpr
auto tile_dstr =
TileDstr{};
330 static_for<0, NumCoord, 1>{}([&](
auto iCoord) {
335 static_for<0, NumAccessPerCoord, 1>{}([&](
auto iCoordAccess) {
336 constexpr
auto iAccess = number<iCoord * NumAccessPerCoord + iCoordAccess>{};
339 constexpr
auto idx_ys_start = SFC_Ys::get_index(iAccess);
340 constexpr
auto idx_gather = idx_ys_start[number<YsGatherDim>{}];
341 const auto page_offset =
page_idx_[idx_gather];
344 const vector_t vec_value = [&]() {
345 if constexpr(std::is_same_v<ValidArray, std::nullptr_t>)
348 bottom_tensor_thread_coord,
350 bool_constant<oob_conditional_check>{});
355 bottom_tensor_thread_coord,
358 bool_constant<oob_conditional_check>{});
363 static_for<0, Traits::ScalarPerVector, Traits::PackedSize>{}([&](
auto j) {
366 return jj == Traits::VectorDimY ? (idx_ys_start[jj] + j)
372 tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys) /
375 dst_tensor.get_thread_buffer().template at<d>() =
376 vec_value.template get_as<DataType>()[j / Traits::PackedSize];
380 tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys_start);
381 static_assert(d % Traits::ScalarPerVector == 0);
383 dst_tensor.get_thread_buffer().template get_as<vector_t>()(
384 number<d / Traits::ScalarPerVector>{}) = bit_cast<vector_t>(vec_value);
389 constexpr
auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess);
392 [&](
auto i) {
return i == YsGatherDim ? 0 : idx_diff_ys[i]; },
396 generate_tuple([&](
auto) {
return number<0>{}; }, number<NDimP>{}),
397 forward_step_scatter);
400 window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
406 template <
typename LdsTileWindow_,
407 index_t i_access_unsupport_ = -1,
408 bool oob_conditional_check =
true>
413 using LdsTileWindow = remove_cvref_t<LdsTileWindow_>;
414 using LdsDataType =
typename LdsTileWindow::DataType;
415 using Traits = load_store_traits;
416 using vector_t =
typename Traits::vector_t;
417 using SFC_Ys =
typename Traits::SFC_Ys;
419 constexpr
auto tile_dstr =
TileDstr{};
422 const auto window_origin = lds_tile.get_window_origin();
423 const auto& bottom_tensor_view = lds_tile.get_bottom_tensor_view();
424 const auto& tensor_descriptor = bottom_tensor_view.get_tensor_descriptor();
425 auto smem_base_ptr = bottom_tensor_view.get_buffer_view().p_data_;
428 static_for<0, NumCoord, 1>{}([&](
auto iCoord) {
436 static_for<0, NumAccessPerCoord, 1>{}([&](
auto iCoordAccess) {
437 constexpr
auto iAccess = number<iCoord * NumAccessPerCoord + iCoordAccess>{};
440 auto lds_bottom_tensor_thread_idx =
441 window_origin + lds_window_adaptor_thread_coord.get_bottom_index();
443 const auto lds_coord =
446 CK_TILE_LDS_ADDR LdsDataType* smem = smem_base_ptr + lds_coord.get_offset();
449 constexpr
auto idx_ys_start = SFC_Ys::get_index(iAccess);
450 constexpr
auto idx_gather = idx_ys_start[number<YsGatherDim>{}];
451 const auto page_offset =
page_idx_[idx_gather];
454 auto mixed_bottom_thread_coord = bottom_tensor_thread_coord;
455 mixed_bottom_thread_coord.get_hidden_index()[number<0>{}] += page_offset;
458 if constexpr(std::is_same_v<ValidArray, std::nullptr_t>)
461 mixed_bottom_thread_coord,
463 bool_constant<oob_conditional_check>{});
467 mixed_bottom_thread_coord,
470 bool_constant<oob_conditional_check>{});
475 constexpr
auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess);
478 [&](
auto i) {
return i == YsGatherDim ? 0 : idx_diff_ys[i]; },
482 generate_tuple([&](
auto) {
return number<0>{}; }, number<NDimP>{}),
483 forward_step_scatter);
486 generate_tuple([&](
auto) {
return number<0>{}; }, number<NDimP>{}),
490 window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
492 lds_window_adaptor_thread_coord,
493 lds_bottom_tensor_thread_coord,
501 template <
typename LdsTileWindow_,
502 index_t i_access_unsupport_ = -1,
503 bool oob_conditional_check =
true,
504 bool pre_nop =
false>
508 bool_constant<pre_nop> = {})
const
510 using LdsTileWindow = remove_cvref_t<LdsTileWindow_>;
512 using LdsDataType =
typename LdsTileWindow::DataType;
516 static_assert(LdsTileWindow::get_num_of_dimension() == 3);
519 lds_tile.get_bottom_tensor_view().get_tensor_descriptor().calculate_offset(
520 make_tuple(number<0>{}, number<0>{}, number<0>{})) *
524 lds_tile.get_bottom_tensor_view().get_tensor_descriptor().calculate_offset(
525 make_tuple(number<0>{}, number<1>{}, number<0>{})) *
526 sizeof(LdsDataType) -
530 lds_tile.get_bottom_tensor_view().get_tensor_descriptor().calculate_offset(
531 make_tuple(number<1>{}, number<0>{}, number<0>{})) *
532 sizeof(LdsDataType) -
535 const index_t m0_init_value = size_per_buf + size_per_wave * get_warp_id();
539 using Traits = load_store_traits;
542 using vector_t =
typename Traits::vector_t;
543 using SFC_Ys =
typename Traits::SFC_Ys;
545 LdsDataType* smem = lds_tile.get_bottom_tensor_view().get_buffer_view().p_data_;
548 static_for<0, NumCoord, 1>{}([&](
auto iCoord) {
553 static_for<0, NumAccessPerCoord, 1>{}([&](
auto iCoordAccess) {
554 constexpr
auto iAccess = number<iCoord * NumAccessPerCoord + iCoordAccess>{};
555 constexpr
auto pre_nop_ = [&]() {
556 if constexpr(pre_nop && iCoord == 0 && iCoordAccess == 0)
559 return bool_constant<false>{};
562 constexpr
auto idx_ys_start = SFC_Ys::get_index(iAccess);
563 constexpr
auto idx_gather = idx_ys_start[number<YsGatherDim>{}];
564 const auto page_offset =
page_idx_[idx_gather];
567 if constexpr(std::is_same_v<ValidArray, std::nullptr_t>)
570 smem, bottom_tensor_thread_coord, page_offset, 0, pre_nop_);
576 bottom_tensor_thread_coord,
586 constexpr
auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess);
589 [&](
auto i) {
return i == YsGatherDim ? 0 : idx_diff_ys[i]; },
593 generate_tuple([&](
auto) {
return number<0>{}; }, number<NDimP>{}),
594 forward_step_scatter);
597 window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
605 template <
index_t i_access_unsupport_ = -1,
bool oob_conditional_check =
true>
610 using Traits = load_store_traits;
613 using vector_t =
typename Traits::vector_t;
614 using SFC_Ys =
typename Traits::SFC_Ys;
616 constexpr
auto tile_dstr =
TileDstr{};
618 static_for<0, NumCoord, 1>{}([&](
auto iCoord) {
622 static_for<0, NumAccessPerCoord, 1>{}([&](
auto iCoordAccess) {
623 constexpr
auto iAccess = number<iCoord * NumAccessPerCoord + iCoordAccess>{};
626 constexpr
auto idx_ys_start = SFC_Ys::get_index(iAccess);
627 constexpr
auto idx_gather = idx_ys_start[number<0>{}];
628 const auto page_offset =
page_idx_[idx_gather];
633 static_for<0, Traits::ScalarPerVector, Traits::PackedSize>{}([&](
auto j) {
636 return jj == Traits::VectorDimY ? (idx_ys_start[jj] + j)
642 tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys) /
645 vec_value.template get_as<DataType>()(j / Traits::PackedSize) =
650 if constexpr(std::is_same_v<ValidArray, std::nullptr_t>)
653 bottom_tensor_thread_coord,
656 bool_constant<oob_conditional_check>{});
661 bottom_tensor_thread_coord,
665 bool_constant<oob_conditional_check>{});
670 constexpr
auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess);
673 [&](
auto i) {
return i == YsGatherDim ? 0 : idx_diff_ys[i]; },
677 generate_tuple([&](
auto) {
return number<0>{}; }, number<NDimP>{}),
678 forward_step_scatter);
681 window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
687 template <
index_t i_access_unsupport_ = -1,
bool oob_conditional_check =
true>
692 using Traits = load_store_traits;
695 using vector_t =
typename Traits::vector_t;
696 using SFC_Ys =
typename Traits::SFC_Ys;
698 constexpr
auto tile_dstr =
TileDstr{};
701 static_for<0, NumCoord, 1>{}([&](
auto iCoord) {
705 static_for<0, NumAccessPerCoord, 1>{}([&](
auto iCoordAccess) {
706 constexpr
auto iAccess = number<iCoord * NumAccessPerCoord + iCoordAccess>{};
709 constexpr
auto idx_ys_start = SFC_Ys::get_index(iAccess);
710 constexpr
auto idx_gather = idx_ys_start[number<0>{}];
711 const auto page_offset =
page_idx_[idx_gather];
720 static_for<0, Traits::ScalarPerVector, Traits::PackedSize>{}([&](
auto j) {
723 return jj == Traits::VectorDimY ? (idx_ys_start[jj] + j)
729 tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys) /
732 vec_value.template get_as<DataType>()(j / Traits::PackedSize) =
739 if constexpr(std::is_same_v<ValidArray, std::nullptr_t>)
742 bottom_tensor_thread_coord,
745 bool_constant<oob_conditional_check>{});
750 bottom_tensor_thread_coord,
754 bool_constant<oob_conditional_check>{});
761 constexpr
auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess);
764 [&](
auto i) {
return i == YsGatherDim ? 0 : idx_diff_ys[i]; },
768 generate_tuple([&](
auto) {
return number<0>{}; }, number<NDimP>{}),
769 forward_step_scatter);
772 window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
785 step_new(HsGatherDim) = 0;
797 if constexpr(std::is_same_v<ValidArray, std::nullptr_t> ==
false)
818 static_assert(
NDimP == 1 or
NDimP == 2,
"wrong!");
822 if constexpr(
NDimP == 1)
827 else if constexpr(
NDimP == 2)
829 window_adaptor_thread_coord_tmp =
842 window_origin_ + window_adaptor_thread_coord_tmp.get_bottom_index();
844 bottom_tensor_thread_origin_idx_tmp(HsGatherDim) = 0;
851 using SFC_Ys =
typename Traits::SFC_Ys;
854 auto window_adaptor_thread_coord = window_adaptor_thread_coord_tmp;
855 auto bottom_tensor_thread_coord = bottom_tensor_thread_coord_tmp;
857 constexpr
auto idx_diff_ys =
864 window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
867 make_tuple(window_adaptor_thread_coord, bottom_tensor_thread_coord);
898 template <
typename TensorView_,
899 typename WindowLengths_,
900 typename StaticTileDistribution_,
901 typename StaticPageIndexArray_,
906 const WindowLengths_& window_lengths,
907 const multi_index<TensorView_::get_num_of_dimension()>& origin,
909 const StaticPageIndexArray_& page_idx,
911 number<NumCoord> = {})
913 return tile_scatter_gather<remove_cvref_t<TensorView_>,
914 remove_cvref_t<WindowLengths_>,
915 remove_cvref_t<StaticTileDistribution_>,
916 remove_cvref_t<StaticPageIndexArray_>,
920 tensor_view, window_lengths, origin, tile_distribution, page_idx,
nullptr};
923 template <
typename TensorView,
924 typename WindowLengths,
925 typename StaticTileDistribution,
926 typename StaticPageIndexArray,
930 const multi_index<TensorView::get_num_of_dimension()>& origin,
932 const StaticPageIndexArray& page_idx,
940 number<HsGatherDim>{});
943 template <
typename TensorView,
944 typename WindowLengths,
945 typename StaticTileDistribution,
946 typename StaticPageIndexArray,
951 const StaticPageIndexArray& page_idx,
959 number<HsGatherDim>{});
962 template <
typename TensorView_,
963 typename WindowLengths_,
964 typename StaticTileDistribution_,
965 typename StaticPageIndexArray_,
966 typename StaticValidArray_,
971 const WindowLengths_& window_lengths,
972 const multi_index<TensorView_::get_num_of_dimension()>& origin,
974 const StaticPageIndexArray_& page_idx,
975 const StaticValidArray_& valids,
977 number<NumCoord> = {})
979 return tile_scatter_gather<remove_cvref_t<TensorView_>,
980 remove_cvref_t<WindowLengths_>,
981 remove_cvref_t<StaticTileDistribution_>,
982 remove_cvref_t<StaticPageIndexArray_>,
983 remove_cvref_t<StaticValidArray_>,
986 tensor_view, window_lengths, origin, tile_distribution, page_idx, valids};
989 template <
typename TensorView,
990 typename WindowLengths,
991 typename StaticTileDistribution,
992 typename StaticPageIndexArray,
993 typename StaticValidArray,
997 const multi_index<TensorView::get_num_of_dimension()>& origin,
999 const StaticPageIndexArray& page_idx,
1000 const StaticValidArray& valids,
1009 number<HsGatherDim>{});
1012 template <
typename TensorView,
1013 typename WindowLengths,
1014 typename StaticTileDistribution,
1015 typename StaticPageIndexArray,
1016 typename StaticValidArray,
1021 const StaticPageIndexArray& page_idx,
1022 const StaticValidArray& valids,
1031 number<HsGatherDim>{});
1034 template <
typename NewTensorView_,
1035 typename OldTensorView_,
1036 typename WindowLengths_,
1037 typename StaticTileDistribution_,
1038 typename StaticPageIndexArray_,
1039 typename StaticValidArray_,
1045 StaticTileDistribution_,
1046 StaticPageIndexArray_,
1049 NumCoord>& tile_window)
1054 tile_window.tile_dstr_,
1055 tile_window.page_idx_,
1056 tile_window.valids_);
#define CK_TILE_DEVICE
Definition: config.hpp:45
#define CK_TILE_LDS_ADDR
Definition: config.hpp:62
#define CK_TILE_HOST_DEVICE
Definition: config.hpp:46
Definition: cluster_descriptor.hpp:13
constexpr CK_TILE_HOST_DEVICE void set_container_subset(array< T, N > &y, sequence< Is... > picks, const array< T, sizeof...(Is)> &x)
Definition: container_helper.hpp:420
__device__ uint32_t amd_wave_read_first_lane(uint16_t v)
Definition: amd_buffer_addressing.hpp:35
constexpr CK_TILE_HOST_DEVICE void move_tensor_coordinate(const TensorDesc &tensor_desc, TensorCoord &coord, const Index &coord_step)
Definition: tensor_coordinate.hpp:72
CK_TILE_DEVICE auto replace_bottom_tensor_view(const NewTensorView_ &new_tensor_view, const tile_scatter_gather< OldTensorView_, WindowLengths_, StaticTileDistribution_, StaticPageIndexArray_, StaticValidArray_, HsGatherDim, NumCoord > &tile_window)
Definition: tile_scatter_gather.hpp:1042
constexpr CK_TILE_HOST_DEVICE auto make_tensor_adaptor_coordinate(const Adaptor &adaptor, const TopIndex &idx_top)
Definition: tensor_adaptor_coordinate.hpp:56
constant< b > bool_constant
Definition: integral_constant.hpp:43
constexpr CK_TILE_HOST_DEVICE auto generate_array(F &&f, number< N >)
Definition: sequence.hpp:1126
int32_t index_t
Definition: integer.hpp:9
constexpr CK_TILE_HOST_DEVICE auto make_tensor_coordinate(const TensorDesc &tensor_desc, const TopIndex &idx_top)
Definition: tensor_coordinate.hpp:60
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition: type_traits.hpp:21
constant< v > number
Definition: integral_constant.hpp:37
constexpr CK_TILE_HOST_DEVICE auto to_sequence(tuple< number< Is >... >)
Definition: sequence.hpp:1066
CK_TILE_HOST_DEVICE auto get_partition_index(Distribution)
Definition: tile_distribution.hpp:21
constexpr CK_TILE_HOST_DEVICE auto generate_tuple(F &&f, number< N >)
Definition: tuple.hpp:429
CK_TILE_DEVICE void m0_set_with_memory(index_t v)
Definition: utility.hpp:19
typename std::remove_reference< T >::type remove_reference_t
Definition: type_traits.hpp:15
constexpr CK_TILE_HOST_DEVICE auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:360
constexpr CK_TILE_DEVICE auto make_tile_scatter_gather(const TensorView_ &tensor_view, const WindowLengths_ &window_lengths, const multi_index< TensorView_::get_num_of_dimension()> &origin, const StaticTileDistribution_ &tile_distribution, const StaticPageIndexArray_ &page_idx, number< HsGatherDim >={}, number< NumCoord >={})
Definition: tile_scatter_gather.hpp:905
constexpr CK_TILE_HOST_DEVICE void move_tensor_adaptor_coordinate(const Adaptor &adaptor, AdaptorCoord &coord, const TopIndex &idx_diff_top, BottomIndex &idx_diff_bottom)
Definition: tensor_adaptor_coordinate.hpp:98
constexpr CK_TILE_HOST_DEVICE auto get_container_subset(const array< T, N > &arr, sequence< Is... >)
Definition: container_helper.hpp:389
CK_TILE_DEVICE void m0_inc_with_memory(index_t v)
Definition: utility.hpp:25
impl::is_static_impl< remove_cvref_t< T > > is_static
Definition: type_traits.hpp:87
constexpr CK_TILE_HOST_DEVICE auto container_concat(const X &x, const Ys &... ys)
Definition: container_helper.hpp:363
Definition: sequence.hpp:298
typename std::conditional< kHasContent, type0, type1 >::type type
Definition: sequence.hpp:313
A fixed-size array container similar to std::array with additional utilities.
Definition: array.hpp:43
Definition: integral_constant.hpp:13
Definition: type_traits.hpp:76
Definition: numeric.hpp:81
Definition: space_filling_curve.hpp:20
Definition: static_distributed_tensor.hpp:21
constexpr CK_TILE_HOST_DEVICE const auto & get_thread_buffer() const
Definition: static_distributed_tensor.hpp:58
Definition: functional.hpp:43
Definition: tensor_view.hpp:41
Definition: tile_distribution.hpp:70
constexpr CK_TILE_HOST_DEVICE const auto & get_ps_ys_to_xs_adaptor() const
Definition: tile_distribution.hpp:124
Definition: tile_scatter_gather.hpp:81
static constexpr index_t PackedSize
Definition: tile_scatter_gather.hpp:104
static constexpr index_t NumAccess
Definition: tile_scatter_gather.hpp:143
decltype(get_space_filling_curve()) SFC_Ys
Definition: tile_scatter_gather.hpp:141
static constexpr index_t VectorDimY
Definition: tile_scatter_gather.hpp:106
static constexpr index_t ScalarPerVector
Definition: tile_scatter_gather.hpp:107
This class provides tile (windowed) view and access to the device memory.
Definition: tile_scatter_gather.hpp:41
CK_TILE_DEVICE void move(const BottomTensorIndex &step)
Definition: tile_scatter_gather.hpp:781
static constexpr index_t NumAccessPerCoord
Definition: tile_scatter_gather.hpp:149
static constexpr auto I1
Definition: tile_scatter_gather.hpp:59
constexpr CK_TILE_DEVICE tile_scatter_gather(const BottomTensorView &bottom_tensor_view, const WindowLengths &window_lengths, const BottomTensorIndex &window_origin, const TileDstr &tile_distribution, const PageIdxArray &page_idx, const ValidArray &valids)
Definition: tile_scatter_gather.hpp:153
BottomTensorIndex window_origin_
Definition: tile_scatter_gather.hpp:881
CK_TILE_DEVICE auto load(number< i_access_unsupport_ >={}, bool_constant< oob_conditional_check >={}) const
Definition: tile_scatter_gather.hpp:307
WindowLengths window_lengths_
Definition: tile_scatter_gather.hpp:878
constexpr CK_TILE_DEVICE auto get_tile_distribution() const
Definition: tile_scatter_gather.hpp:232
constexpr CK_TILE_DEVICE auto get_num_of_access() const
Definition: tile_scatter_gather.hpp:304
static constexpr index_t NDimBottomTensor
Definition: tile_scatter_gather.hpp:53
static constexpr CK_TILE_DEVICE auto get_window_adaptor_ys_safe_vector_length_strides()
Definition: tile_scatter_gather.hpp:265
array< index_t, NDimBottomTensor > BottomTensorIndex
Definition: tile_scatter_gather.hpp:72
PageIdxArray page_idx_
Definition: tile_scatter_gather.hpp:888
CK_TILE_DEVICE auto async_load(LdsTileWindow_ &&lds_tile, number< i_access_unsupport_ >={}, bool_constant< oob_conditional_check >={}) const
Definition: tile_scatter_gather.hpp:409
remove_cvref_t< WindowLengths_ > WindowLengths
Definition: tile_scatter_gather.hpp:43
CK_TILE_DEVICE void set_window_origin(const BottomTensorIndex &new_window_origin)
Definition: tile_scatter_gather.hpp:810
array< tuple< WindowAdaptorCoord, BottomTensorCoord >, NumCoord > pre_computed_coords_
Definition: tile_scatter_gather.hpp:894
constexpr CK_TILE_DEVICE auto get_window_origin() const
Definition: tile_scatter_gather.hpp:236
remove_cvref_t< StaticTileDistribution_ > TileDstr
Definition: tile_scatter_gather.hpp:44
CK_TILE_DEVICE void move_window_adaptor_and_bottom_tensor_thread_coordinate(WindowAdaptorCoord &window_adaptor_thread_coord, BottomTensorCoord &bottom_tensor_thread_coord, const ATopIndex &idx_diff_adaptor_top) const
Definition: tile_scatter_gather.hpp:247
CK_TILE_DEVICE auto load(DistributedTensor &dst_tensor, number< i_access_unsupport_ >={}, bool_constant< oob_conditional_check >={}) const
Definition: tile_scatter_gather.hpp:319
CK_TILE_DEVICE void store(const static_distributed_tensor< DataType, TileDstr > &dstr_tensor, number< i_access_unsupport_ >={}, bool_constant< oob_conditional_check >={}) const
Definition: tile_scatter_gather.hpp:688
CK_TILE_DEVICE void update_page_idx_and_valids(const PageIdxArray &new_idx, const ValidArray &new_valids)
Definition: tile_scatter_gather.hpp:803
typename BottomTensorView::TensorDesc BottomTensorDesc
Definition: tile_scatter_gather.hpp:48
TileDstr tile_dstr_
Definition: tile_scatter_gather.hpp:886
ValidArray valids_
Definition: tile_scatter_gather.hpp:889
static constexpr index_t NDimY
Definition: tile_scatter_gather.hpp:56
remove_cvref_t< typename BottomTensorView::DataType > DataType
Definition: tile_scatter_gather.hpp:50
static constexpr index_t NDimWindowAdaptorTop
Definition: tile_scatter_gather.hpp:52
static constexpr CK_TILE_DEVICE bool has_static_tile_distribution()
Definition: tile_scatter_gather.hpp:225
remove_cvref_t< StaticValidArray_ > ValidArray
Definition: tile_scatter_gather.hpp:46
static constexpr index_t NDimP
Definition: tile_scatter_gather.hpp:55
remove_reference_t< BottomTensorView_ > BottomTensorView
Definition: tile_scatter_gather.hpp:42
constexpr CK_TILE_DEVICE tile_scatter_gather()=default
CK_TILE_DEVICE void update(const static_distributed_tensor< DataType, TileDstr > &dstr_tensor, number< i_access_unsupport_ >={}, bool_constant< oob_conditional_check >={}) const
Definition: tile_scatter_gather.hpp:606
remove_cvref_t< StaticPageIndexArray_ > PageIdxArray
Definition: tile_scatter_gather.hpp:45
constexpr CK_TILE_DEVICE auto get_window_lengths() const
Definition: tile_scatter_gather.hpp:230
CK_TILE_HOST_DEVICE void init_raw()
Definition: tile_scatter_gather.hpp:871
static constexpr auto I0
Definition: tile_scatter_gather.hpp:58
decltype(make_tensor_coordinate(BottomTensorDesc{}, BottomTensorIndex{})) BottomTensorCoord
Definition: tile_scatter_gather.hpp:78
constexpr CK_TILE_DEVICE auto get_bottom_tensor_view() const
Definition: tile_scatter_gather.hpp:234
typename TileDstr::PsYs2XsAdaptor WindowAdaptor
Definition: tile_scatter_gather.hpp:47
decltype(make_tensor_adaptor_coordinate(WindowAdaptor{}, AdaptorTopIndex{})) WindowAdaptorCoord
Definition: tile_scatter_gather.hpp:75
constexpr CK_TILE_DEVICE void set_bottom_tensor_view_data_ptr(typename BottomTensorView::DataType *data)
Definition: tile_scatter_gather.hpp:239
BottomTensorView bottom_tensor_view_
Definition: tile_scatter_gather.hpp:875
CK_TILE_DEVICE void update_valids(const ValidArray &new_valids)
Definition: tile_scatter_gather.hpp:795
CK_TILE_DEVICE auto async_load_raw(LdsTileWindow_ &&lds_tile, number< i_access_unsupport_ >={}, bool_constant< oob_conditional_check >={}, bool_constant< pre_nop >={}) const
Definition: tile_scatter_gather.hpp:505
array< index_t, NDimWindowAdaptorTop > AdaptorTopIndex
Definition: tile_scatter_gather.hpp:71
CK_TILE_DEVICE void update_page_idx(const PageIdxArray &new_idx)
Definition: tile_scatter_gather.hpp:793
static constexpr CK_TILE_DEVICE index_t get_num_of_dimension()
Definition: tile_scatter_gather.hpp:223
constexpr CK_TILE_DEVICE auto get_window_origin() const
Definition: tile_window_base.hpp:45
BottomTensorIndex window_origin_
Definition: tile_window_base.hpp:79
constexpr CK_TILE_DEVICE auto get_bottom_tensor_view() const
Definition: tile_window_base.hpp:47
constexpr CK_TILE_DEVICE auto get_window_lengths() const
Definition: tile_window_base.hpp:46
WindowLengths window_lengths_
Definition: tile_window_base.hpp:81
This class provides description of tile windowed view on the device memory.
Definition: tile_window.hpp:1195
#define TO_SEQUENCE(a, n)
Definition: to_sequence.hpp:10