32 template <
typename BottomTensorView_,
33 typename WindowLengths_,
34 typename StaticTileDistribution_,
35 typename StaticPageIndexArray_,
36 typename StaticValidArray_,
55 static constexpr
index_t NDimP = TileDstr::get_num_of_dimension_p();
56 static constexpr
index_t NDimY = TileDstr::get_num_of_dimension_y();
60 static_assert(NumCoord == 1);
65 "wrong! lengths should be static");
68 static_assert(
NDimBottomTensor == WindowAdaptor::get_num_of_bottom_dimension(),
69 "wrong! inconsistent # of diemsnions");
83 static constexpr
auto get_vector_dim_y_scalar_per_vector()
85 const auto [ys_vector_lengths, ys_vector_strides] =
93 if(ys_vector_strides[i] == 1 && ys_vector_lengths[i] > ScalarPerVector_)
95 ScalarPerVector_ = ys_vector_lengths[i];
100 return make_tuple(VectorDimY_, ScalarPerVector_);
108 get_vector_dim_y_scalar_per_vector().template at<1>();
115 static constexpr
auto scalars_per_access_ = [] {
120 constexpr
auto NDimY_ =
NDimY;
122 return TO_SEQUENCE(scalars_per_access_arr, NDimY_);
125 static constexpr
auto get_space_filling_curve()
127 constexpr
auto tile_dstr =
TileDstr{};
129 constexpr
auto thread_tensor_lengths_ys =
130 to_sequence(tile_dstr.get_ys_to_d_descriptor().get_lengths());
137 decltype(scalars_per_access_)>{};
141 using SFC_Ys = decltype(get_space_filling_curve());
145 static_assert(0 <
NumAccess,
"Wrong! NumAccess should be larger than 0");
146 static_assert(
NumAccess % NumCoord == 0,
"wrong! # of access is not divisible by NumCoord");
171 static_assert(
NDimP == 1 or
NDimP == 2,
"wrong!");
175 if constexpr(
NDimP == 1)
180 else if constexpr(
NDimP == 2)
182 window_adaptor_thread_coord_tmp =
195 window_origin + window_adaptor_thread_coord_tmp.get_bottom_index();
196 bottom_tensor_thread_origin_idx_tmp(HsGatherDim) = 0;
203 using SFC_Ys =
typename Traits::SFC_Ys;
206 auto window_adaptor_thread_coord = window_adaptor_thread_coord_tmp;
207 auto bottom_tensor_thread_coord = bottom_tensor_thread_coord_tmp;
209 constexpr
auto idx_diff_ys =
216 window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
219 make_tuple(window_adaptor_thread_coord, bottom_tensor_thread_coord);
221 if constexpr(BottomTensorView::buffer_view::get_address_space() ==
222 address_space_enum::global)
226 auto use_lane_id_0 = partition_index;
227 use_lane_id_0[1] = 0;
233 window_origin + window_adaptor_thread_coord_tmp_warp.get_bottom_index();
234 bottom_tensor_thread_origin_idx_tmp_warp(HsGatherDim) = 0;
235 const auto bottom_tensor_thread_coord_tmp_warp =
237 bottom_tensor_thread_origin_idx_tmp_warp);
242 auto window_adaptor_thread_coord = window_adaptor_thread_coord_tmp_warp;
243 auto bottom_tensor_thread_coord = bottom_tensor_thread_coord_tmp_warp;
245 constexpr
auto idx_diff_ys =
253 window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
256 make_tuple(window_adaptor_thread_coord, bottom_tensor_thread_coord);
284 template <
typename ATopIndex>
288 const ATopIndex& idx_diff_adaptor_top)
const
293 window_adaptor_thread_coord,
294 idx_diff_adaptor_top,
295 idx_diff_adaptor_bottom);
298 bottom_tensor_thread_coord,
299 idx_diff_adaptor_bottom);
306 const auto [bottom_tensor_top_dim_vector_lengths, bottom_tensor_top_dim_vector_strides] =
307 BottomTensorDesc::get_top_dimension_safe_vector_length_strides();
310 const auto window_adaptor_bottom_dim_vector_lengths = bottom_tensor_top_dim_vector_lengths;
311 const auto window_adaptor_bottom_dim_vector_strides = bottom_tensor_top_dim_vector_strides;
314 array<
index_t, WindowAdaptor::get_num_of_hidden_dimension()> window_adaptor_vector_lengths{
316 array<
index_t, WindowAdaptor::get_num_of_hidden_dimension()> window_adaptor_vector_strides{
319 constexpr
auto window_adaptor_bottom_dims =
320 WindowAdaptor::get_bottom_dimension_hidden_ids();
323 window_adaptor_bottom_dims,
324 window_adaptor_bottom_dim_vector_lengths);
326 window_adaptor_bottom_dims,
327 window_adaptor_bottom_dim_vector_strides);
329 const auto [window_adaptor_ps_ys_vector_lengths, window_adaptor_ps_ys_vector_strides] =
330 WindowAdaptor{}.get_top_dimension_safe_vector_length_strides(
331 window_adaptor_vector_lengths, window_adaptor_vector_strides);
344 template <
index_t i_access_unsupport_ = -1,
bool oob_conditional_check =
true>
348 constexpr
auto tile_dstr =
TileDstr{};
349 auto dst_tensor = make_static_distributed_tensor<DataType>(tile_dstr);
350 load(dst_tensor, number<i_access_unsupport_>{}, bool_constant<oob_conditional_check>{});
354 template <
typename DistributedTensor,
355 index_t i_access_unsupport_ = -1,
356 bool oob_conditional_check =
true>
361 using Traits = load_store_traits;
362 using vector_t =
typename Traits::vector_t;
363 using SFC_Ys =
typename Traits::SFC_Ys;
365 constexpr
auto tile_dstr =
TileDstr{};
368 static_for<0, NumCoord, 1>{}([&](
auto iCoord) {
373 static_for<0, NumAccessPerCoord, 1>{}([&](
auto iCoordAccess) {
374 constexpr
auto iAccess = number<iCoord * NumAccessPerCoord + iCoordAccess>{};
377 constexpr
auto idx_ys_start = SFC_Ys::get_index(iAccess);
378 constexpr
auto idx_gather = idx_ys_start[number<YsGatherDim>{}];
379 const auto page_offset =
page_idx_[idx_gather];
382 const vector_t vec_value = [&]() {
383 if constexpr(std::is_same_v<ValidArray, std::nullptr_t>)
386 bottom_tensor_thread_coord,
388 bool_constant<oob_conditional_check>{});
393 bottom_tensor_thread_coord,
396 bool_constant<oob_conditional_check>{});
401 static_for<0, Traits::ScalarPerVector, Traits::PackedSize>{}([&](
auto j) {
404 return jj == Traits::VectorDimY ? (idx_ys_start[jj] + j)
410 tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys) /
413 dst_tensor.get_thread_buffer().template at<d>() =
414 vec_value.template get_as<DataType>()[j / Traits::PackedSize];
418 tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys_start);
419 static_assert(d % Traits::ScalarPerVector == 0);
421 dst_tensor.get_thread_buffer().template get_as<vector_t>()(
422 number<d / Traits::ScalarPerVector>{}) = bit_cast<vector_t>(vec_value);
427 constexpr
auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess);
430 [&](
auto i) {
return i == YsGatherDim ? 0 : idx_diff_ys[i]; },
434 generate_tuple([&](
auto) {
return number<0>{}; }, number<NDimP>{}),
435 forward_step_scatter);
438 window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
444 template <
typename LdsTileWindow_,
445 index_t i_access_unsupport_ = -1,
446 bool oob_conditional_check =
true>
451 using LdsTileWindow = remove_cvref_t<LdsTileWindow_>;
452 using LdsDataType =
typename LdsTileWindow::DataType;
453 using Traits = load_store_traits;
454 using vector_t =
typename Traits::vector_t;
455 using SFC_Ys =
typename Traits::SFC_Ys;
457 constexpr
auto tile_dstr =
TileDstr{};
460 const auto window_origin = lds_tile.get_window_origin();
461 const auto& bottom_tensor_view = lds_tile.get_bottom_tensor_view();
462 const auto& tensor_descriptor = bottom_tensor_view.get_tensor_descriptor();
463 auto smem_base_ptr = bottom_tensor_view.get_buffer_view().p_data_;
466 static_for<0, NumCoord, 1>{}([&](
auto iCoord) {
474 static_for<0, NumAccessPerCoord, 1>{}([&](
auto iCoordAccess) {
475 constexpr
auto iAccess = number<iCoord * NumAccessPerCoord + iCoordAccess>{};
478 auto lds_bottom_tensor_thread_idx =
479 window_origin + lds_window_adaptor_thread_coord.get_bottom_index();
481 const auto lds_coord =
484 CK_TILE_LDS_ADDR LdsDataType* smem = smem_base_ptr + lds_coord.get_offset();
487 constexpr
auto idx_ys_start = SFC_Ys::get_index(iAccess);
488 constexpr
auto idx_gather = idx_ys_start[number<YsGatherDim>{}];
489 const auto page_offset =
page_idx_[idx_gather];
492 auto mixed_bottom_thread_coord = bottom_tensor_thread_coord;
493 mixed_bottom_thread_coord.get_hidden_index()[number<0>{}] += page_offset;
496 if constexpr(std::is_same_v<ValidArray, std::nullptr_t>)
499 mixed_bottom_thread_coord,
501 bool_constant<oob_conditional_check>{});
505 mixed_bottom_thread_coord,
508 bool_constant<oob_conditional_check>{});
513 constexpr
auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess);
516 [&](
auto i) {
return i == YsGatherDim ? 0 : idx_diff_ys[i]; },
520 generate_tuple([&](
auto) {
return number<0>{}; }, number<NDimP>{}),
521 forward_step_scatter);
524 generate_tuple([&](
auto) {
return number<0>{}; }, number<NDimP>{}),
528 window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
530 lds_window_adaptor_thread_coord,
531 lds_bottom_tensor_thread_coord,
539 template <
typename LdsTileWindow_,
540 index_t i_access_unsupport_ = -1,
541 bool oob_conditional_check =
true,
542 bool pre_nop =
false>
546 bool_constant<pre_nop> = {})
const
548 using LdsTileWindow = remove_cvref_t<LdsTileWindow_>;
550 using LdsDataType =
typename LdsTileWindow::DataType;
554 static_assert(LdsTileWindow::get_num_of_dimension() == 3);
557 lds_tile.get_bottom_tensor_view().get_tensor_descriptor().calculate_offset(
558 make_tuple(number<0>{}, number<0>{}, number<0>{})) *
562 lds_tile.get_bottom_tensor_view().get_tensor_descriptor().calculate_offset(
563 make_tuple(number<0>{}, number<1>{}, number<0>{})) *
564 sizeof(LdsDataType) -
568 lds_tile.get_bottom_tensor_view().get_tensor_descriptor().calculate_offset(
569 make_tuple(number<1>{}, number<0>{}, number<0>{})) *
570 sizeof(LdsDataType) -
573 const index_t m0_init_value = size_per_buf + size_per_wave * get_warp_id();
577 using Traits = load_store_traits;
580 using vector_t =
typename Traits::vector_t;
581 using SFC_Ys =
typename Traits::SFC_Ys;
583 LdsDataType* smem = lds_tile.get_bottom_tensor_view().get_buffer_view().p_data_;
586 static_for<0, NumCoord, 1>{}([&](
auto iCoord) {
591 static_for<0, NumAccessPerCoord, 1>{}([&](
auto iCoordAccess) {
592 constexpr
auto iAccess = number<iCoord * NumAccessPerCoord + iCoordAccess>{};
593 constexpr
auto pre_nop_ = [&]() {
594 if constexpr(pre_nop && iCoord == 0 && iCoordAccess == 0)
597 return bool_constant<false>{};
600 constexpr
auto idx_ys_start = SFC_Ys::get_index(iAccess);
601 constexpr
auto idx_gather = idx_ys_start[number<YsGatherDim>{}];
602 const auto page_offset =
page_idx_[idx_gather];
605 if constexpr(std::is_same_v<ValidArray, std::nullptr_t>)
608 smem, bottom_tensor_thread_coord, page_offset, 0, pre_nop_);
614 bottom_tensor_thread_coord,
624 constexpr
auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess);
627 [&](
auto i) {
return i == YsGatherDim ? 0 : idx_diff_ys[i]; },
631 generate_tuple([&](
auto) {
return number<0>{}; }, number<NDimP>{}),
632 forward_step_scatter);
635 window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
644 template <
typename LdsTileWindow_,
645 index_t i_access_unsupport_ = -1,
646 bool oob_conditional_check =
true,
647 bool static_move_ys =
false,
648 typename = std::enable_if_t<std::is_class_v<remove_cvref_t<LdsTileWindow_>>>>
650 LdsTileWindow_&& lds_tile,
653 bool_constant<static_move_ys> = {})
const
655 using LdsTileWindow = remove_cvref_t<LdsTileWindow_>;
656 using LdsDataType =
typename LdsTileWindow::DataType;
658 using Traits = load_store_traits;
660 using vector_t =
typename Traits::vector_t;
661 using SFC_Ys =
typename Traits::SFC_Ys;
664 const auto window_origin = lds_tile.get_window_origin();
665 const auto& bottom_tensor_view = lds_tile.get_bottom_tensor_view();
666 const auto& tensor_descriptor = bottom_tensor_view.get_tensor_descriptor();
667 auto lds_base_ptr = bottom_tensor_view.get_buffer_view().p_data_;
669 static_for<0, NumCoord, 1>{}([&](
auto iCoord) {
676 static_for<0, NumAccessPerCoord, 1>{}([&](
auto iCoordAccess) {
677 constexpr
auto iAccess = number<iCoord * NumAccessPerCoord + iCoordAccess>{};
679 constexpr
auto idx_ys_offset = [&]() {
680 constexpr
auto idx_off_ys = SFC_Ys::get_step_between(number<0>{}, iAccess);
682 StaticTileDistribution_{}.get_ps_ys_to_xs_adaptor(),
685 return adapter_ys_offset.get_bottom_index();
687 const auto lds_ys_offset = [&]() {
688 if constexpr(static_move_ys)
690 const auto coord_ys_offset =
692 return coord_ys_offset.get_offset();
699 auto lds_bottom_tensor_thread_idx =
700 window_origin + window_adaptor_warp_coord.get_bottom_index();
701 const auto lds_coord =
706 lds_coord.get_offset() / Traits::PackedSize +
707 lds_ys_offset / Traits::PackedSize;
709 const auto dram_ys_offset = [&]() {
710 if constexpr(static_move_ys)
714 return coord_ys_offset.get_offset();
720 constexpr
auto idx_ys_start = SFC_Ys::get_index(iAccess);
721 constexpr
auto idx_gather = idx_ys_start[number<YsGatherDim>{}];
722 const auto page_offset =
page_idx_[idx_gather];
724 auto mixed_bottom_thread_coord = bottom_tensor_thread_coord;
725 mixed_bottom_thread_coord.get_hidden_index()[number<0>{}] += page_offset;
727 if constexpr(std::is_same_v<ValidArray, std::nullptr_t>)
731 mixed_bottom_thread_coord,
732 offset + dram_ys_offset,
733 bool_constant<oob_conditional_check>{});
739 mixed_bottom_thread_coord,
740 offset + dram_ys_offset,
742 bool_constant<oob_conditional_check>{});
748 constexpr
auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess);
751 [&](
auto i) {
return i == YsGatherDim ? 0 : idx_diff_ys[i]; },
755 generate_tuple([&](
auto) {
return number<0>{}; }, number<NDimP>{}),
756 forward_step_scatter);
758 if constexpr(!static_move_ys)
760 window_adaptor_thread_coord,
761 bottom_tensor_thread_coord,
764 if constexpr(!static_move_ys)
766 window_adaptor_warp_coord, bottom_tensor_warp_coord, idx_diff_ps_ys);
772 template <
index_t i_access_unsupport_ = -1,
bool oob_conditional_check =
true>
777 using Traits = load_store_traits;
780 using vector_t =
typename Traits::vector_t;
781 using SFC_Ys =
typename Traits::SFC_Ys;
783 constexpr
auto tile_dstr =
TileDstr{};
785 static_for<0, NumCoord, 1>{}([&](
auto iCoord) {
789 static_for<0, NumAccessPerCoord, 1>{}([&](
auto iCoordAccess) {
790 constexpr
auto iAccess = number<iCoord * NumAccessPerCoord + iCoordAccess>{};
793 constexpr
auto idx_ys_start = SFC_Ys::get_index(iAccess);
794 constexpr
auto idx_gather = idx_ys_start[number<0>{}];
795 const auto page_offset =
page_idx_[idx_gather];
800 static_for<0, Traits::ScalarPerVector, Traits::PackedSize>{}([&](
auto j) {
803 return jj == Traits::VectorDimY ? (idx_ys_start[jj] + j)
809 tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys) /
812 vec_value.template get_as<DataType>()(j / Traits::PackedSize) =
817 if constexpr(std::is_same_v<ValidArray, std::nullptr_t>)
820 bottom_tensor_thread_coord,
823 bool_constant<oob_conditional_check>{});
828 bottom_tensor_thread_coord,
832 bool_constant<oob_conditional_check>{});
837 constexpr
auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess);
840 [&](
auto i) {
return i == YsGatherDim ? 0 : idx_diff_ys[i]; },
844 generate_tuple([&](
auto) {
return number<0>{}; }, number<NDimP>{}),
845 forward_step_scatter);
848 window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
854 template <
index_t i_access_unsupport_ = -1,
bool oob_conditional_check =
true>
859 using Traits = load_store_traits;
862 using vector_t =
typename Traits::vector_t;
863 using SFC_Ys =
typename Traits::SFC_Ys;
865 constexpr
auto tile_dstr =
TileDstr{};
868 static_for<0, NumCoord, 1>{}([&](
auto iCoord) {
872 static_for<0, NumAccessPerCoord, 1>{}([&](
auto iCoordAccess) {
873 constexpr
auto iAccess = number<iCoord * NumAccessPerCoord + iCoordAccess>{};
876 constexpr
auto idx_ys_start = SFC_Ys::get_index(iAccess);
877 constexpr
auto idx_gather = idx_ys_start[number<0>{}];
878 const auto page_offset =
page_idx_[idx_gather];
887 static_for<0, Traits::ScalarPerVector, Traits::PackedSize>{}([&](
auto j) {
890 return jj == Traits::VectorDimY ? (idx_ys_start[jj] + j)
896 tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys) /
899 vec_value.template get_as<DataType>()(j / Traits::PackedSize) =
906 if constexpr(std::is_same_v<ValidArray, std::nullptr_t>)
909 bottom_tensor_thread_coord,
912 bool_constant<oob_conditional_check>{});
917 bottom_tensor_thread_coord,
921 bool_constant<oob_conditional_check>{});
928 constexpr
auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess);
931 [&](
auto i) {
return i == YsGatherDim ? 0 : idx_diff_ys[i]; },
935 generate_tuple([&](
auto) {
return number<0>{}; }, number<NDimP>{}),
936 forward_step_scatter);
939 window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
952 step_new(HsGatherDim) = 0;
958 if constexpr(BottomTensorView::buffer_view::get_address_space() ==
959 address_space_enum::global)
973 if constexpr(std::is_same_v<ValidArray, std::nullptr_t> ==
false)
994 static_assert(
NDimP == 1 or
NDimP == 2,
"wrong!");
998 if constexpr(
NDimP == 1)
1003 else if constexpr(
NDimP == 2)
1005 window_adaptor_thread_coord_tmp =
1018 window_origin_ + window_adaptor_thread_coord_tmp.get_bottom_index();
1020 bottom_tensor_thread_origin_idx_tmp(HsGatherDim) = 0;
1027 using SFC_Ys =
typename Traits::SFC_Ys;
1030 auto window_adaptor_thread_coord = window_adaptor_thread_coord_tmp;
1031 auto bottom_tensor_thread_coord = bottom_tensor_thread_coord_tmp;
1033 constexpr
auto idx_diff_ys =
1040 window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
1043 make_tuple(window_adaptor_thread_coord, bottom_tensor_thread_coord);
1072 address_space_enum::global,
1079 template <
typename TensorView_,
1080 typename WindowLengths_,
1081 typename StaticTileDistribution_,
1082 typename StaticPageIndexArray_,
1087 const WindowLengths_& window_lengths,
1088 const multi_index<TensorView_::get_num_of_dimension()>& origin,
1090 const StaticPageIndexArray_& page_idx,
1092 number<NumCoord> = {})
1094 return tile_scatter_gather<remove_cvref_t<TensorView_>,
1095 remove_cvref_t<WindowLengths_>,
1096 remove_cvref_t<StaticTileDistribution_>,
1097 remove_cvref_t<StaticPageIndexArray_>,
1101 tensor_view, window_lengths, origin, tile_distribution, page_idx,
nullptr};
1104 template <
typename TensorView,
1105 typename WindowLengths,
1106 typename StaticTileDistribution,
1107 typename StaticPageIndexArray,
1111 const multi_index<TensorView::get_num_of_dimension()>& origin,
1113 const StaticPageIndexArray& page_idx,
1121 number<HsGatherDim>{});
1124 template <
typename TensorView,
1125 typename WindowLengths,
1126 typename StaticTileDistribution,
1127 typename StaticPageIndexArray,
1132 const StaticPageIndexArray& page_idx,
1140 number<HsGatherDim>{});
1143 template <
typename TensorView_,
1144 typename WindowLengths_,
1145 typename StaticTileDistribution_,
1146 typename StaticPageIndexArray_,
1147 typename StaticValidArray_,
1152 const WindowLengths_& window_lengths,
1153 const multi_index<TensorView_::get_num_of_dimension()>& origin,
1155 const StaticPageIndexArray_& page_idx,
1156 const StaticValidArray_& valids,
1158 number<NumCoord> = {})
1160 return tile_scatter_gather<remove_cvref_t<TensorView_>,
1161 remove_cvref_t<WindowLengths_>,
1162 remove_cvref_t<StaticTileDistribution_>,
1163 remove_cvref_t<StaticPageIndexArray_>,
1164 remove_cvref_t<StaticValidArray_>,
1167 tensor_view, window_lengths, origin, tile_distribution, page_idx, valids};
1170 template <
typename TensorView,
1171 typename WindowLengths,
1172 typename StaticTileDistribution,
1173 typename StaticPageIndexArray,
1174 typename StaticValidArray,
1178 const multi_index<TensorView::get_num_of_dimension()>& origin,
1180 const StaticPageIndexArray& page_idx,
1181 const StaticValidArray& valids,
1190 number<HsGatherDim>{});
1193 template <
typename TensorView,
1194 typename WindowLengths,
1195 typename StaticTileDistribution,
1196 typename StaticPageIndexArray,
1197 typename StaticValidArray,
1202 const StaticPageIndexArray& page_idx,
1203 const StaticValidArray& valids,
1212 number<HsGatherDim>{});
1215 template <
typename NewTensorView_,
1216 typename OldTensorView_,
1217 typename WindowLengths_,
1218 typename StaticTileDistribution_,
1219 typename StaticPageIndexArray_,
1220 typename StaticValidArray_,
1226 StaticTileDistribution_,
1227 StaticPageIndexArray_,
1230 NumCoord>& tile_window)
1235 tile_window.tile_dstr_,
1236 tile_window.page_idx_,
1237 tile_window.valids_);
#define CK_TILE_DEVICE
Definition: config.hpp:45
#define CK_TILE_LDS_ADDR
Definition: config.hpp:62
#define CK_TILE_HOST_DEVICE
Definition: config.hpp:46
Definition: cluster_descriptor.hpp:13
constexpr CK_TILE_HOST_DEVICE auto to_array(const std::vector< X > &x)
Definition: array.hpp:286
constexpr CK_TILE_HOST_DEVICE void set_container_subset(array< T, N > &y, sequence< Is... > picks, const array< T, sizeof...(Is)> &x)
Definition: container_helper.hpp:420
__device__ uint32_t amd_wave_read_first_lane(uint16_t v)
Definition: amd_buffer_addressing.hpp:36
constexpr CK_TILE_HOST_DEVICE void move_tensor_coordinate(const TensorDesc &tensor_desc, TensorCoord &coord, const Index &coord_step)
Definition: tensor_coordinate.hpp:72
CK_TILE_DEVICE auto replace_bottom_tensor_view(const NewTensorView_ &new_tensor_view, const tile_scatter_gather< OldTensorView_, WindowLengths_, StaticTileDistribution_, StaticPageIndexArray_, StaticValidArray_, HsGatherDim, NumCoord > &tile_window)
Definition: tile_scatter_gather.hpp:1223
constexpr CK_TILE_HOST_DEVICE auto make_tensor_adaptor_coordinate(const Adaptor &adaptor, const TopIndex &idx_top)
Definition: tensor_adaptor_coordinate.hpp:56
constant< b > bool_constant
Definition: integral_constant.hpp:43
constexpr CK_TILE_HOST_DEVICE auto generate_array(F &&f, number< N >)
Definition: sequence.hpp:1126
int32_t index_t
Definition: integer.hpp:9
constexpr CK_TILE_HOST_DEVICE auto make_tensor_coordinate(const TensorDesc &tensor_desc, const TopIndex &idx_top)
Definition: tensor_coordinate.hpp:60
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition: type_traits.hpp:21
constant< v > number
Definition: integral_constant.hpp:37
constexpr CK_TILE_HOST_DEVICE auto to_sequence(tuple< number< Is >... >)
Definition: sequence.hpp:1066
CK_TILE_HOST_DEVICE auto get_partition_index(Distribution)
Definition: tile_distribution.hpp:21
constexpr CK_TILE_HOST_DEVICE auto generate_tuple(F &&f, number< N >)
Definition: tuple.hpp:429
CK_TILE_DEVICE void m0_set_with_memory(index_t v)
Definition: utility.hpp:19
typename std::remove_reference< T >::type remove_reference_t
Definition: type_traits.hpp:15
constexpr CK_TILE_HOST_DEVICE auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:360
constexpr CK_TILE_DEVICE auto make_tile_scatter_gather(const TensorView_ &tensor_view, const WindowLengths_ &window_lengths, const multi_index< TensorView_::get_num_of_dimension()> &origin, const StaticTileDistribution_ &tile_distribution, const StaticPageIndexArray_ &page_idx, number< HsGatherDim >={}, number< NumCoord >={})
Definition: tile_scatter_gather.hpp:1086
constexpr CK_TILE_HOST_DEVICE void move_tensor_adaptor_coordinate(const Adaptor &adaptor, AdaptorCoord &coord, const TopIndex &idx_diff_top, BottomIndex &idx_diff_bottom)
Definition: tensor_adaptor_coordinate.hpp:98
constexpr CK_TILE_HOST_DEVICE auto get_container_subset(const array< T, N > &arr, sequence< Is... >)
Definition: container_helper.hpp:389
CK_TILE_DEVICE void m0_inc_with_memory(index_t v)
Definition: utility.hpp:25
impl::is_static_impl< remove_cvref_t< T > > is_static
Definition: type_traits.hpp:87
constexpr CK_TILE_HOST_DEVICE auto container_concat(const X &x, const Ys &... ys)
Definition: container_helper.hpp:363
typename conditional< predicate, X, Y >::type conditional_t
Definition: functional.hpp:115
Definition: sequence.hpp:298
typename std::conditional< kHasContent, type0, type1 >::type type
Definition: sequence.hpp:313
A fixed-size array container similar to std::array with additional utilities.
Definition: array.hpp:43
Definition: integral_constant.hpp:13
Definition: type_traits.hpp:76
Definition: numeric.hpp:81
Definition: coordinate_transform.hpp:1392
Definition: space_filling_curve.hpp:20
Definition: static_distributed_tensor.hpp:21
constexpr CK_TILE_HOST_DEVICE const auto & get_thread_buffer() const
Definition: static_distributed_tensor.hpp:58
Definition: functional.hpp:43
Definition: tensor_view.hpp:41
Definition: tile_distribution.hpp:70
constexpr CK_TILE_HOST_DEVICE const auto & get_ps_ys_to_xs_adaptor() const
Definition: tile_distribution.hpp:124
Definition: tile_scatter_gather.hpp:81
static constexpr index_t PackedSize
Definition: tile_scatter_gather.hpp:104
static constexpr index_t NumAccess
Definition: tile_scatter_gather.hpp:143
decltype(get_space_filling_curve()) SFC_Ys
Definition: tile_scatter_gather.hpp:141
static constexpr index_t VectorDimY
Definition: tile_scatter_gather.hpp:106
static constexpr index_t ScalarPerVector
Definition: tile_scatter_gather.hpp:107
This class provides tile (windowed) view and access to the device memory.
Definition: tile_scatter_gather.hpp:41
CK_TILE_DEVICE void move(const BottomTensorIndex &step)
Definition: tile_scatter_gather.hpp:948
static constexpr index_t NumAccessPerCoord
Definition: tile_scatter_gather.hpp:149
static constexpr auto I1
Definition: tile_scatter_gather.hpp:59
constexpr CK_TILE_DEVICE tile_scatter_gather(const BottomTensorView &bottom_tensor_view, const WindowLengths &window_lengths, const BottomTensorIndex &window_origin, const TileDstr &tile_distribution, const PageIdxArray &page_idx, const ValidArray &valids)
Definition: tile_scatter_gather.hpp:153
BottomTensorIndex window_origin_
Definition: tile_scatter_gather.hpp:1057
CK_TILE_DEVICE auto load(number< i_access_unsupport_ >={}, bool_constant< oob_conditional_check >={}) const
Definition: tile_scatter_gather.hpp:345
WindowLengths window_lengths_
Definition: tile_scatter_gather.hpp:1054
constexpr CK_TILE_DEVICE auto get_tile_distribution() const
Definition: tile_scatter_gather.hpp:270
constexpr CK_TILE_DEVICE auto get_num_of_access() const
Definition: tile_scatter_gather.hpp:342
static constexpr index_t NDimBottomTensor
Definition: tile_scatter_gather.hpp:53
static constexpr CK_TILE_DEVICE auto get_window_adaptor_ys_safe_vector_length_strides()
Definition: tile_scatter_gather.hpp:303
array< index_t, NDimBottomTensor > BottomTensorIndex
Definition: tile_scatter_gather.hpp:72
PageIdxArray page_idx_
Definition: tile_scatter_gather.hpp:1064
CK_TILE_DEVICE auto async_load(LdsTileWindow_ &&lds_tile, number< i_access_unsupport_ >={}, bool_constant< oob_conditional_check >={}) const
Definition: tile_scatter_gather.hpp:447
remove_cvref_t< WindowLengths_ > WindowLengths
Definition: tile_scatter_gather.hpp:43
CK_TILE_DEVICE void set_window_origin(const BottomTensorIndex &new_window_origin)
Definition: tile_scatter_gather.hpp:986
array< tuple< WindowAdaptorCoord, BottomTensorCoord >, NumCoord > pre_computed_coords_
Definition: tile_scatter_gather.hpp:1070
constexpr CK_TILE_DEVICE auto get_window_origin() const
Definition: tile_scatter_gather.hpp:274
remove_cvref_t< StaticTileDistribution_ > TileDstr
Definition: tile_scatter_gather.hpp:44
CK_TILE_DEVICE void move_window_adaptor_and_bottom_tensor_thread_coordinate(WindowAdaptorCoord &window_adaptor_thread_coord, BottomTensorCoord &bottom_tensor_thread_coord, const ATopIndex &idx_diff_adaptor_top) const
Definition: tile_scatter_gather.hpp:285
CK_TILE_DEVICE auto load(DistributedTensor &dst_tensor, number< i_access_unsupport_ >={}, bool_constant< oob_conditional_check >={}) const
Definition: tile_scatter_gather.hpp:357
CK_TILE_DEVICE void store(const static_distributed_tensor< DataType, TileDstr > &dstr_tensor, number< i_access_unsupport_ >={}, bool_constant< oob_conditional_check >={}) const
Definition: tile_scatter_gather.hpp:855
CK_TILE_DEVICE void update_page_idx_and_valids(const PageIdxArray &new_idx, const ValidArray &new_valids)
Definition: tile_scatter_gather.hpp:979
typename BottomTensorView::TensorDesc BottomTensorDesc
Definition: tile_scatter_gather.hpp:48
TileDstr tile_dstr_
Definition: tile_scatter_gather.hpp:1062
ValidArray valids_
Definition: tile_scatter_gather.hpp:1065
static constexpr index_t NDimY
Definition: tile_scatter_gather.hpp:56
remove_cvref_t< typename BottomTensorView::DataType > DataType
Definition: tile_scatter_gather.hpp:50
static constexpr index_t NDimWindowAdaptorTop
Definition: tile_scatter_gather.hpp:52
static constexpr CK_TILE_DEVICE bool has_static_tile_distribution()
Definition: tile_scatter_gather.hpp:263
remove_cvref_t< StaticValidArray_ > ValidArray
Definition: tile_scatter_gather.hpp:46
static constexpr index_t NDimP
Definition: tile_scatter_gather.hpp:55
remove_reference_t< BottomTensorView_ > BottomTensorView
Definition: tile_scatter_gather.hpp:42
constexpr CK_TILE_DEVICE tile_scatter_gather()=default
CK_TILE_DEVICE void update(const static_distributed_tensor< DataType, TileDstr > &dstr_tensor, number< i_access_unsupport_ >={}, bool_constant< oob_conditional_check >={}) const
Definition: tile_scatter_gather.hpp:773
remove_cvref_t< StaticPageIndexArray_ > PageIdxArray
Definition: tile_scatter_gather.hpp:45
constexpr CK_TILE_DEVICE auto get_window_lengths() const
Definition: tile_scatter_gather.hpp:268
CK_TILE_HOST_DEVICE void init_raw()
Definition: tile_scatter_gather.hpp:1047
static constexpr auto I0
Definition: tile_scatter_gather.hpp:58
decltype(make_tensor_coordinate(BottomTensorDesc{}, BottomTensorIndex{})) BottomTensorCoord
Definition: tile_scatter_gather.hpp:78
constexpr CK_TILE_DEVICE auto get_bottom_tensor_view() const
Definition: tile_scatter_gather.hpp:272
std::conditional_t< BottomTensorView::buffer_view::get_address_space()==address_space_enum::global, array< tuple< WindowAdaptorCoord, BottomTensorCoord >, NumCoord >, std::byte > pre_computed_warp_coords_
Definition: tile_scatter_gather.hpp:1075
typename TileDstr::PsYs2XsAdaptor WindowAdaptor
Definition: tile_scatter_gather.hpp:47
decltype(make_tensor_adaptor_coordinate(WindowAdaptor{}, AdaptorTopIndex{})) WindowAdaptorCoord
Definition: tile_scatter_gather.hpp:75
constexpr CK_TILE_DEVICE void set_bottom_tensor_view_data_ptr(typename BottomTensorView::DataType *data)
Definition: tile_scatter_gather.hpp:277
BottomTensorView bottom_tensor_view_
Definition: tile_scatter_gather.hpp:1051
CK_TILE_DEVICE void async_load_with_offset(index_t offset, LdsTileWindow_ &&lds_tile, number< i_access_unsupport_ >={}, bool_constant< oob_conditional_check >={}, bool_constant< static_move_ys >={}) const
Definition: tile_scatter_gather.hpp:649
CK_TILE_DEVICE void update_valids(const ValidArray &new_valids)
Definition: tile_scatter_gather.hpp:971
CK_TILE_DEVICE auto async_load_raw(LdsTileWindow_ &&lds_tile, number< i_access_unsupport_ >={}, bool_constant< oob_conditional_check >={}, bool_constant< pre_nop >={}) const
Definition: tile_scatter_gather.hpp:543
array< index_t, NDimWindowAdaptorTop > AdaptorTopIndex
Definition: tile_scatter_gather.hpp:71
CK_TILE_DEVICE void update_page_idx(const PageIdxArray &new_idx)
Definition: tile_scatter_gather.hpp:969
static constexpr CK_TILE_DEVICE index_t get_num_of_dimension()
Definition: tile_scatter_gather.hpp:261
constexpr CK_TILE_DEVICE auto get_window_origin() const
Definition: tile_window_base.hpp:45
BottomTensorIndex window_origin_
Definition: tile_window_base.hpp:79
constexpr CK_TILE_DEVICE auto get_bottom_tensor_view() const
Definition: tile_window_base.hpp:47
constexpr CK_TILE_DEVICE auto get_window_lengths() const
Definition: tile_window_base.hpp:46
WindowLengths window_lengths_
Definition: tile_window_base.hpp:81
This class provides description of tile windowed view on the device memory.
Definition: tile_window.hpp:1195
#define TO_SEQUENCE(a, n)
Definition: to_sequence.hpp:10