33 template <
typename BottomTensorView_,
34 typename WindowLengths_,
35 typename StaticTileDistribution_,
36 typename StaticPageIndexArray_,
37 typename StaticValidArray_,
56 static constexpr
index_t NDimP = TileDstr::get_num_of_dimension_p();
57 static constexpr
index_t NDimY = TileDstr::get_num_of_dimension_y();
61 static_assert(NumCoord == 1);
66 "wrong! lengths should be static");
69 static_assert(
NDimBottomTensor == WindowAdaptor::get_num_of_bottom_dimension(),
70 "wrong! inconsistent # of diemsnions");
84 static constexpr
auto get_vector_dim_y_scalar_per_vector()
86 const auto [ys_vector_lengths, ys_vector_strides] =
94 if(ys_vector_strides[i] == 1 && ys_vector_lengths[i] > ScalarPerVector_)
96 ScalarPerVector_ = ys_vector_lengths[i];
101 return make_tuple(VectorDimY_, ScalarPerVector_);
109 get_vector_dim_y_scalar_per_vector().template at<1>();
116 static constexpr
auto scalars_per_access_ = [] {
121 constexpr
auto NDimY_ =
NDimY;
123 return TO_SEQUENCE(scalars_per_access_arr, NDimY_);
126 static constexpr
auto get_space_filling_curve()
128 constexpr
auto tile_dstr =
TileDstr{};
130 constexpr
auto thread_tensor_lengths_ys =
131 to_sequence(tile_dstr.get_ys_to_d_descriptor().get_lengths());
138 decltype(scalars_per_access_)>{};
142 using SFC_Ys = decltype(get_space_filling_curve());
146 static_assert(0 <
NumAccess,
"Wrong! NumAccess should be larger than 0");
147 static_assert(
NumAccess % NumCoord == 0,
"wrong! # of access is not divisible by NumCoord");
172 static_assert(
NDimP == 1 or
NDimP == 2,
"wrong!");
176 if constexpr(
NDimP == 1)
181 else if constexpr(
NDimP == 2)
183 window_adaptor_thread_coord_tmp =
197 window_origin + window_adaptor_thread_coord_tmp.get_bottom_index();
198 bottom_tensor_thread_origin_idx_tmp(HsGatherDim) = 0;
205 using SFC_Ys =
typename Traits::SFC_Ys;
208 auto window_adaptor_thread_coord = window_adaptor_thread_coord_tmp;
209 auto bottom_tensor_thread_coord = bottom_tensor_thread_coord_tmp;
211 constexpr
auto idx_diff_ys =
218 window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
221 make_tuple(window_adaptor_thread_coord, bottom_tensor_thread_coord);
248 template <
typename ATopIndex>
252 const ATopIndex& idx_diff_adaptor_top)
const
257 window_adaptor_thread_coord,
258 idx_diff_adaptor_top,
259 idx_diff_adaptor_bottom);
262 bottom_tensor_thread_coord,
263 idx_diff_adaptor_bottom);
270 const auto [bottom_tensor_top_dim_vector_lengths, bottom_tensor_top_dim_vector_strides] =
271 BottomTensorDesc::get_top_dimension_safe_vector_length_strides();
274 const auto window_adaptor_bottom_dim_vector_lengths = bottom_tensor_top_dim_vector_lengths;
275 const auto window_adaptor_bottom_dim_vector_strides = bottom_tensor_top_dim_vector_strides;
278 array<
index_t, WindowAdaptor::get_num_of_hidden_dimension()> window_adaptor_vector_lengths{
280 array<
index_t, WindowAdaptor::get_num_of_hidden_dimension()> window_adaptor_vector_strides{
283 constexpr
auto window_adaptor_bottom_dims =
284 WindowAdaptor::get_bottom_dimension_hidden_ids();
287 window_adaptor_bottom_dims,
288 window_adaptor_bottom_dim_vector_lengths);
290 window_adaptor_bottom_dims,
291 window_adaptor_bottom_dim_vector_strides);
293 const auto [window_adaptor_ps_ys_vector_lengths, window_adaptor_ps_ys_vector_strides] =
294 WindowAdaptor{}.get_top_dimension_safe_vector_length_strides(
295 window_adaptor_vector_lengths, window_adaptor_vector_strides);
308 template <
index_t i_access_unsupport_ = -1,
bool oob_conditional_check =
true>
312 constexpr
auto tile_dstr =
TileDstr{};
313 auto dst_tensor = make_static_distributed_tensor<DataType>(tile_dstr);
314 load(dst_tensor, number<i_access_unsupport_>{}, bool_constant<oob_conditional_check>{});
318 template <
typename DistributedTensor,
319 index_t i_access_unsupport_ = -1,
320 bool oob_conditional_check =
true>
325 using Traits = load_store_traits;
326 using vector_t =
typename Traits::vector_t;
327 using SFC_Ys =
typename Traits::SFC_Ys;
329 constexpr
auto tile_dstr =
TileDstr{};
332 static_for<0, NumCoord, 1>{}([&](
auto iCoord) {
337 static_for<0, NumAccessPerCoord, 1>{}([&](
auto iCoordAccess) {
338 constexpr
auto iAccess = number<iCoord * NumAccessPerCoord + iCoordAccess>{};
341 constexpr
auto idx_ys_start = SFC_Ys::get_index(iAccess);
342 constexpr
auto idx_gather = idx_ys_start[number<YsGatherDim>{}];
343 const auto page_offset =
page_idx_[idx_gather];
346 const vector_t vec_value = [&]() {
347 if constexpr(std::is_same_v<ValidArray, std::nullptr_t>)
350 bottom_tensor_thread_coord,
352 bool_constant<oob_conditional_check>{});
357 bottom_tensor_thread_coord,
360 bool_constant<oob_conditional_check>{});
365 static_for<0, Traits::ScalarPerVector, Traits::PackedSize>{}([&](
auto j) {
368 return jj == Traits::VectorDimY ? (idx_ys_start[jj] + j)
374 tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys) /
377 dst_tensor.get_thread_buffer().template at<d>() =
378 vec_value.template get_as<DataType>()[j / Traits::PackedSize];
382 tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys_start);
383 static_assert(d % Traits::ScalarPerVector == 0);
385 dst_tensor.get_thread_buffer().template get_as<vector_t>()(
386 number<d / Traits::ScalarPerVector>{}) = bit_cast<vector_t>(vec_value);
391 constexpr
auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess);
394 [&](
auto i) {
return i == YsGatherDim ? 0 : idx_diff_ys[i]; },
398 generate_tuple([&](
auto) {
return number<0>{}; }, number<NDimP>{}),
399 forward_step_scatter);
402 window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
408 template <
typename LdsTileWindow_,
409 index_t i_access_unsupport_ = -1,
410 bool oob_conditional_check =
true>
415 using LdsTileWindow = remove_cvref_t<LdsTileWindow_>;
416 using LdsDataType =
typename LdsTileWindow::DataType;
417 using Traits = load_store_traits;
418 using vector_t =
typename Traits::vector_t;
419 using SFC_Ys =
typename Traits::SFC_Ys;
421 constexpr
auto tile_dstr =
TileDstr{};
424 const auto window_origin = lds_tile.get_window_origin();
425 const auto& bottom_tensor_view = lds_tile.get_bottom_tensor_view();
426 const auto& tensor_descriptor = bottom_tensor_view.get_tensor_descriptor();
427 auto smem_base_ptr = bottom_tensor_view.get_buffer_view().p_data_;
430 static_for<0, NumCoord, 1>{}([&](
auto iCoord) {
438 static_for<0, NumAccessPerCoord, 1>{}([&](
auto iCoordAccess) {
439 constexpr
auto iAccess = number<iCoord * NumAccessPerCoord + iCoordAccess>{};
442 auto lds_bottom_tensor_thread_idx =
443 window_origin + lds_window_adaptor_thread_coord.get_bottom_index();
445 const auto lds_coord =
448 CK_TILE_LDS_ADDR LdsDataType* smem = smem_base_ptr + lds_coord.get_offset();
451 constexpr
auto idx_ys_start = SFC_Ys::get_index(iAccess);
452 constexpr
auto idx_gather = idx_ys_start[number<YsGatherDim>{}];
453 const auto page_offset =
page_idx_[idx_gather];
456 auto mixed_bottom_thread_coord = bottom_tensor_thread_coord;
457 mixed_bottom_thread_coord.get_hidden_index()[number<0>{}] += page_offset;
460 if constexpr(std::is_same_v<ValidArray, std::nullptr_t>)
463 mixed_bottom_thread_coord,
465 bool_constant<oob_conditional_check>{});
469 mixed_bottom_thread_coord,
472 bool_constant<oob_conditional_check>{});
477 constexpr
auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess);
480 [&](
auto i) {
return i == YsGatherDim ? 0 : idx_diff_ys[i]; },
484 generate_tuple([&](
auto) {
return number<0>{}; }, number<NDimP>{}),
485 forward_step_scatter);
488 generate_tuple([&](
auto) {
return number<0>{}; }, number<NDimP>{}),
492 window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
494 lds_window_adaptor_thread_coord,
495 lds_bottom_tensor_thread_coord,
503 template <
typename LdsTileWindow_,
504 index_t i_access_unsupport_ = -1,
505 bool oob_conditional_check =
true,
506 bool pre_nop =
false>
510 bool_constant<pre_nop> = {})
const
512 using LdsTileWindow = remove_cvref_t<LdsTileWindow_>;
514 using LdsDataType =
typename LdsTileWindow::DataType;
518 static_assert(LdsTileWindow::get_num_of_dimension() == 3);
521 lds_tile.get_bottom_tensor_view().get_tensor_descriptor().calculate_offset(
522 make_tuple(number<0>{}, number<0>{}, number<0>{})) *
526 lds_tile.get_bottom_tensor_view().get_tensor_descriptor().calculate_offset(
527 make_tuple(number<0>{}, number<1>{}, number<0>{})) *
528 sizeof(LdsDataType) -
532 lds_tile.get_bottom_tensor_view().get_tensor_descriptor().calculate_offset(
533 make_tuple(number<1>{}, number<0>{}, number<0>{})) *
534 sizeof(LdsDataType) -
537 const index_t m0_init_value = size_per_buf + size_per_wave * get_warp_id();
540 using Traits = load_store_traits;
543 using vector_t =
typename Traits::vector_t;
544 using SFC_Ys =
typename Traits::SFC_Ys;
546 LdsDataType* smem = lds_tile.get_bottom_tensor_view().get_buffer_view().p_data_;
549 static_for<0, NumCoord, 1>{}([&](
auto iCoord) {
554 static_for<0, NumAccessPerCoord, 1>{}([&](
auto iCoordAccess) {
555 constexpr
auto iAccess = number<iCoord * NumAccessPerCoord + iCoordAccess>{};
556 constexpr
auto pre_nop_ = [&]() {
557 if constexpr(pre_nop && iCoord == 0 && iCoordAccess == 0)
560 return bool_constant<false>{};
563 constexpr
auto idx_ys_start = SFC_Ys::get_index(iAccess);
564 constexpr
auto idx_gather = idx_ys_start[number<YsGatherDim>{}];
565 const auto page_offset =
page_idx_[idx_gather];
568 if constexpr(std::is_same_v<ValidArray, std::nullptr_t>)
571 smem, bottom_tensor_thread_coord, page_offset, 0, pre_nop_);
577 bottom_tensor_thread_coord,
587 constexpr
auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess);
590 [&](
auto i) {
return i == YsGatherDim ? 0 : idx_diff_ys[i]; },
594 generate_tuple([&](
auto) {
return number<0>{}; }, number<NDimP>{}),
595 forward_step_scatter);
598 window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
606 template <
index_t i_access_unsupport_ = -1,
bool oob_conditional_check =
true>
611 using Traits = load_store_traits;
614 using vector_t =
typename Traits::vector_t;
615 using SFC_Ys =
typename Traits::SFC_Ys;
617 constexpr
auto tile_dstr =
TileDstr{};
619 static_for<0, NumCoord, 1>{}([&](
auto iCoord) {
623 static_for<0, NumAccessPerCoord, 1>{}([&](
auto iCoordAccess) {
624 constexpr
auto iAccess = number<iCoord * NumAccessPerCoord + iCoordAccess>{};
627 constexpr
auto idx_ys_start = SFC_Ys::get_index(iAccess);
628 constexpr
auto idx_gather = idx_ys_start[number<0>{}];
629 const auto page_offset =
page_idx_[idx_gather];
634 static_for<0, Traits::ScalarPerVector, Traits::PackedSize>{}([&](
auto j) {
637 return jj == Traits::VectorDimY ? (idx_ys_start[jj] + j)
643 tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys) /
646 vec_value.template get_as<DataType>()(j / Traits::PackedSize) =
651 if constexpr(std::is_same_v<ValidArray, std::nullptr_t>)
654 bottom_tensor_thread_coord,
657 bool_constant<oob_conditional_check>{});
662 bottom_tensor_thread_coord,
666 bool_constant<oob_conditional_check>{});
671 constexpr
auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess);
674 [&](
auto i) {
return i == YsGatherDim ? 0 : idx_diff_ys[i]; },
678 generate_tuple([&](
auto) {
return number<0>{}; }, number<NDimP>{}),
679 forward_step_scatter);
682 window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
688 template <
index_t i_access_unsupport_ = -1,
bool oob_conditional_check =
true>
693 using Traits = load_store_traits;
696 using vector_t =
typename Traits::vector_t;
697 using SFC_Ys =
typename Traits::SFC_Ys;
699 constexpr
auto tile_dstr =
TileDstr{};
702 static_for<0, NumCoord, 1>{}([&](
auto iCoord) {
706 static_for<0, NumAccessPerCoord, 1>{}([&](
auto iCoordAccess) {
707 constexpr
auto iAccess = number<iCoord * NumAccessPerCoord + iCoordAccess>{};
710 constexpr
auto idx_ys_start = SFC_Ys::get_index(iAccess);
711 constexpr
auto idx_gather = idx_ys_start[number<0>{}];
712 const auto page_offset =
page_idx_[idx_gather];
721 static_for<0, Traits::ScalarPerVector, Traits::PackedSize>{}([&](
auto j) {
724 return jj == Traits::VectorDimY ? (idx_ys_start[jj] + j)
730 tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys) /
733 vec_value.template get_as<DataType>()(j / Traits::PackedSize) =
740 if constexpr(std::is_same_v<ValidArray, std::nullptr_t>)
743 bottom_tensor_thread_coord,
746 bool_constant<oob_conditional_check>{});
751 bottom_tensor_thread_coord,
755 bool_constant<oob_conditional_check>{});
762 constexpr
auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess);
765 [&](
auto i) {
return i == YsGatherDim ? 0 : idx_diff_ys[i]; },
769 generate_tuple([&](
auto) {
return number<0>{}; }, number<NDimP>{}),
770 forward_step_scatter);
773 window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
786 step_new(HsGatherDim) = 0;
798 if constexpr(std::is_same_v<ValidArray, std::nullptr_t> ==
false)
819 static_assert(
NDimP == 1 or
NDimP == 2,
"wrong!");
823 if constexpr(
NDimP == 1)
828 else if constexpr(
NDimP == 2)
830 window_adaptor_thread_coord_tmp =
843 window_origin_ + window_adaptor_thread_coord_tmp.get_bottom_index();
845 bottom_tensor_thread_origin_idx_tmp(HsGatherDim) = 0;
852 using SFC_Ys =
typename Traits::SFC_Ys;
855 auto window_adaptor_thread_coord = window_adaptor_thread_coord_tmp;
856 auto bottom_tensor_thread_coord = bottom_tensor_thread_coord_tmp;
858 constexpr
auto idx_diff_ys =
865 window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
868 make_tuple(window_adaptor_thread_coord, bottom_tensor_thread_coord);
899 template <
typename TensorView_,
900 typename WindowLengths_,
901 typename StaticTileDistribution_,
902 typename StaticPageIndexArray_,
907 const WindowLengths_& window_lengths,
908 const multi_index<TensorView_::get_num_of_dimension()>& origin,
910 const StaticPageIndexArray_& page_idx,
912 number<NumCoord> = {})
914 return tile_scatter_gather<remove_cvref_t<TensorView_>,
915 remove_cvref_t<WindowLengths_>,
916 remove_cvref_t<StaticTileDistribution_>,
917 remove_cvref_t<StaticPageIndexArray_>,
921 tensor_view, window_lengths, origin, tile_distribution, page_idx,
nullptr};
924 template <
typename TensorView,
925 typename WindowLengths,
926 typename StaticTileDistribution,
927 typename StaticPageIndexArray,
931 const multi_index<TensorView::get_num_of_dimension()>& origin,
933 const StaticPageIndexArray& page_idx,
941 number<HsGatherDim>{});
944 template <
typename TensorView,
945 typename WindowLengths,
946 typename StaticTileDistribution,
947 typename StaticPageIndexArray,
952 const StaticPageIndexArray& page_idx,
960 number<HsGatherDim>{});
963 template <
typename TensorView_,
964 typename WindowLengths_,
965 typename StaticTileDistribution_,
966 typename StaticPageIndexArray_,
967 typename StaticValidArray_,
972 const WindowLengths_& window_lengths,
973 const multi_index<TensorView_::get_num_of_dimension()>& origin,
975 const StaticPageIndexArray_& page_idx,
976 const StaticValidArray_& valids,
978 number<NumCoord> = {})
980 return tile_scatter_gather<remove_cvref_t<TensorView_>,
981 remove_cvref_t<WindowLengths_>,
982 remove_cvref_t<StaticTileDistribution_>,
983 remove_cvref_t<StaticPageIndexArray_>,
984 remove_cvref_t<StaticValidArray_>,
987 tensor_view, window_lengths, origin, tile_distribution, page_idx, valids};
990 template <
typename TensorView,
991 typename WindowLengths,
992 typename StaticTileDistribution,
993 typename StaticPageIndexArray,
994 typename StaticValidArray,
998 const multi_index<TensorView::get_num_of_dimension()>& origin,
1000 const StaticPageIndexArray& page_idx,
1001 const StaticValidArray& valids,
1010 number<HsGatherDim>{});
1013 template <
typename TensorView,
1014 typename WindowLengths,
1015 typename StaticTileDistribution,
1016 typename StaticPageIndexArray,
1017 typename StaticValidArray,
1022 const StaticPageIndexArray& page_idx,
1023 const StaticValidArray& valids,
1032 number<HsGatherDim>{});
1035 template <
typename NewTensorView_,
1036 typename OldTensorView_,
1037 typename WindowLengths_,
1038 typename StaticTileDistribution_,
1039 typename StaticPageIndexArray_,
1040 typename StaticValidArray_,
1046 StaticTileDistribution_,
1047 StaticPageIndexArray_,
1050 NumCoord>& tile_window)
1055 tile_window.tile_dstr_,
1056 tile_window.page_idx_,
1057 tile_window.valids_);
#define CK_TILE_DEVICE
Definition: config.hpp:41
#define CK_TILE_LDS_ADDR
Definition: config.hpp:58
#define CK_TILE_HOST_DEVICE
Definition: config.hpp:42
CK_TILE_HOST_DEVICE auto get_partition_index(Distribution)
Definition: tile_distribution.hpp:22
Definition: cluster_descriptor.hpp:13
constexpr CK_TILE_HOST_DEVICE void set_container_subset(array< T, N > &y, sequence< Is... > picks, const array< T, sizeof...(Is)> &x)
Definition: container_helper.hpp:420
constexpr CK_TILE_HOST_DEVICE void move_tensor_coordinate(const TensorDesc &tensor_desc, TensorCoord &coord, const Index &coord_step)
Definition: tensor_coordinate.hpp:72
CK_TILE_DEVICE auto replace_bottom_tensor_view(const NewTensorView_ &new_tensor_view, const tile_scatter_gather< OldTensorView_, WindowLengths_, StaticTileDistribution_, StaticPageIndexArray_, StaticValidArray_, HsGatherDim, NumCoord > &tile_window)
Definition: tile_scatter_gather.hpp:1043
constexpr CK_TILE_HOST_DEVICE auto make_tensor_adaptor_coordinate(const Adaptor &adaptor, const TopIndex &idx_top)
Definition: tensor_adaptor_coordinate.hpp:55
constant< b > bool_constant
Definition: integral_constant.hpp:43
constexpr CK_TILE_HOST_DEVICE auto generate_array(F &&f, number< N >)
Definition: sequence.hpp:1115
int32_t index_t
Definition: integer.hpp:9
constexpr CK_TILE_HOST_DEVICE auto make_tensor_coordinate(const TensorDesc &tensor_desc, const TopIndex &idx_top)
Definition: tensor_coordinate.hpp:60
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition: type_traits.hpp:21
constant< v > number
Definition: integral_constant.hpp:37
constexpr CK_TILE_HOST_DEVICE auto to_sequence(tuple< number< Is >... >)
Definition: sequence.hpp:1055
constexpr CK_TILE_HOST_DEVICE auto generate_tuple(F &&f, number< N >)
Definition: tuple.hpp:429
CK_TILE_DEVICE void m0_set_with_memory(index_t v)
Definition: utility.hpp:19
typename std::remove_reference< T >::type remove_reference_t
Definition: type_traits.hpp:15
constexpr CK_TILE_HOST_DEVICE auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:360
constexpr CK_TILE_DEVICE auto make_tile_scatter_gather(const TensorView_ &tensor_view, const WindowLengths_ &window_lengths, const multi_index< TensorView_::get_num_of_dimension()> &origin, const StaticTileDistribution_ &tile_distribution, const StaticPageIndexArray_ &page_idx, number< HsGatherDim >={}, number< NumCoord >={})
Definition: tile_scatter_gather.hpp:906
constexpr CK_TILE_HOST_DEVICE void move_tensor_adaptor_coordinate(const Adaptor &adaptor, AdaptorCoord &coord, const TopIndex &idx_diff_top, BottomIndex &idx_diff_bottom)
Definition: tensor_adaptor_coordinate.hpp:97
constexpr CK_TILE_HOST_DEVICE auto get_container_subset(const array< T, N > &arr, sequence< Is... >)
Definition: container_helper.hpp:389
CK_TILE_DEVICE void m0_inc_with_memory(index_t v)
Definition: utility.hpp:25
impl::is_static_impl< remove_cvref_t< T > > is_static
Definition: type_traits.hpp:87
constexpr CK_TILE_HOST_DEVICE auto container_concat(const X &x, const Ys &... ys)
Definition: container_helper.hpp:363
Definition: sequence.hpp:287
typename std::conditional< kHasContent, type0, type1 >::type type
Definition: sequence.hpp:302
A fixed-size array container similar to std::array with additional utilities.
Definition: array.hpp:43
Definition: integral_constant.hpp:13
Definition: type_traits.hpp:76
Definition: numeric.hpp:81
Definition: space_filling_curve.hpp:20
Definition: static_distributed_tensor.hpp:21
constexpr CK_TILE_HOST_DEVICE const auto & get_thread_buffer() const
Definition: static_distributed_tensor.hpp:58
Definition: functional.hpp:43
Definition: tensor_view.hpp:41
Definition: tile_distribution.hpp:72
constexpr CK_TILE_HOST_DEVICE const auto & get_ps_ys_to_xs_adaptor() const
Definition: tile_distribution.hpp:126
Definition: tile_scatter_gather.hpp:82
static constexpr index_t PackedSize
Definition: tile_scatter_gather.hpp:105
static constexpr index_t NumAccess
Definition: tile_scatter_gather.hpp:144
decltype(get_space_filling_curve()) SFC_Ys
Definition: tile_scatter_gather.hpp:142
static constexpr index_t VectorDimY
Definition: tile_scatter_gather.hpp:107
static constexpr index_t ScalarPerVector
Definition: tile_scatter_gather.hpp:108
This class provides tile (windowed) view and access to the device memory.
Definition: tile_scatter_gather.hpp:42
CK_TILE_DEVICE void move(const BottomTensorIndex &step)
Definition: tile_scatter_gather.hpp:782
static constexpr index_t NumAccessPerCoord
Definition: tile_scatter_gather.hpp:150
static constexpr auto I1
Definition: tile_scatter_gather.hpp:60
constexpr CK_TILE_DEVICE tile_scatter_gather(const BottomTensorView &bottom_tensor_view, const WindowLengths &window_lengths, const BottomTensorIndex &window_origin, const TileDstr &tile_distribution, const PageIdxArray &page_idx, const ValidArray &valids)
Definition: tile_scatter_gather.hpp:154
BottomTensorIndex window_origin_
Definition: tile_scatter_gather.hpp:882
CK_TILE_DEVICE auto load(number< i_access_unsupport_ >={}, bool_constant< oob_conditional_check >={}) const
Definition: tile_scatter_gather.hpp:309
WindowLengths window_lengths_
Definition: tile_scatter_gather.hpp:879
constexpr CK_TILE_DEVICE auto get_tile_distribution() const
Definition: tile_scatter_gather.hpp:234
constexpr CK_TILE_DEVICE auto get_num_of_access() const
Definition: tile_scatter_gather.hpp:306
static constexpr index_t NDimBottomTensor
Definition: tile_scatter_gather.hpp:54
static constexpr CK_TILE_DEVICE auto get_window_adaptor_ys_safe_vector_length_strides()
Definition: tile_scatter_gather.hpp:267
array< index_t, NDimBottomTensor > BottomTensorIndex
Definition: tile_scatter_gather.hpp:73
PageIdxArray page_idx_
Definition: tile_scatter_gather.hpp:889
CK_TILE_DEVICE auto async_load(LdsTileWindow_ &&lds_tile, number< i_access_unsupport_ >={}, bool_constant< oob_conditional_check >={}) const
Definition: tile_scatter_gather.hpp:411
remove_cvref_t< WindowLengths_ > WindowLengths
Definition: tile_scatter_gather.hpp:44
CK_TILE_DEVICE void set_window_origin(const BottomTensorIndex &new_window_origin)
Definition: tile_scatter_gather.hpp:811
array< tuple< WindowAdaptorCoord, BottomTensorCoord >, NumCoord > pre_computed_coords_
Definition: tile_scatter_gather.hpp:895
constexpr CK_TILE_DEVICE auto get_window_origin() const
Definition: tile_scatter_gather.hpp:238
remove_cvref_t< StaticTileDistribution_ > TileDstr
Definition: tile_scatter_gather.hpp:45
CK_TILE_DEVICE void move_window_adaptor_and_bottom_tensor_thread_coordinate(WindowAdaptorCoord &window_adaptor_thread_coord, BottomTensorCoord &bottom_tensor_thread_coord, const ATopIndex &idx_diff_adaptor_top) const
Definition: tile_scatter_gather.hpp:249
CK_TILE_DEVICE auto load(DistributedTensor &dst_tensor, number< i_access_unsupport_ >={}, bool_constant< oob_conditional_check >={}) const
Definition: tile_scatter_gather.hpp:321
CK_TILE_DEVICE void store(const static_distributed_tensor< DataType, TileDstr > &dstr_tensor, number< i_access_unsupport_ >={}, bool_constant< oob_conditional_check >={}) const
Definition: tile_scatter_gather.hpp:689
CK_TILE_DEVICE void update_page_idx_and_valids(const PageIdxArray &new_idx, const ValidArray &new_valids)
Definition: tile_scatter_gather.hpp:804
typename BottomTensorView::TensorDesc BottomTensorDesc
Definition: tile_scatter_gather.hpp:49
TileDstr tile_dstr_
Definition: tile_scatter_gather.hpp:887
ValidArray valids_
Definition: tile_scatter_gather.hpp:890
static constexpr index_t NDimY
Definition: tile_scatter_gather.hpp:57
remove_cvref_t< typename BottomTensorView::DataType > DataType
Definition: tile_scatter_gather.hpp:51
static constexpr index_t NDimWindowAdaptorTop
Definition: tile_scatter_gather.hpp:53
static constexpr CK_TILE_DEVICE bool has_static_tile_distribution()
Definition: tile_scatter_gather.hpp:227
remove_cvref_t< StaticValidArray_ > ValidArray
Definition: tile_scatter_gather.hpp:47
static constexpr index_t NDimP
Definition: tile_scatter_gather.hpp:56
remove_reference_t< BottomTensorView_ > BottomTensorView
Definition: tile_scatter_gather.hpp:43
constexpr CK_TILE_DEVICE tile_scatter_gather()=default
CK_TILE_DEVICE void update(const static_distributed_tensor< DataType, TileDstr > &dstr_tensor, number< i_access_unsupport_ >={}, bool_constant< oob_conditional_check >={}) const
Definition: tile_scatter_gather.hpp:607
remove_cvref_t< StaticPageIndexArray_ > PageIdxArray
Definition: tile_scatter_gather.hpp:46
constexpr CK_TILE_DEVICE auto get_window_lengths() const
Definition: tile_scatter_gather.hpp:232
CK_TILE_HOST_DEVICE void init_raw()
Definition: tile_scatter_gather.hpp:872
static constexpr auto I0
Definition: tile_scatter_gather.hpp:59
decltype(make_tensor_coordinate(BottomTensorDesc{}, BottomTensorIndex{})) BottomTensorCoord
Definition: tile_scatter_gather.hpp:79
constexpr CK_TILE_DEVICE auto get_bottom_tensor_view() const
Definition: tile_scatter_gather.hpp:236
typename TileDstr::PsYs2XsAdaptor WindowAdaptor
Definition: tile_scatter_gather.hpp:48
decltype(make_tensor_adaptor_coordinate(WindowAdaptor{}, AdaptorTopIndex{})) WindowAdaptorCoord
Definition: tile_scatter_gather.hpp:76
constexpr CK_TILE_DEVICE void set_bottom_tensor_view_data_ptr(typename BottomTensorView::DataType *data)
Definition: tile_scatter_gather.hpp:241
BottomTensorView bottom_tensor_view_
Definition: tile_scatter_gather.hpp:876
CK_TILE_DEVICE void update_valids(const ValidArray &new_valids)
Definition: tile_scatter_gather.hpp:796
CK_TILE_DEVICE auto async_load_raw(LdsTileWindow_ &&lds_tile, number< i_access_unsupport_ >={}, bool_constant< oob_conditional_check >={}, bool_constant< pre_nop >={}) const
Definition: tile_scatter_gather.hpp:507
array< index_t, NDimWindowAdaptorTop > AdaptorTopIndex
Definition: tile_scatter_gather.hpp:72
CK_TILE_DEVICE void update_page_idx(const PageIdxArray &new_idx)
Definition: tile_scatter_gather.hpp:794
static constexpr CK_TILE_DEVICE index_t get_num_of_dimension()
Definition: tile_scatter_gather.hpp:225
constexpr CK_TILE_DEVICE auto get_window_origin() const
Definition: tile_window_base.hpp:45
BottomTensorIndex window_origin_
Definition: tile_window_base.hpp:79
constexpr CK_TILE_DEVICE auto get_bottom_tensor_view() const
Definition: tile_window_base.hpp:47
constexpr CK_TILE_DEVICE auto get_window_lengths() const
Definition: tile_window_base.hpp:46
WindowLengths window_lengths_
Definition: tile_window_base.hpp:81
This class provides description of tile windowed view on the device memory.
Definition: tile_window.hpp:1016
#define TO_SEQUENCE(a, n)
Definition: to_sequence.hpp:10