32 template <
typename BottomTensorView_,
33 typename WindowLengths_,
34 typename StaticTileDistribution_,
50 static constexpr
index_t NDimP = TileDstr::get_num_of_dimension_p();
51 static constexpr
index_t NDimY = TileDstr::get_num_of_dimension_y();
55 static_assert(NumCoord == 1);
60 "wrong! lengths should be static");
63 static_assert(
NDimBottomTensor == WindowAdaptor::get_num_of_bottom_dimension(),
64 "wrong! inconsistent # of diemsnions");
78 static constexpr
auto get_vector_dim_y_scalar_per_vector()
80 const auto [ys_vector_lengths, ys_vector_strides] =
89 if(ys_vector_strides[i] == 1 && ys_vector_lengths[i] > ScalarPerVector_)
91 ScalarPerVector_ = ys_vector_lengths[i];
96 return make_tuple(VectorDimY_, ScalarPerVector_);
102 get_vector_dim_y_scalar_per_vector().template at<1>();
109 static constexpr
auto scalars_per_access_ = [] {
114 constexpr
auto NDimY_ =
NDimY;
116 return TO_SEQUENCE(scalars_per_access_arr, NDimY_);
119 static constexpr
auto get_space_filling_curve()
121 constexpr
auto tile_dstr =
TileDstr{};
123 constexpr
auto thread_tensor_lengths_ys =
124 to_sequence(tile_dstr.get_ys_to_d_descriptor().get_lengths());
131 decltype(scalars_per_access_)>{};
135 using SFC_Ys = decltype(get_space_filling_curve());
139 static_assert(0 <
NumAccess,
"Wrong! NumAccess should be larger than 0");
140 static_assert(
NumAccess % NumCoord == 0,
"wrong! # of access is not divisible by NumCoord");
162 static_assert(
NDimP == 1 or
NDimP == 2,
"wrong!");
166 if constexpr(
NDimP == 1)
171 else if constexpr(
NDimP == 2)
173 window_adaptor_thread_coord_tmp =
187 window_origin + window_adaptor_thread_coord_tmp.get_bottom_index();
195 using SFC_Ys =
typename Traits::SFC_Ys;
198 auto window_adaptor_thread_coord = window_adaptor_thread_coord_tmp;
199 auto bottom_tensor_thread_coord = bottom_tensor_thread_coord_tmp;
201 constexpr
auto idx_diff_ys =
208 window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
211 make_tuple(window_adaptor_thread_coord, bottom_tensor_thread_coord);
238 template <
typename ATopIndex>
242 const ATopIndex& idx_diff_adaptor_top)
const
247 window_adaptor_thread_coord,
248 idx_diff_adaptor_top,
249 idx_diff_adaptor_bottom);
252 bottom_tensor_thread_coord,
253 idx_diff_adaptor_bottom);
260 const auto [bottom_tensor_top_dim_vector_lengths, bottom_tensor_top_dim_vector_strides] =
261 BottomTensorDesc::get_top_dimension_safe_vector_length_strides();
264 const auto window_adaptor_bottom_dim_vector_lengths = bottom_tensor_top_dim_vector_lengths;
265 const auto window_adaptor_bottom_dim_vector_strides = bottom_tensor_top_dim_vector_strides;
268 array<
index_t, WindowAdaptor::get_num_of_hidden_dimension()> window_adaptor_vector_lengths{
270 array<
index_t, WindowAdaptor::get_num_of_hidden_dimension()> window_adaptor_vector_strides{
273 constexpr
auto window_adaptor_bottom_dims =
274 WindowAdaptor::get_bottom_dimension_hidden_ids();
277 window_adaptor_bottom_dims,
278 window_adaptor_bottom_dim_vector_lengths);
280 window_adaptor_bottom_dims,
281 window_adaptor_bottom_dim_vector_strides);
283 const auto [window_adaptor_ps_ys_vector_lengths, window_adaptor_ps_ys_vector_strides] =
284 WindowAdaptor{}.get_top_dimension_safe_vector_length_strides(
285 window_adaptor_vector_lengths, window_adaptor_vector_strides);
298 template <
index_t i_access_unsupport_ = -1,
bool oob_conditional_check =
true>
302 constexpr
auto tile_dstr =
TileDstr{};
303 auto dst_tensor = make_static_distributed_tensor<DataType>(tile_dstr);
304 load(dst_tensor, number<i_access_unsupport_>{}, bool_constant<oob_conditional_check>{});
308 template <
typename DistributedTensor,
309 index_t i_access_unsupport_ = -1,
310 bool oob_conditional_check =
true>
315 using Traits = load_store_traits;
316 using vector_t =
typename Traits::vector_t;
317 using SFC_Ys =
typename Traits::SFC_Ys;
319 constexpr
auto tile_dstr =
TileDstr{};
322 static_for<0, NumCoord, 1>{}([&](
auto iCoord) {
327 static_for<0, NumAccessPerCoord, 1>{}([&](
auto iCoordAccess) {
328 constexpr
auto iAccess = number<iCoord * NumAccessPerCoord + iCoordAccess>{};
331 constexpr
auto idx_ys_start = SFC_Ys::get_index(iAccess);
334 const vector_t vec_value =
336 bottom_tensor_thread_coord, 0, bool_constant<oob_conditional_check>{});
339 static_for<0, Traits::ScalarPerVector, 1>{}([&](
auto j) {
342 return jj == Traits::VectorDimY ? (idx_ys_start[jj] + j)
348 tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys);
350 dst_tensor.get_thread_buffer().template at<d>() =
351 vec_value.template get_as<DataType>()[j];
355 tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys_start);
356 static_assert(d % Traits::ScalarPerVector == 0);
358 dst_tensor.get_thread_buffer().template get_as<vector_t>()(
359 number<d / Traits::ScalarPerVector>{}) = bit_cast<vector_t>(vec_value);
364 constexpr
auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess);
367 generate_tuple([&](
auto) {
return number<0>{}; }, number<NDimP>{}),
371 window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
377 template <
typename DstTile,
378 index_t i_access_unsupport_ = -1,
379 bool oob_conditional_check =
true,
380 bool pre_nop =
false>
384 bool_constant<pre_nop> = {})
const
386 using Traits = load_store_traits;
389 using vector_t =
typename Traits::vector_t;
390 using SFC_Ys =
typename Traits::SFC_Ys;
391 static constexpr
index_t YElementSize =
392 TileDstr{}.get_ys_to_d_descriptor().get_element_space_size();
393 static_assert(YElementSize % Traits::ScalarPerVector == 0);
394 using vectorized_tbuf = array<vector_t, YElementSize / Traits::ScalarPerVector>;
400 constexpr
auto tile_dstr =
TileDstr{};
402 auto& dst_vec_tbuf =
reinterpret_cast<vectorized_tbuf&
>(dst_tensor.get_thread_buffer());
405 static_for<0, NumCoord, 1>{}([&](
auto iCoord) {
410 static_for<0, NumAccessPerCoord, 1>{}([&](
auto iCoordAccess) {
411 constexpr
auto iAccess = number<iCoord * NumAccessPerCoord + iCoordAccess>{};
412 constexpr
auto pre_nop_ = [&]() {
413 if constexpr(pre_nop && iCoord == 0 && iCoordAccess == 0)
416 return bool_constant<false>{};
420 constexpr
auto idx_ys_start = SFC_Ys::get_index(iAccess);
422 tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys_start);
423 static_assert(d % Traits::ScalarPerVector == 0);
426 dst_vec_tbuf.template at<d / Traits::ScalarPerVector>(),
427 bottom_tensor_thread_coord,
429 bool_constant<oob_conditional_check>{},
431 #if CK_TILE_WORKAROUND_ROCM_6_1_SCRATCH_MEMORY_ISSUE || \
432 CK_TILE_WORKAROUND_ROCM_6_2_SCRATCH_MEMORY_ISSUE
439 constexpr
auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess);
442 generate_tuple([&](
auto) {
return number<0>{}; }, number<NDimP>{}),
446 window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
453 template <
typename LdsTileWindow_,
454 index_t i_access_unsupport_ = -1,
455 bool oob_conditional_check =
true,
456 bool pre_nop =
false>
460 bool_constant<pre_nop> = {})
const
462 using LdsTileWindow = remove_cvref_t<LdsTileWindow_>;
464 using LdsDataType =
typename LdsTileWindow::DataType;
468 static_assert(LdsTileWindow::get_num_of_dimension() == 3);
471 lds_tile.get_bottom_tensor_view().get_tensor_descriptor().calculate_offset(
472 make_tuple(number<0>{}, number<0>{}, number<0>{})) *
476 lds_tile.get_bottom_tensor_view().get_tensor_descriptor().calculate_offset(
477 make_tuple(number<0>{}, number<1>{}, number<0>{})) *
478 sizeof(LdsDataType) -
482 lds_tile.get_bottom_tensor_view().get_tensor_descriptor().calculate_offset(
483 make_tuple(number<1>{}, number<0>{}, number<0>{})) *
484 sizeof(LdsDataType) -
490 using Traits = load_store_traits;
493 using vector_t =
typename Traits::vector_t;
494 using SFC_Ys =
typename Traits::SFC_Ys;
496 LdsDataType* smem = lds_tile.get_bottom_tensor_view().get_buffer_view().p_data_;
499 static_for<0, NumCoord, 1>{}([&](
auto iCoord) {
504 static_for<0, NumAccessPerCoord, 1>{}([&](
auto iCoordAccess) {
505 constexpr
auto iAccess = number<iCoord * NumAccessPerCoord + iCoordAccess>{};
506 constexpr
auto pre_nop_ = [&]() {
507 if constexpr(pre_nop && iCoord == 0 && iCoordAccess == 0)
510 return bool_constant<false>{};
515 smem, bottom_tensor_thread_coord, 0, pre_nop_);
520 constexpr
auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess);
523 generate_tuple([&](
auto) {
return number<0>{}; }, number<NDimP>{}),
527 window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
535 template <
typename LdsTileWindow_,
536 index_t i_access_unsupport_ = -1,
537 bool oob_conditional_check =
true>
542 using LdsTileWindow = remove_cvref_t<LdsTileWindow_>;
543 using LdsDataType =
typename LdsTileWindow::DataType;
546 static_assert(LdsTileWindow::get_num_of_dimension() == 3);
551 constexpr
index_t size_per_buf =
552 lds_tile.get_bottom_tensor_view().get_tensor_descriptor().calculate_offset(
553 make_tuple(number<0>{}, number<0>{}, number<0>{}));
555 constexpr
index_t size_per_wave =
556 lds_tile.get_bottom_tensor_view().get_tensor_descriptor().calculate_offset(
557 make_tuple(number<0>{}, number<1>{}, number<0>{})) -
560 constexpr
index_t size_per_issue =
561 lds_tile.get_bottom_tensor_view().get_tensor_descriptor().calculate_offset(
562 make_tuple(number<1>{}, number<0>{}, number<0>{})) -
567 using Traits = load_store_traits;
569 using vector_t =
typename Traits::vector_t;
570 using SFC_Ys =
typename Traits::SFC_Ys;
574 lds_tile.get_bottom_tensor_view().get_buffer_view().p_data_ + m0_init_value;
577 static_for<0, NumCoord, 1>{}([&](
auto iCoord) {
582 static_for<0, NumAccessPerCoord, 1>{}([&](
auto iCoordAccess) {
583 constexpr
auto iAccess = number<iCoord * NumAccessPerCoord + iCoordAccess>{};
587 smem, bottom_tensor_thread_coord, 0, bool_constant<oob_conditional_check>{});
592 constexpr
auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess);
595 generate_tuple([&](
auto) {
return number<0>{}; }, number<NDimP>{}),
599 window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
601 smem += size_per_issue;
607 template <
index_t i_access_unsupport_ = -1,
bool oob_conditional_check =
true>
612 using Traits = load_store_traits;
615 using vector_t =
typename Traits::vector_t;
616 using SFC_Ys =
typename Traits::SFC_Ys;
618 constexpr
auto tile_dstr =
TileDstr{};
621 static_for<0, NumCoord, 1>{}([&](
auto iCoord) {
625 static_for<0, NumAccessPerCoord, 1>{}([&](
auto iCoordAccess) {
626 constexpr
auto iAccess = number<iCoord * NumAccessPerCoord + iCoordAccess>{};
629 constexpr
auto idx_ys_start = SFC_Ys::get_index(iAccess);
635 static_for<0, Traits::ScalarPerVector, 1>{}([&](
auto j) {
638 return jj == Traits::VectorDimY ? (idx_ys_start[jj] + j)
644 tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys);
646 vec_value.template get_as<DataType>()(j) =
654 bottom_tensor_thread_coord,
657 bool_constant<oob_conditional_check>{});
662 constexpr
auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess);
665 generate_tuple([&](
auto) {
return number<0>{}; }, number<NDimP>{}),
669 window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
675 template <
index_t i_access_unsupport_ = -1>
679 using Traits = load_store_traits;
681 using vector_t =
typename Traits::vector_t;
682 using SFC_Ys =
typename Traits::SFC_Ys;
684 constexpr
auto tile_dstr =
TileDstr{};
685 static constexpr
bool oob_conditional_check =
true;
688 static_for<0, NumCoord, 1>{}([&](
auto iCoord) {
693 static_for<0, NumAccessPerCoord, 1>{}([&](
auto iCoordAccess) {
694 constexpr
auto iAccess = number<iCoord * NumAccessPerCoord + iCoordAccess>{};
697 constexpr
auto idx_ys_start = SFC_Ys::get_index(iAccess);
701 static_for<0, Traits::ScalarPerVector, 1>{}([&](
auto j) {
704 return jj == Traits::VectorDimY ? (idx_ys_start[jj] + j)
709 tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys);
710 vec_value.template get_as<DataType>()(j) =
716 .template set_vectorized_elements_raw<vector_t, oob_conditional_check>(
717 bottom_tensor_thread_coord, 0, vec_value);
722 constexpr
auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess);
725 generate_tuple([&](
auto) {
return number<0>{}; }, number<NDimP>{}),
729 window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
735 template <
index_t i_access_unsupport_ = -1,
bool oob_conditional_check =
true>
740 using Traits = load_store_traits;
742 using vector_t =
typename Traits::vector_t;
743 using SFC_Ys =
typename Traits::SFC_Ys;
745 constexpr
auto tile_dstr =
TileDstr{};
748 static_for<0, NumCoord, 1>{}([&](
auto iCoord) {
753 static_for<0, NumAccessPerCoord, 1>{}([&](
auto iCoordAccess) {
754 constexpr
auto iAccess = number<iCoord * NumAccessPerCoord + iCoordAccess>{};
757 constexpr
auto idx_ys_start = SFC_Ys::get_index(iAccess);
762 static_for<0, Traits::ScalarPerVector, 1>{}([&](
auto j) {
765 return jj == Traits::VectorDimY ? (idx_ys_start[jj] + j)
771 tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys);
773 vec_value.template get_as<DataType>()(j) =
779 bottom_tensor_thread_coord,
782 bool_constant<oob_conditional_check>{});
787 constexpr
auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess);
790 generate_tuple([&](
auto) {
return number<0>{}; }, number<NDimP>{}),
794 window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
800 template <
index_t i_access_unsupport_ = -1,
bool oob_conditional_check =
true,
bool pre_nop>
804 bool_constant<pre_nop> = {})
const
806 using Traits = load_store_traits;
808 using vector_t =
typename Traits::vector_t;
809 using SFC_Ys =
typename Traits::SFC_Ys;
811 constexpr
auto tile_dstr =
TileDstr{};
814 static_for<0, NumCoord, 1>{}([&](
auto iCoord) {
819 static_for<0, NumAccessPerCoord, 1>{}([&](
auto iCoordAccess) {
820 constexpr
auto iAccess = number<iCoord * NumAccessPerCoord + iCoordAccess>{};
823 constexpr
auto idx_ys_start = SFC_Ys::get_index(iAccess);
828 static_for<0, Traits::ScalarPerVector, 1>{}([&](
auto j) {
831 return jj == Traits::VectorDimY ? (idx_ys_start[jj] + j)
837 tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys);
839 vec_value.template get_as<DataType>()(j) =
845 bottom_tensor_thread_coord,
848 bool_constant<oob_conditional_check>{},
849 bool_constant<pre_nop>{});
854 constexpr
auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess);
857 generate_tuple([&](
auto) {
return number<0>{}; }, number<NDimP>{}),
861 window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
889 static_assert(
NDimP == 1 or
NDimP == 2,
"wrong!");
893 if constexpr(
NDimP == 1)
898 else if constexpr(
NDimP == 2)
900 window_adaptor_thread_coord_tmp =
913 window_origin_ + window_adaptor_thread_coord_tmp.get_bottom_index();
921 using SFC_Ys =
typename Traits::SFC_Ys;
924 auto window_adaptor_thread_coord = window_adaptor_thread_coord_tmp;
925 auto bottom_tensor_thread_coord = bottom_tensor_thread_coord_tmp;
927 constexpr
auto idx_diff_ys =
934 window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
937 make_tuple(window_adaptor_thread_coord, bottom_tensor_thread_coord);
965 template <
typename TensorView_,
966 typename WindowLengths_,
967 typename StaticTileDistribution_,
971 const WindowLengths_& window_lengths,
972 const multi_index<TensorView_::get_num_of_dimension()>& origin,
976 return tile_window_with_static_distribution<remove_cvref_t<TensorView_>,
977 remove_cvref_t<WindowLengths_>,
978 remove_cvref_t<StaticTileDistribution_>,
980 tensor_view, window_lengths, origin, tile_distribution};
984 template <
typename TensorView_,
985 typename WindowLengths_,
986 typename StaticTileDistribution_,
990 const WindowLengths_& window_lengths,
991 const multi_index<TensorView_::get_num_of_dimension()>& origin,
995 auto w = tile_window_with_static_distribution<remove_cvref_t<TensorView_>,
996 remove_cvref_t<WindowLengths_>,
997 remove_cvref_t<StaticTileDistribution_>,
999 tensor_view, window_lengths, origin, tile_distribution};
1004 template <
typename TensorView_,
1005 typename WindowLengths_,
1006 typename StaticTileDistribution_,
1011 StaticTileDistribution_,
1015 StaticTileDistribution_,
1016 NumCoord>::BottomTensorIndex& step)
1029 template <
typename BottomTensorView_,
typename WindowLengths_>
1040 "wrong! lengths should be static");
1089 template <
typename TensorView_,
typename WindowLengths_>
1092 const WindowLengths_& window_lengths,
1093 const multi_index<TensorView_::get_num_of_dimension()>& origin)
1096 "wrong! lengths should be static");
1104 template <
typename TensorView,
typename WindowLengths>
1107 const multi_index<TensorView::get_num_of_dimension()>& origin)
1113 template <
typename TensorView,
typename WindowLengths,
typename StaticTileDistribution>
1116 const multi_index<TensorView::get_num_of_dimension()>& origin,
1125 template <
typename TensorView,
typename WindowLengths,
typename StaticTileDistribution>
1136 template <
typename TensorView,
typename WindowLengths,
typename StaticTileDistribution>
1149 template <
typename TensorView_,
typename WindowLengths_>
#define CK_TILE_DEVICE
Definition: config.hpp:40
#define CK_TILE_LDS_ADDR
Definition: config.hpp:56
#define CK_TILE_HOST_DEVICE
Definition: config.hpp:41
CK_TILE_HOST_DEVICE auto get_partition_index(Distribution)
Definition: tile_distribution.hpp:22
Definition: cluster_descriptor.hpp:13
constexpr CK_TILE_HOST_DEVICE void set_container_subset(array< T, N > &y, sequence< Is... > picks, const array< T, sizeof...(Is)> &x)
Definition: container_helper.hpp:420
constexpr CK_TILE_HOST_DEVICE void move_tensor_coordinate(const TensorDesc &tensor_desc, TensorCoord &coord, const Index &coord_step)
Definition: tensor_coordinate.hpp:72
tuple_array< T, N > thread_buffer
Definition: thread_buffer.hpp:14
CK_TILE_DEVICE auto make_tile_window_raw(const TensorView_ &tensor_view, const WindowLengths_ &window_lengths, const multi_index< TensorView_::get_num_of_dimension()> &origin, const StaticTileDistribution_ &tile_distribution, number< NumCoord >={})
Definition: tile_window.hpp:989
constexpr CK_TILE_HOST_DEVICE auto make_tensor_adaptor_coordinate(const Adaptor &adaptor, const TopIndex &idx_top)
Definition: tensor_adaptor_coordinate.hpp:55
constant< b > bool_constant
Definition: integral_constant.hpp:39
constexpr CK_TILE_HOST_DEVICE auto generate_array(F &&f, number< N >)
Definition: sequence.hpp:1106
int32_t index_t
Definition: integer.hpp:9
constexpr CK_TILE_HOST_DEVICE auto make_tensor_coordinate(const TensorDesc &tensor_desc, const TopIndex &idx_top)
Definition: tensor_coordinate.hpp:60
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition: type_traits.hpp:20
constant< v > number
Definition: integral_constant.hpp:33
CK_TILE_DEVICE index_t get_warp_id()
Definition: arch.hpp:71
constexpr CK_TILE_HOST_DEVICE auto to_sequence(tuple< number< Is >... >)
Definition: sequence.hpp:1046
constexpr CK_TILE_DEVICE auto make_tile_window(null_tensor_view, const WindowLengths &window_lengths, const multi_index< WindowLengths::size()> &, Ts &&...)
Definition: null_tile_window.hpp:72
CK_TILE_DEVICE void move_tile_window(null_tile_window< WindowLengths > &, const typename null_tile_window< WindowLengths >::BottomTensorIndex &)
Definition: null_tile_window.hpp:92
constexpr CK_TILE_HOST_DEVICE auto generate_tuple(F &&f, number< N >)
Definition: tuple.hpp:400
CK_TILE_DEVICE void m0_set_with_memory(index_t v)
Definition: utility.hpp:19
typename std::remove_reference< T >::type remove_reference_t
Definition: type_traits.hpp:14
constexpr CK_TILE_HOST_DEVICE auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:337
constexpr CK_TILE_HOST_DEVICE void move_tensor_adaptor_coordinate(const Adaptor &adaptor, AdaptorCoord &coord, const TopIndex &idx_diff_top, BottomIndex &idx_diff_bottom)
Definition: tensor_adaptor_coordinate.hpp:97
constexpr CK_TILE_HOST_DEVICE auto get_container_subset(const array< T, N > &arr, sequence< Is... >)
Definition: container_helper.hpp:389
CK_TILE_DEVICE void m0_inc_with_memory(index_t v)
Definition: utility.hpp:25
impl::is_static_impl< remove_cvref_t< T > > is_static
Definition: type_traits.hpp:86
constexpr CK_TILE_HOST_DEVICE auto container_concat(const X &x, const Ys &... ys)
Definition: container_helper.hpp:363
Definition: sequence.hpp:278
typename std::conditional< kHasContent, type0, type1 >::type type
Definition: sequence.hpp:293
Definition: integral_constant.hpp:13
Definition: type_traits.hpp:75
Definition: space_filling_curve.hpp:20
Definition: static_distributed_tensor.hpp:21
constexpr CK_TILE_HOST_DEVICE const auto & get_thread_buffer() const
Definition: static_distributed_tensor.hpp:56
Definition: functional.hpp:43
Definition: tensor_view.hpp:41
Definition: tile_distribution.hpp:72
constexpr CK_TILE_HOST_DEVICE const auto & get_ps_ys_to_xs_adaptor() const
Definition: tile_distribution.hpp:126
Definition: tile_window.hpp:76
static constexpr index_t NumAccess
Definition: tile_window.hpp:137
static constexpr index_t VectorDimY
Definition: tile_window.hpp:100
thread_buffer< DataType, ScalarPerVector > vector_t
Definition: tile_window.hpp:106
decltype(get_space_filling_curve()) SFC_Ys
Definition: tile_window.hpp:135
static constexpr index_t ScalarPerVector
Definition: tile_window.hpp:101
This class provides tile (windowed) view and access to the device memory.
Definition: tile_window.hpp:37
static constexpr CK_TILE_DEVICE bool has_static_tile_distribution()
Definition: tile_window.hpp:217
decltype(make_tensor_coordinate(BottomTensorDesc{}, BottomTensorIndex{})) BottomTensorCoord
Definition: tile_window.hpp:73
static constexpr index_t NDimWindowAdaptorTop
Definition: tile_window.hpp:47
array< index_t, NDimBottomTensor > BottomTensorIndex
Definition: tile_window.hpp:67
static constexpr CK_TILE_DEVICE index_t get_num_of_dimension()
Definition: tile_window.hpp:215
constexpr CK_TILE_DEVICE auto get_tile_distribution() const
Definition: tile_window.hpp:224
static constexpr index_t NDimY
Definition: tile_window.hpp:51
remove_cvref_t< WindowLengths_ > WindowLengths
Definition: tile_window.hpp:39
CK_TILE_DEVICE void update_raw(const static_distributed_tensor< DataType, TileDstr > &dstr_tensor, number< i_access_unsupport_ >={}, bool_constant< oob_conditional_check >={}, bool_constant< pre_nop >={}) const
Definition: tile_window.hpp:801
CK_TILE_DEVICE void move(const BottomTensorIndex &step)
Definition: tile_window.hpp:870
constexpr CK_TILE_DEVICE auto get_bottom_tensor_view() const
Definition: tile_window.hpp:226
constexpr CK_TILE_DEVICE tile_window_with_static_distribution()=default
constexpr CK_TILE_DEVICE auto get_num_of_access() const
Definition: tile_window.hpp:296
constexpr CK_TILE_DEVICE auto get_window_origin() const
Definition: tile_window.hpp:228
CK_TILE_DEVICE auto async_load(LdsTileWindow_ &&lds_tile, number< i_access_unsupport_ >={}, bool_constant< oob_conditional_check >={}) const
Definition: tile_window.hpp:538
static constexpr CK_TILE_DEVICE auto get_window_adaptor_ys_safe_vector_length_strides()
Definition: tile_window.hpp:257
CK_TILE_DEVICE void store_raw(const static_distributed_tensor< DataType, TileDstr > &dstr_tensor, number< i_access_unsupport_ >={}) const
Definition: tile_window.hpp:676
CK_TILE_DEVICE auto load(DistributedTensor &dst_tensor, number< i_access_unsupport_ >={}, bool_constant< oob_conditional_check >={}) const
Definition: tile_window.hpp:311
CK_TILE_DEVICE void store(const static_distributed_tensor< DataType, TileDstr > &dstr_tensor, number< i_access_unsupport_ >={}, bool_constant< oob_conditional_check >={}) const
Definition: tile_window.hpp:608
CK_TILE_DEVICE auto load(number< i_access_unsupport_ >={}, bool_constant< oob_conditional_check >={}) const
Definition: tile_window.hpp:299
static constexpr index_t NDimBottomTensor
Definition: tile_window.hpp:48
BottomTensorIndex window_origin_
Definition: tile_window.hpp:951
CK_TILE_HOST_DEVICE void init_raw()
Definition: tile_window.hpp:941
CK_TILE_DEVICE void update(const static_distributed_tensor< DataType, TileDstr > &dstr_tensor, number< i_access_unsupport_ >={}, bool_constant< oob_conditional_check >={}) const
Definition: tile_window.hpp:736
static constexpr auto I0
Definition: tile_window.hpp:53
constexpr CK_TILE_DEVICE auto get_window_lengths() const
Definition: tile_window.hpp:222
array< tuple< WindowAdaptorCoord, BottomTensorCoord >, NumCoord > pre_computed_coords_
Definition: tile_window.hpp:961
remove_cvref_t< StaticTileDistribution_ > TileDstr
Definition: tile_window.hpp:40
remove_cvref_t< typename BottomTensorView::DataType > DataType
Definition: tile_window.hpp:45
CK_TILE_DEVICE void load_raw(DstTile &dst_tensor, number< i_access_unsupport_ >={}, bool_constant< oob_conditional_check >={}, bool_constant< pre_nop >={}) const
Definition: tile_window.hpp:381
remove_reference_t< BottomTensorView_ > BottomTensorView
Definition: tile_window.hpp:38
static constexpr auto I1
Definition: tile_window.hpp:54
typename TileDstr::PsYs2XsAdaptor WindowAdaptor
Definition: tile_window.hpp:42
WindowLengths window_lengths_
Definition: tile_window.hpp:948
typename BottomTensorView::TensorDesc BottomTensorDesc
Definition: tile_window.hpp:43
CK_TILE_DEVICE auto async_load_raw(LdsTileWindow_ &&lds_tile, number< i_access_unsupport_ >={}, bool_constant< oob_conditional_check >={}, bool_constant< pre_nop >={}) const
Definition: tile_window.hpp:457
CK_TILE_DEVICE void set_window_origin(const BottomTensorIndex &new_window_origin)
Definition: tile_window.hpp:881
TileDstr tile_dstr_
Definition: tile_window.hpp:956
constexpr CK_TILE_DEVICE void set_bottom_tensor_view_data_ptr(typename BottomTensorView::DataType *data)
Definition: tile_window.hpp:231
BottomTensorView bottom_tensor_view_
Definition: tile_window.hpp:945
constexpr CK_TILE_DEVICE tile_window_with_static_distribution(const BottomTensorView &bottom_tensor_view, const WindowLengths &window_lengths, const BottomTensorIndex &window_origin, const TileDstr &tile_distribution)
Definition: tile_window.hpp:147
static constexpr index_t NumAccessPerCoord
Definition: tile_window.hpp:143
CK_TILE_DEVICE void move_window_adaptor_and_bottom_tensor_thread_coordinate(WindowAdaptorCoord &window_adaptor_thread_coord, BottomTensorCoord &bottom_tensor_thread_coord, const ATopIndex &idx_diff_adaptor_top) const
Definition: tile_window.hpp:239
decltype(make_tensor_adaptor_coordinate(WindowAdaptor{}, AdaptorTopIndex{})) WindowAdaptorCoord
Definition: tile_window.hpp:70
static constexpr index_t NDimP
Definition: tile_window.hpp:50
array< index_t, NDimWindowAdaptorTop > AdaptorTopIndex
Definition: tile_window.hpp:66
This class provides description of tile windowed view on the device memory.
Definition: tile_window.hpp:1031
static constexpr index_t NDimBottomTensor
Definition: tile_window.hpp:1037
BottomTensorView bottom_tensor_view_
Definition: tile_window.hpp:1080
static constexpr CK_TILE_DEVICE index_t get_num_of_dimension()
Definition: tile_window.hpp:1056
constexpr CK_TILE_DEVICE auto get_window_lengths() const
Definition: tile_window.hpp:1058
constexpr CK_TILE_DEVICE auto get_bottom_tensor_view() const
Definition: tile_window.hpp:1060
constexpr CK_TILE_DEVICE tile_window_with_static_lengths(const BottomTensorView &bottom_tensor_view, const WindowLengths &window_lengths, const BottomTensorIndex &window_origin)
Definition: tile_window.hpp:1046
constexpr CK_TILE_DEVICE tile_window_with_static_lengths()=default
typename BottomTensorView::DataType DataType
Definition: tile_window.hpp:1035
CK_TILE_DEVICE void move(const BottomTensorIndex &step)
Definition: tile_window.hpp:1076
typename BottomTensorView::TensorDesc BottomTensorDesc
Definition: tile_window.hpp:1034
constexpr CK_TILE_DEVICE auto get_window_origin() const
Definition: tile_window.hpp:1062
remove_reference_t< BottomTensorView_ > BottomTensorView
Definition: tile_window.hpp:1032
BottomTensorIndex window_origin_
Definition: tile_window.hpp:1086
CK_TILE_DEVICE void set_window_origin(const BottomTensorIndex &new_window_origin)
Definition: tile_window.hpp:1064
WindowLengths window_lengths_
Definition: tile_window.hpp:1083
remove_cvref_t< WindowLengths_ > WindowLengths
Definition: tile_window.hpp:1033
constexpr CK_TILE_DEVICE void set_bottom_tensor_view_data_ptr(typename BottomTensorView::DataType *data)
Definition: tile_window.hpp:1070
#define TO_SEQUENCE(a, n)
Definition: to_sequence.hpp:10