32 template <
typename BottomTensorView_,
33 typename WindowLengths_,
34 typename StaticTileDistribution_,
35 typename StaticPageIndexArray_,
36 typename StaticValidArray_,
55 static constexpr
index_t NDimP = TileDstr::get_num_of_dimension_p();
56 static constexpr
index_t NDimY = TileDstr::get_num_of_dimension_y();
60 static_assert(NumCoord == 1);
65 "wrong! lengths should be static");
68 static_assert(
NDimBottomTensor == WindowAdaptor::get_num_of_bottom_dimension(),
69 "wrong! inconsistent # of diemsnions");
83 static constexpr
auto get_vector_dim_y_scalar_per_vector()
85 const auto [ys_vector_lengths, ys_vector_strides] =
93 if(ys_vector_strides[i] == 1 && ys_vector_lengths[i] > ScalarPerVector_)
95 ScalarPerVector_ = ys_vector_lengths[i];
100 return make_tuple(VectorDimY_, ScalarPerVector_);
108 get_vector_dim_y_scalar_per_vector().template at<1>();
115 static constexpr
auto scalars_per_access_ = [] {
120 constexpr
auto NDimY_ =
NDimY;
122 return TO_SEQUENCE(scalars_per_access_arr, NDimY_);
125 static constexpr
auto get_space_filling_curve()
127 constexpr
auto tile_dstr =
TileDstr{};
129 constexpr
auto thread_tensor_lengths_ys =
130 to_sequence(tile_dstr.get_ys_to_d_descriptor().get_lengths());
137 decltype(scalars_per_access_)>{};
141 using SFC_Ys = decltype(get_space_filling_curve());
145 static_assert(0 <
NumAccess,
"Wrong! NumAccess should be larger than 0");
146 static_assert(
NumAccess % NumCoord == 0,
"wrong! # of access is not divisible by NumCoord");
171 static_assert(
NDimP == 1 or
NDimP == 2,
"wrong!");
175 if constexpr(
NDimP == 1)
180 else if constexpr(
NDimP == 2)
182 window_adaptor_thread_coord_tmp =
196 window_origin + window_adaptor_thread_coord_tmp.get_bottom_index();
197 bottom_tensor_thread_origin_idx_tmp(HsGatherDim) = 0;
204 using SFC_Ys =
typename Traits::SFC_Ys;
207 auto window_adaptor_thread_coord = window_adaptor_thread_coord_tmp;
208 auto bottom_tensor_thread_coord = bottom_tensor_thread_coord_tmp;
210 constexpr
auto idx_diff_ys =
217 window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
220 make_tuple(window_adaptor_thread_coord, bottom_tensor_thread_coord);
247 template <
typename ATopIndex>
251 const ATopIndex& idx_diff_adaptor_top)
const
256 window_adaptor_thread_coord,
257 idx_diff_adaptor_top,
258 idx_diff_adaptor_bottom);
261 bottom_tensor_thread_coord,
262 idx_diff_adaptor_bottom);
269 const auto [bottom_tensor_top_dim_vector_lengths, bottom_tensor_top_dim_vector_strides] =
270 BottomTensorDesc::get_top_dimension_safe_vector_length_strides();
273 const auto window_adaptor_bottom_dim_vector_lengths = bottom_tensor_top_dim_vector_lengths;
274 const auto window_adaptor_bottom_dim_vector_strides = bottom_tensor_top_dim_vector_strides;
277 array<
index_t, WindowAdaptor::get_num_of_hidden_dimension()> window_adaptor_vector_lengths{
279 array<
index_t, WindowAdaptor::get_num_of_hidden_dimension()> window_adaptor_vector_strides{
282 constexpr
auto window_adaptor_bottom_dims =
283 WindowAdaptor::get_bottom_dimension_hidden_ids();
286 window_adaptor_bottom_dims,
287 window_adaptor_bottom_dim_vector_lengths);
289 window_adaptor_bottom_dims,
290 window_adaptor_bottom_dim_vector_strides);
292 const auto [window_adaptor_ps_ys_vector_lengths, window_adaptor_ps_ys_vector_strides] =
293 WindowAdaptor{}.get_top_dimension_safe_vector_length_strides(
294 window_adaptor_vector_lengths, window_adaptor_vector_strides);
307 template <
index_t i_access_unsupport_ = -1,
bool oob_conditional_check =
true>
311 constexpr
auto tile_dstr =
TileDstr{};
312 auto dst_tensor = make_static_distributed_tensor<DataType>(tile_dstr);
313 load(dst_tensor, number<i_access_unsupport_>{}, bool_constant<oob_conditional_check>{});
317 template <
typename DistributedTensor,
318 index_t i_access_unsupport_ = -1,
319 bool oob_conditional_check =
true>
324 using Traits = load_store_traits;
325 using vector_t =
typename Traits::vector_t;
326 using SFC_Ys =
typename Traits::SFC_Ys;
328 constexpr
auto tile_dstr =
TileDstr{};
331 static_for<0, NumCoord, 1>{}([&](
auto iCoord) {
336 static_for<0, NumAccessPerCoord, 1>{}([&](
auto iCoordAccess) {
337 constexpr
auto iAccess = number<iCoord * NumAccessPerCoord + iCoordAccess>{};
340 constexpr
auto idx_ys_start = SFC_Ys::get_index(iAccess);
341 constexpr
auto idx_gather = idx_ys_start[number<YsGatherDim>{}];
342 const auto page_offset =
page_idx_[idx_gather];
345 const vector_t vec_value = [&]() {
346 if constexpr(std::is_same_v<ValidArray, std::nullptr_t>)
349 bottom_tensor_thread_coord,
351 bool_constant<oob_conditional_check>{});
356 bottom_tensor_thread_coord,
359 bool_constant<oob_conditional_check>{});
364 static_for<0, Traits::ScalarPerVector, Traits::PackedSize>{}([&](
auto j) {
367 return jj == Traits::VectorDimY ? (idx_ys_start[jj] + j)
373 tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys) /
376 dst_tensor.get_thread_buffer().template at<d>() =
377 vec_value.template get_as<DataType>()[j / Traits::PackedSize];
381 tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys_start);
382 static_assert(d % Traits::ScalarPerVector == 0);
384 dst_tensor.get_thread_buffer().template get_as<vector_t>()(
385 number<d / Traits::ScalarPerVector>{}) = bit_cast<vector_t>(vec_value);
390 constexpr
auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess);
393 [&](
auto i) {
return i == YsGatherDim ? 0 : idx_diff_ys[i]; },
397 generate_tuple([&](
auto) {
return number<0>{}; }, number<NDimP>{}),
398 forward_step_scatter);
401 window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
408 template <
typename LdsTileWindow_,
409 index_t i_access_unsupport_ = -1,
410 bool oob_conditional_check =
true,
411 bool pre_nop =
false>
415 bool_constant<pre_nop> = {})
const
417 using LdsTileWindow = remove_cvref_t<LdsTileWindow_>;
419 using LdsDataType =
typename LdsTileWindow::DataType;
423 static_assert(LdsTileWindow::get_num_of_dimension() == 3);
426 lds_tile.get_bottom_tensor_view().get_tensor_descriptor().calculate_offset(
427 make_tuple(number<0>{}, number<0>{}, number<0>{})) *
431 lds_tile.get_bottom_tensor_view().get_tensor_descriptor().calculate_offset(
432 make_tuple(number<0>{}, number<1>{}, number<0>{})) *
433 sizeof(LdsDataType) -
437 lds_tile.get_bottom_tensor_view().get_tensor_descriptor().calculate_offset(
438 make_tuple(number<1>{}, number<0>{}, number<0>{})) *
439 sizeof(LdsDataType) -
442 const index_t m0_init_value = size_per_buf + size_per_wave * get_warp_id();
445 using Traits = load_store_traits;
448 using vector_t =
typename Traits::vector_t;
449 using SFC_Ys =
typename Traits::SFC_Ys;
451 LdsDataType* smem = lds_tile.get_bottom_tensor_view().get_buffer_view().p_data_;
454 static_for<0, NumCoord, 1>{}([&](
auto iCoord) {
459 static_for<0, NumAccessPerCoord, 1>{}([&](
auto iCoordAccess) {
460 constexpr
auto iAccess = number<iCoord * NumAccessPerCoord + iCoordAccess>{};
461 constexpr
auto pre_nop_ = [&]() {
462 if constexpr(pre_nop && iCoord == 0 && iCoordAccess == 0)
465 return bool_constant<false>{};
468 constexpr
auto idx_ys_start = SFC_Ys::get_index(iAccess);
469 constexpr
auto idx_gather = idx_ys_start[number<YsGatherDim>{}];
470 const auto page_offset =
page_idx_[idx_gather];
473 if constexpr(std::is_same_v<ValidArray, std::nullptr_t>)
476 smem, bottom_tensor_thread_coord, page_offset, 0, pre_nop_);
482 bottom_tensor_thread_coord,
492 constexpr
auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess);
495 [&](
auto i) {
return i == YsGatherDim ? 0 : idx_diff_ys[i]; },
499 generate_tuple([&](
auto) {
return number<0>{}; }, number<NDimP>{}),
500 forward_step_scatter);
503 window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
511 template <
index_t i_access_unsupport_ = -1,
bool oob_conditional_check =
true>
516 using Traits = load_store_traits;
519 using vector_t =
typename Traits::vector_t;
520 using SFC_Ys =
typename Traits::SFC_Ys;
522 constexpr
auto tile_dstr =
TileDstr{};
525 static_for<0, NumCoord, 1>{}([&](
auto iCoord) {
529 static_for<0, NumAccessPerCoord, 1>{}([&](
auto iCoordAccess) {
530 constexpr
auto iAccess = number<iCoord * NumAccessPerCoord + iCoordAccess>{};
533 constexpr
auto idx_ys_start = SFC_Ys::get_index(iAccess);
534 constexpr
auto idx_gather = idx_ys_start[number<0>{}];
535 const auto page_offset =
page_idx_[idx_gather];
544 static_for<0, Traits::ScalarPerVector, Traits::PackedSize>{}([&](
auto j) {
547 return jj == Traits::VectorDimY ? (idx_ys_start[jj] + j)
553 tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys) /
556 vec_value.template get_as<DataType>()(j / Traits::PackedSize) =
563 if constexpr(std::is_same_v<ValidArray, std::nullptr_t>)
566 bottom_tensor_thread_coord,
569 bool_constant<oob_conditional_check>{});
574 bottom_tensor_thread_coord,
578 bool_constant<oob_conditional_check>{});
585 constexpr
auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess);
588 [&](
auto i) {
return i == YsGatherDim ? 0 : idx_diff_ys[i]; },
592 generate_tuple([&](
auto) {
return number<0>{}; }, number<NDimP>{}),
593 forward_step_scatter);
596 window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
609 step_new(HsGatherDim) = 0;
621 if constexpr(std::is_same_v<ValidArray, std::nullptr_t> ==
false)
642 static_assert(
NDimP == 1 or
NDimP == 2,
"wrong!");
646 if constexpr(
NDimP == 1)
651 else if constexpr(
NDimP == 2)
653 window_adaptor_thread_coord_tmp =
666 window_origin_ + window_adaptor_thread_coord_tmp.get_bottom_index();
668 bottom_tensor_thread_origin_idx_tmp(HsGatherDim) = 0;
675 using SFC_Ys =
typename Traits::SFC_Ys;
678 auto window_adaptor_thread_coord = window_adaptor_thread_coord_tmp;
679 auto bottom_tensor_thread_coord = bottom_tensor_thread_coord_tmp;
681 constexpr
auto idx_diff_ys =
688 window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
691 make_tuple(window_adaptor_thread_coord, bottom_tensor_thread_coord);
722 template <
typename TensorView_,
723 typename WindowLengths_,
724 typename StaticTileDistribution_,
725 typename StaticPageIndexArray_,
730 const WindowLengths_& window_lengths,
731 const multi_index<TensorView_::get_num_of_dimension()>& origin,
733 const StaticPageIndexArray_& page_idx,
735 number<NumCoord> = {})
737 return tile_scatter_gather<remove_cvref_t<TensorView_>,
738 remove_cvref_t<WindowLengths_>,
739 remove_cvref_t<StaticTileDistribution_>,
740 remove_cvref_t<StaticPageIndexArray_>,
744 tensor_view, window_lengths, origin, tile_distribution, page_idx,
nullptr};
747 template <
typename TensorView,
748 typename WindowLengths,
749 typename StaticTileDistribution,
750 typename StaticPageIndexArray,
754 const multi_index<TensorView::get_num_of_dimension()>& origin,
756 const StaticPageIndexArray& page_idx,
764 number<HsGatherDim>{});
767 template <
typename TensorView,
768 typename WindowLengths,
769 typename StaticTileDistribution,
770 typename StaticPageIndexArray,
775 const StaticPageIndexArray& page_idx,
783 number<HsGatherDim>{});
786 template <
typename TensorView_,
787 typename WindowLengths_,
788 typename StaticTileDistribution_,
789 typename StaticPageIndexArray_,
790 typename StaticValidArray_,
795 const WindowLengths_& window_lengths,
796 const multi_index<TensorView_::get_num_of_dimension()>& origin,
798 const StaticPageIndexArray_& page_idx,
799 const StaticValidArray_& valids,
801 number<NumCoord> = {})
803 return tile_scatter_gather<remove_cvref_t<TensorView_>,
804 remove_cvref_t<WindowLengths_>,
805 remove_cvref_t<StaticTileDistribution_>,
806 remove_cvref_t<StaticPageIndexArray_>,
807 remove_cvref_t<StaticValidArray_>,
810 tensor_view, window_lengths, origin, tile_distribution, page_idx, valids};
813 template <
typename TensorView,
814 typename WindowLengths,
815 typename StaticTileDistribution,
816 typename StaticPageIndexArray,
817 typename StaticValidArray,
821 const multi_index<TensorView::get_num_of_dimension()>& origin,
823 const StaticPageIndexArray& page_idx,
824 const StaticValidArray& valids,
833 number<HsGatherDim>{});
836 template <
typename TensorView,
837 typename WindowLengths,
838 typename StaticTileDistribution,
839 typename StaticPageIndexArray,
840 typename StaticValidArray,
845 const StaticPageIndexArray& page_idx,
846 const StaticValidArray& valids,
855 number<HsGatherDim>{});
#define CK_TILE_DEVICE
Definition: config.hpp:41
#define CK_TILE_HOST_DEVICE
Definition: config.hpp:42
CK_TILE_HOST_DEVICE auto get_partition_index(Distribution)
Definition: tile_distribution.hpp:22
Definition: cluster_descriptor.hpp:13
constexpr CK_TILE_HOST_DEVICE void set_container_subset(array< T, N > &y, sequence< Is... > picks, const array< T, sizeof...(Is)> &x)
Definition: container_helper.hpp:420
constexpr CK_TILE_HOST_DEVICE void move_tensor_coordinate(const TensorDesc &tensor_desc, TensorCoord &coord, const Index &coord_step)
Definition: tensor_coordinate.hpp:72
constexpr CK_TILE_HOST_DEVICE auto make_tensor_adaptor_coordinate(const Adaptor &adaptor, const TopIndex &idx_top)
Definition: tensor_adaptor_coordinate.hpp:55
constant< b > bool_constant
Definition: integral_constant.hpp:43
constexpr CK_TILE_HOST_DEVICE auto generate_array(F &&f, number< N >)
Definition: sequence.hpp:1112
int32_t index_t
Definition: integer.hpp:9
constexpr CK_TILE_HOST_DEVICE auto make_tensor_coordinate(const TensorDesc &tensor_desc, const TopIndex &idx_top)
Definition: tensor_coordinate.hpp:60
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition: type_traits.hpp:21
constant< v > number
Definition: integral_constant.hpp:37
constexpr CK_TILE_HOST_DEVICE auto to_sequence(tuple< number< Is >... >)
Definition: sequence.hpp:1052
constexpr CK_TILE_HOST_DEVICE auto generate_tuple(F &&f, number< N >)
Definition: tuple.hpp:429
CK_TILE_DEVICE void m0_set_with_memory(index_t v)
Definition: utility.hpp:19
typename std::remove_reference< T >::type remove_reference_t
Definition: type_traits.hpp:15
constexpr CK_TILE_HOST_DEVICE auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:360
constexpr CK_TILE_DEVICE auto make_tile_scatter_gather(const TensorView_ &tensor_view, const WindowLengths_ &window_lengths, const multi_index< TensorView_::get_num_of_dimension()> &origin, const StaticTileDistribution_ &tile_distribution, const StaticPageIndexArray_ &page_idx, number< HsGatherDim >={}, number< NumCoord >={})
Definition: tile_scatter_gather.hpp:729
constexpr CK_TILE_HOST_DEVICE void move_tensor_adaptor_coordinate(const Adaptor &adaptor, AdaptorCoord &coord, const TopIndex &idx_diff_top, BottomIndex &idx_diff_bottom)
Definition: tensor_adaptor_coordinate.hpp:97
constexpr CK_TILE_HOST_DEVICE auto get_container_subset(const array< T, N > &arr, sequence< Is... >)
Definition: container_helper.hpp:389
CK_TILE_DEVICE void m0_inc_with_memory(index_t v)
Definition: utility.hpp:25
impl::is_static_impl< remove_cvref_t< T > > is_static
Definition: type_traits.hpp:87
constexpr CK_TILE_HOST_DEVICE auto container_concat(const X &x, const Ys &... ys)
Definition: container_helper.hpp:363
Definition: sequence.hpp:284
typename std::conditional< kHasContent, type0, type1 >::type type
Definition: sequence.hpp:299
A fixed-size array container similar to std::array with additional utilities.
Definition: array.hpp:43
Definition: integral_constant.hpp:13
Definition: type_traits.hpp:76
Definition: numeric.hpp:81
Definition: space_filling_curve.hpp:20
Definition: static_distributed_tensor.hpp:21
constexpr CK_TILE_HOST_DEVICE const auto & get_thread_buffer() const
Definition: static_distributed_tensor.hpp:58
Definition: functional.hpp:43
Definition: tensor_view.hpp:41
Definition: tile_distribution.hpp:72
constexpr CK_TILE_HOST_DEVICE const auto & get_ps_ys_to_xs_adaptor() const
Definition: tile_distribution.hpp:126
Definition: tile_scatter_gather.hpp:81
static constexpr index_t PackedSize
Definition: tile_scatter_gather.hpp:104
static constexpr index_t NumAccess
Definition: tile_scatter_gather.hpp:143
decltype(get_space_filling_curve()) SFC_Ys
Definition: tile_scatter_gather.hpp:141
static constexpr index_t VectorDimY
Definition: tile_scatter_gather.hpp:106
static constexpr index_t ScalarPerVector
Definition: tile_scatter_gather.hpp:107
This class provides tile (windowed) view and access to the device memory.
Definition: tile_scatter_gather.hpp:41
CK_TILE_DEVICE void move(const BottomTensorIndex &step)
Definition: tile_scatter_gather.hpp:605
static constexpr index_t NumAccessPerCoord
Definition: tile_scatter_gather.hpp:149
static constexpr auto I1
Definition: tile_scatter_gather.hpp:59
constexpr CK_TILE_DEVICE tile_scatter_gather(const BottomTensorView &bottom_tensor_view, const WindowLengths &window_lengths, const BottomTensorIndex &window_origin, const TileDstr &tile_distribution, const PageIdxArray &page_idx, const ValidArray &valids)
Definition: tile_scatter_gather.hpp:153
BottomTensorIndex window_origin_
Definition: tile_scatter_gather.hpp:705
CK_TILE_DEVICE auto load(number< i_access_unsupport_ >={}, bool_constant< oob_conditional_check >={}) const
Definition: tile_scatter_gather.hpp:308
WindowLengths window_lengths_
Definition: tile_scatter_gather.hpp:702
constexpr CK_TILE_DEVICE auto get_tile_distribution() const
Definition: tile_scatter_gather.hpp:233
constexpr CK_TILE_DEVICE auto get_num_of_access() const
Definition: tile_scatter_gather.hpp:305
static constexpr index_t NDimBottomTensor
Definition: tile_scatter_gather.hpp:53
static constexpr CK_TILE_DEVICE auto get_window_adaptor_ys_safe_vector_length_strides()
Definition: tile_scatter_gather.hpp:266
array< index_t, NDimBottomTensor > BottomTensorIndex
Definition: tile_scatter_gather.hpp:72
PageIdxArray page_idx_
Definition: tile_scatter_gather.hpp:712
remove_cvref_t< WindowLengths_ > WindowLengths
Definition: tile_scatter_gather.hpp:43
CK_TILE_DEVICE void set_window_origin(const BottomTensorIndex &new_window_origin)
Definition: tile_scatter_gather.hpp:634
array< tuple< WindowAdaptorCoord, BottomTensorCoord >, NumCoord > pre_computed_coords_
Definition: tile_scatter_gather.hpp:718
constexpr CK_TILE_DEVICE auto get_window_origin() const
Definition: tile_scatter_gather.hpp:237
remove_cvref_t< StaticTileDistribution_ > TileDstr
Definition: tile_scatter_gather.hpp:44
CK_TILE_DEVICE void move_window_adaptor_and_bottom_tensor_thread_coordinate(WindowAdaptorCoord &window_adaptor_thread_coord, BottomTensorCoord &bottom_tensor_thread_coord, const ATopIndex &idx_diff_adaptor_top) const
Definition: tile_scatter_gather.hpp:248
CK_TILE_DEVICE auto load(DistributedTensor &dst_tensor, number< i_access_unsupport_ >={}, bool_constant< oob_conditional_check >={}) const
Definition: tile_scatter_gather.hpp:320
CK_TILE_DEVICE void store(const static_distributed_tensor< DataType, TileDstr > &dstr_tensor, number< i_access_unsupport_ >={}, bool_constant< oob_conditional_check >={}) const
Definition: tile_scatter_gather.hpp:512
CK_TILE_DEVICE void update_page_idx_and_valids(const PageIdxArray &new_idx, const ValidArray &new_valids)
Definition: tile_scatter_gather.hpp:627
typename BottomTensorView::TensorDesc BottomTensorDesc
Definition: tile_scatter_gather.hpp:48
TileDstr tile_dstr_
Definition: tile_scatter_gather.hpp:710
ValidArray valids_
Definition: tile_scatter_gather.hpp:713
static constexpr index_t NDimY
Definition: tile_scatter_gather.hpp:56
remove_cvref_t< typename BottomTensorView::DataType > DataType
Definition: tile_scatter_gather.hpp:50
static constexpr index_t NDimWindowAdaptorTop
Definition: tile_scatter_gather.hpp:52
static constexpr CK_TILE_DEVICE bool has_static_tile_distribution()
Definition: tile_scatter_gather.hpp:226
remove_cvref_t< StaticValidArray_ > ValidArray
Definition: tile_scatter_gather.hpp:46
static constexpr index_t NDimP
Definition: tile_scatter_gather.hpp:55
remove_reference_t< BottomTensorView_ > BottomTensorView
Definition: tile_scatter_gather.hpp:42
constexpr CK_TILE_DEVICE tile_scatter_gather()=default
remove_cvref_t< StaticPageIndexArray_ > PageIdxArray
Definition: tile_scatter_gather.hpp:45
constexpr CK_TILE_DEVICE auto get_window_lengths() const
Definition: tile_scatter_gather.hpp:231
CK_TILE_HOST_DEVICE void init_raw()
Definition: tile_scatter_gather.hpp:695
static constexpr auto I0
Definition: tile_scatter_gather.hpp:58
decltype(make_tensor_coordinate(BottomTensorDesc{}, BottomTensorIndex{})) BottomTensorCoord
Definition: tile_scatter_gather.hpp:78
constexpr CK_TILE_DEVICE auto get_bottom_tensor_view() const
Definition: tile_scatter_gather.hpp:235
typename TileDstr::PsYs2XsAdaptor WindowAdaptor
Definition: tile_scatter_gather.hpp:47
decltype(make_tensor_adaptor_coordinate(WindowAdaptor{}, AdaptorTopIndex{})) WindowAdaptorCoord
Definition: tile_scatter_gather.hpp:75
constexpr CK_TILE_DEVICE void set_bottom_tensor_view_data_ptr(typename BottomTensorView::DataType *data)
Definition: tile_scatter_gather.hpp:240
BottomTensorView bottom_tensor_view_
Definition: tile_scatter_gather.hpp:699
CK_TILE_DEVICE void update_valids(const ValidArray &new_valids)
Definition: tile_scatter_gather.hpp:619
CK_TILE_DEVICE auto async_load_raw(LdsTileWindow_ &&lds_tile, number< i_access_unsupport_ >={}, bool_constant< oob_conditional_check >={}, bool_constant< pre_nop >={}) const
Definition: tile_scatter_gather.hpp:412
array< index_t, NDimWindowAdaptorTop > AdaptorTopIndex
Definition: tile_scatter_gather.hpp:71
CK_TILE_DEVICE void update_page_idx(const PageIdxArray &new_idx)
Definition: tile_scatter_gather.hpp:617
static constexpr CK_TILE_DEVICE index_t get_num_of_dimension()
Definition: tile_scatter_gather.hpp:224
constexpr CK_TILE_DEVICE auto get_window_origin() const
Definition: tile_window_base.hpp:45
constexpr CK_TILE_DEVICE auto get_bottom_tensor_view() const
Definition: tile_window_base.hpp:47
constexpr CK_TILE_DEVICE auto get_window_lengths() const
Definition: tile_window_base.hpp:46
This class provides description of tile windowed view on the device memory.
Definition: tile_window.hpp:873
#define TO_SEQUENCE(a, n)
Definition: to_sequence.hpp:10