33 template <
typename BottomTensorView_,
34 typename WindowLengths_,
35 typename StaticTileDistribution_,
39 tile_window_with_static_distribution<BottomTensorView_,
41 StaticTileDistribution_,
45 StaticTileDistribution_>
50 StaticTileDistribution_,
54 StaticTileDistribution_>;
58 static_assert(NumCoord == 1);
60 static_assert(Base::Traits::NumAccess % NumCoord == 0,
61 "wrong! # of access is not divisible by NumCoord");
84 window_origin + window_adaptor_thread_coord_tmp.get_bottom_index();
87 bottom_tensor_view.get_tensor_descriptor(), bottom_tensor_thread_origin_idx_tmp);
91 using Traits =
typename Base::Traits;
92 using SFC_Ys =
typename Traits::SFC_Ys;
95 auto window_adaptor_thread_coord = window_adaptor_thread_coord_tmp;
96 auto bottom_tensor_thread_coord = bottom_tensor_thread_coord_tmp;
98 constexpr
auto idx_diff_ys =
106 window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
109 make_tuple(window_adaptor_thread_coord, bottom_tensor_thread_coord);
113 template <
index_t i_access_unsupport_ = -1,
bool oob_conditional_check =
true>
118 auto dst_tensor = make_static_distributed_tensor<typename Base::DataType>(tile_dstr);
119 load(dst_tensor, number<i_access_unsupport_>{}, bool_constant<oob_conditional_check>{});
133 template <
typename TileWindow_,
134 typename ElementWise_,
135 index_t i_access_unsupport_ = -1,
136 bool oob_conditional_check =
true>
138 ElementWise_ elementwise,
143 auto dst_tensor = make_static_distributed_tensor<typename Base::DataType>(tile_dstr);
147 number<i_access_unsupport_>{},
148 bool_constant<oob_conditional_check>{});
152 template <
typename DistributedTensor,
153 typename TileWindow_,
154 typename ElementWise_,
155 index_t i_access_unsupport_ = -1,
156 bool oob_conditional_check =
true>
158 const TileWindow_& tile_window,
159 ElementWise_ elementwise,
164 using Traits =
typename Base::Traits;
165 using vector_t =
typename Traits::vector_t;
166 using SFC_Ys =
typename Traits::SFC_Ys;
169 constexpr
auto sizeOfTuple = TileWindow_::size();
171 static_for<0, NumCoord, 1>{}([&](
auto iCoord) {
173 auto window_adaptor_thread_coord =
174 tile_window[number<0>{}].pre_computed_coords_[iCoord][
I0];
175 auto bottom_tensor_thread_coord =
176 tile_window[number<0>{}].pre_computed_coords_[iCoord][
I1];
178 static_for<0, NumAccessPerCoord, 1>{}([&](
auto iCoordAccess) {
179 constexpr
auto iAccess = number<iCoord * NumAccessPerCoord + iCoordAccess>{};
182 constexpr
auto idx_ys_start = SFC_Ys::get_index(iAccess);
187 return tile_window[number<jj>{}]
188 .get_bottom_tensor_view()
189 .
template get_vectorized_elements<vector_t>(
190 bottom_tensor_thread_coord,
192 bool_constant<oob_conditional_check>{});
194 number<sizeOfTuple>{});
197 static_for<0, Traits::ScalarPerVector, Traits::PackedSize>{}([&](
auto j) {
200 return jj == Traits::VectorDimY ? (idx_ys_start[jj] + j)
203 number<Base::NDimY>{});
206 tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys) /
211 elementwise(dst_tensor.get_thread_buffer().template at<d>(),
220 constexpr
auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess);
223 generate_tuple([&](
auto) {
return number<0>{}; }, number<Base::NDimP>{}),
227 window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
233 template <
typename DistributedTensor,
234 index_t i_access_unsupport_ = -1,
235 bool oob_conditional_check =
true>
240 using Traits =
typename Base::Traits;
241 using vector_t =
typename Traits::vector_t;
242 using SFC_Ys =
typename Traits::SFC_Ys;
247 static_for<0, NumCoord, 1>{}([&](
auto iCoord) {
252 static_for<0, NumAccessPerCoord, 1>{}([&](
auto iCoordAccess) {
253 constexpr
auto iAccess = number<iCoord * NumAccessPerCoord + iCoordAccess>{};
256 constexpr
auto idx_ys_start = SFC_Ys::get_index(iAccess);
259 const vector_t vec_value =
261 bottom_tensor_thread_coord, 0, bool_constant<oob_conditional_check>{});
263 static_for<0, Traits::ScalarPerVector, Traits::PackedSize>{}([&](
auto j) {
266 return jj == Traits::VectorDimY ? (idx_ys_start[jj] + j)
269 number<Base::NDimY>{});
272 tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys) /
275 dst_tensor.get_thread_buffer().template at<d>() =
277 .template get_as<typename Base::DataType>()[j / Traits::PackedSize];
282 constexpr
auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess);
285 generate_tuple([&](
auto) {
return number<0>{}; }, number<Base::NDimP>{}),
289 window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
295 template <
typename DstTile,
296 index_t i_access_unsupport_ = -1,
297 bool oob_conditional_check =
true,
298 bool pre_nop =
false>
302 bool_constant<pre_nop> = {})
const
304 using Traits =
typename Base::Traits;
305 using vector_t =
typename Traits::vector_t;
306 using SFC_Ys =
typename Traits::SFC_Ys;
307 static constexpr
index_t YElementSize =
308 typename Base::TileDstr{}.get_ys_to_d_descriptor().get_element_space_size();
309 static_assert(YElementSize % (Traits::PackedSize * Traits::ScalarPerVector) == 0);
310 using vectorized_tbuf =
311 array<vector_t, YElementSize / (Traits::PackedSize * Traits::ScalarPerVector)>;
315 auto& dst_vec_tbuf =
reinterpret_cast<vectorized_tbuf&
>(dst_tensor.get_thread_buffer());
318 static_for<0, NumCoord, 1>{}([&](
auto iCoord) {
323 static_for<0, NumAccessPerCoord, 1>{}([&](
auto iCoordAccess) {
324 constexpr
auto iAccess = number<iCoord * NumAccessPerCoord + iCoordAccess>{};
325 constexpr
auto pre_nop_ = [&]() {
326 if constexpr(pre_nop && iCoord == 0 && iCoordAccess == 0)
329 return bool_constant<false>{};
333 constexpr
auto idx_ys_start = SFC_Ys::get_index(iAccess);
335 tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys_start) /
337 static_assert(d % Traits::ScalarPerVector == 0);
340 dst_vec_tbuf.template at<d / Traits::ScalarPerVector>(),
341 bottom_tensor_thread_coord,
343 bool_constant<oob_conditional_check>{},
345 #if CK_TILE_WORKAROUND_ROCM_6_1_SCRATCH_MEMORY_ISSUE || \
346 CK_TILE_WORKAROUND_ROCM_6_2_SCRATCH_MEMORY_ISSUE
353 constexpr
auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess);
356 generate_tuple([&](
auto) {
return number<0>{}; }, number<Base::NDimP>{}),
360 window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
367 template <
typename LdsTileWindow_,
368 index_t i_access_unsupport_ = -1,
369 bool oob_conditional_check =
true,
370 bool pre_nop =
false>
374 bool_constant<pre_nop> = {})
const
376 using LdsTileWindow = remove_cvref_t<LdsTileWindow_>;
378 using LdsDataType =
typename LdsTileWindow::DataType;
382 static_assert(LdsTileWindow::get_num_of_dimension() == 3);
385 lds_tile.get_bottom_tensor_view().get_tensor_descriptor().calculate_offset(
386 make_tuple(number<0>{}, number<0>{}, number<0>{})) *
390 lds_tile.get_bottom_tensor_view().get_tensor_descriptor().calculate_offset(
391 make_tuple(number<0>{}, number<1>{}, number<0>{})) *
392 sizeof(LdsDataType) -
396 lds_tile.get_bottom_tensor_view().get_tensor_descriptor().calculate_offset(
397 make_tuple(number<1>{}, number<0>{}, number<0>{})) *
398 sizeof(LdsDataType) -
403 size_per_buf + size_per_wave * get_warp_id(bool_constant<false>{});
407 using Traits =
typename Base::Traits;
409 using vector_t =
typename Traits::vector_t;
410 using SFC_Ys =
typename Traits::SFC_Ys;
412 LdsDataType* smem = lds_tile.get_bottom_tensor_view().get_buffer_view().p_data_;
415 static_for<0, NumCoord, 1>{}([&](
auto iCoord) {
420 static_for<0, NumAccessPerCoord, 1>{}([&](
auto iCoordAccess) {
421 constexpr
auto iAccess = number<iCoord * NumAccessPerCoord + iCoordAccess>{};
422 constexpr
auto pre_nop_ = [&]() {
423 if constexpr(pre_nop && iCoord == 0 && iCoordAccess == 0)
426 return bool_constant<false>{};
431 smem, bottom_tensor_thread_coord, 0, pre_nop_);
436 constexpr
auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess);
439 generate_tuple([&](
auto) {
return number<0>{}; }, number<Base::NDimP>{}),
443 window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
451 template <
typename LdsTileWindow_,
452 index_t i_access_unsupport_ = -1,
453 bool oob_conditional_check =
true>
458 using LdsTileWindow = remove_cvref_t<LdsTileWindow_>;
459 using LdsDataType =
typename LdsTileWindow::DataType;
460 using Traits =
typename Base::Traits;
462 using vector_t =
typename Traits::vector_t;
463 using SFC_Ys =
typename Traits::SFC_Ys;
466 const auto window_origin = lds_tile.get_window_origin();
467 const auto& bottom_tensor_view = lds_tile.get_bottom_tensor_view();
468 const auto& tensor_descriptor = bottom_tensor_view.get_tensor_descriptor();
469 auto smem_base_ptr = bottom_tensor_view.get_buffer_view().p_data_;
471 static_for<0, NumCoord, 1>{}([&](
auto iCoord) {
475 static_for<0, NumAccessPerCoord, 1>{}([&](
auto iCoordAccess) {
476 constexpr
auto iAccess = number<iCoord * NumAccessPerCoord + iCoordAccess>{};
479 auto lds_bottom_tensor_thread_idx =
480 window_origin + window_adaptor_thread_coord.get_bottom_index();
483 const auto lds_coord =
487 CK_TILE_LDS_ADDR LdsDataType* smem = smem_base_ptr + lds_coord.get_offset();
492 bottom_tensor_thread_coord,
494 bool_constant<oob_conditional_check>{});
499 constexpr
auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess);
501 generate_tuple([&](
auto) {
return number<0>{}; }, number<Base::NDimP>{}),
505 window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
511 template <
typename Policy,
index_t i_access_unsupport_ = -1,
bool oob_conditional_check =
true>
515 auto dst_tensor = make_static_distributed_tensor<typename Base::DataType>(tile_dstr);
516 this->
template load_transpose<Policy>(
521 template <
typename Policy,
522 typename DistributedTensor,
523 index_t i_access_unsupport_ = -1,
524 bool oob_conditional_check =
true>
529 using Traits =
typename Base::Traits;
530 using vector_t =
typename Traits::vector_t;
531 using SFC_Ys =
typename Traits::SFC_Ys;
535 constexpr
auto group_func = Policy::group_func;
538 static_for<0, NumCoord, 1>{}([&](
auto iCoord) {
543 static_for<0, NumAccessPerCoord, 1>{}([&](
auto iCoordAccess) {
544 constexpr
auto iAccess = number<iCoord * NumAccessPerCoord + iCoordAccess>{};
547 constexpr
auto idx_ys_start = SFC_Ys::get_index(iAccess);
550 const vector_t vec_value =
552 .template get_transpose_vectorized_elements<vector_t>(
553 bottom_tensor_thread_coord, 0);
555 static_for<0, Traits::ScalarPerVector, 1>{}([&](
auto j) {
558 return jj == Traits::VectorDimY ? (idx_ys_start[jj] + j)
561 number<Base::NDimY>{});
563 constexpr
auto grouped_idx_ys = group_func(orig_idx_ys);
565 constexpr
index_t linear_distributed_index =
566 tile_dstr.get_ys_to_d_descriptor().calculate_offset(grouped_idx_ys);
568 dst_tensor.get_thread_buffer().template at<linear_distributed_index>() =
569 vec_value.template get_as<typename Base::DataType>()[j];
574 constexpr
auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess);
577 generate_tuple([&](
auto) {
return number<0>{}; }, number<Base::NDimP>{}),
581 window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
587 template <
index_t i_access_unsupport_ = -1,
bool oob_conditional_check =
true>
593 using Traits =
typename Base::Traits;
595 using vector_t =
typename Traits::vector_t;
596 using SFC_Ys =
typename Traits::SFC_Ys;
601 static_for<0, NumCoord, 1>{}([&](
auto iCoord) {
605 static_for<0, NumAccessPerCoord, 1>{}([&](
auto iCoordAccess) {
606 constexpr
auto iAccess = number<iCoord * NumAccessPerCoord + iCoordAccess>{};
609 constexpr
auto idx_ys_start = SFC_Ys::get_index(iAccess);
615 static_for<0, Traits::ScalarPerVector, Traits::PackedSize>{}([&](
auto j) {
618 return jj == Traits::VectorDimY ? (idx_ys_start[jj] + j)
621 number<Base::NDimY>{});
624 tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys) /
627 vec_value.template get_as<typename Base::DataType>()(j / Traits::PackedSize) =
628 dstr_tensor.get_thread_buffer().template at<d>();
635 bottom_tensor_thread_coord,
638 bool_constant<oob_conditional_check>{});
643 constexpr
auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess);
646 generate_tuple([&](
auto) {
return number<0>{}; }, number<Base::NDimP>{}),
650 window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
656 template <
index_t i_access_unsupport_ = -1>
662 using Traits =
typename Base::Traits;
664 using vector_t =
typename Traits::vector_t;
665 using SFC_Ys =
typename Traits::SFC_Ys;
668 static constexpr
bool oob_conditional_check =
true;
671 static_for<0, NumCoord, 1>{}([&](
auto iCoord) {
676 static_for<0, NumAccessPerCoord, 1>{}([&](
auto iCoordAccess) {
677 constexpr
auto iAccess = number<iCoord * NumAccessPerCoord + iCoordAccess>{};
680 constexpr
auto idx_ys_start = SFC_Ys::get_index(iAccess);
684 static_for<0, Traits::ScalarPerVector, Traits::PackedSize>{}([&](
auto j) {
687 return jj == Traits::VectorDimY ? (idx_ys_start[jj] + j)
690 number<Base::NDimY>{});
692 tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys) /
694 vec_value.template get_as<typename Base::DataType>()(j / Traits::PackedSize) =
700 .template set_vectorized_elements_raw<vector_t, oob_conditional_check>(
701 bottom_tensor_thread_coord, 0, vec_value);
706 constexpr
auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess);
709 generate_tuple([&](
auto) {
return number<0>{}; }, number<Base::NDimP>{}),
713 window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
719 template <
index_t i_access_unsupport_ = -1,
bool oob_conditional_check =
true>
726 using Traits =
typename Base::Traits;
728 using vector_t =
typename Traits::vector_t;
729 using SFC_Ys =
typename Traits::SFC_Ys;
734 static_for<0, NumCoord, 1>{}([&](
auto iCoord) {
739 static_for<0, NumAccessPerCoord, 1>{}([&](
auto iCoordAccess) {
740 constexpr
auto iAccess = number<iCoord * NumAccessPerCoord + iCoordAccess>{};
743 constexpr
auto idx_ys_start = SFC_Ys::get_index(iAccess);
748 static_for<0, Traits::ScalarPerVector, Traits::PackedSize>{}([&](
auto j) {
751 return jj == Traits::VectorDimY ? (idx_ys_start[jj] + j)
754 number<Base::NDimY>{});
757 tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys) /
760 vec_value.template get_as<typename Base::DataType>()(j / Traits::PackedSize) =
766 bottom_tensor_thread_coord,
769 bool_constant<oob_conditional_check>{});
774 constexpr
auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess);
777 generate_tuple([&](
auto) {
return number<0>{}; }, number<Base::NDimP>{}),
781 window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
787 template <
index_t i_access_unsupport_ = -1,
bool oob_conditional_check =
true,
bool pre_nop>
793 bool_constant<pre_nop> = {})
const
795 using Traits =
typename Base::Traits;
797 using vector_t =
typename Traits::vector_t;
798 using SFC_Ys =
typename Traits::SFC_Ys;
803 static_for<0, NumCoord, 1>{}([&](
auto iCoord) {
808 static_for<0, NumAccessPerCoord, 1>{}([&](
auto iCoordAccess) {
809 constexpr
auto iAccess = number<iCoord * NumAccessPerCoord + iCoordAccess>{};
812 constexpr
auto idx_ys_start = SFC_Ys::get_index(iAccess);
817 static_for<0, Traits::ScalarPerVector, Traits::PackedSize>{}([&](
auto j) {
820 return jj == Traits::VectorDimY ? (idx_ys_start[jj] + j)
823 number<Base::NDimY>{});
826 tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys) /
829 vec_value.template get_as<typename Base::DataType>()(j / Traits::PackedSize) =
835 bottom_tensor_thread_coord,
838 bool_constant<oob_conditional_check>{},
839 bool_constant<pre_nop>{});
844 constexpr
auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess);
847 generate_tuple([&](
auto) {
return number<0>{}; }, number<Base::NDimP>{}),
851 window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
877 this->
window_origin_ + window_adaptor_thread_coord_tmp.get_bottom_index();
884 using Traits =
typename Base::Traits;
885 using SFC_Ys =
typename Traits::SFC_Ys;
888 auto window_adaptor_thread_coord = window_adaptor_thread_coord_tmp;
889 auto bottom_tensor_thread_coord = bottom_tensor_thread_coord_tmp;
891 constexpr
auto idx_diff_ys =
899 window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
902 make_tuple(window_adaptor_thread_coord, bottom_tensor_thread_coord);
914 template <
typename TensorView_,
915 typename WindowLengths_,
916 typename StaticTileDistribution_,
920 const WindowLengths_& window_lengths,
921 const multi_index<TensorView_::get_num_of_dimension()>& origin,
925 return tile_window_with_static_distribution<remove_cvref_t<TensorView_>,
926 remove_cvref_t<WindowLengths_>,
927 remove_cvref_t<StaticTileDistribution_>,
929 tensor_view, window_lengths, origin, tile_distribution};
933 template <
typename TensorView_,
934 typename WindowLengths_,
935 typename StaticTileDistribution_,
939 const WindowLengths_& window_lengths,
940 const multi_index<TensorView_::get_num_of_dimension()>& origin,
944 auto w = tile_window_with_static_distribution<remove_cvref_t<TensorView_>,
945 remove_cvref_t<WindowLengths_>,
946 remove_cvref_t<StaticTileDistribution_>,
948 tensor_view, window_lengths, origin, tile_distribution};
953 template <
typename TensorView_,
954 typename WindowLengths_,
955 typename StaticTileDistribution_,
960 StaticTileDistribution_,
964 StaticTileDistribution_,
965 NumCoord>::BottomTensorIndex& step)
970 template <
typename TensorView_,
971 typename WindowLengths_,
972 typename StaticTileDistribution_,
977 StaticTileDistribution_,
981 StaticTileDistribution_,
982 NumCoord>::BottomTensorIndex& step)
986 StaticTileDistribution_,
989 static constexpr
auto N = T::size();
993 template <
typename TileWindowWithStaticDistributionType,
999 static constexpr
auto N = TileWindowWithStaticDistributionType::size();
1011 template <
typename BottomTensorView_,
typename WindowLengths_>
1013 :
public tile_window_base<tile_window_with_static_lengths<BottomTensorView_, WindowLengths_>,
1047 template <
typename DataType>
1052 const char* label =
"")
const
1057 printf(
"%s Window Range [%d:%d, %d:%d] (origin: %d, %d):\n",
1066 for(
index_t i = start_i; i < end_i; i++)
1068 for(
index_t j = start_j; j < end_j; j++)
1073 make_tuple(window_origin[0] + i, window_origin[1] + j));
1077 auto buf =
tensor_view.template get_vectorized_elements<ThreadBuf>(coord, 0);
1079 printf(
" %s[%d,%d] = %f", label, i, j,
static_cast<float>(
value));
1087 template <
typename TensorView_,
typename WindowLengths_>
1090 const WindowLengths_& window_lengths,
1091 const multi_index<TensorView_::get_num_of_dimension()>& origin)
1094 "wrong! lengths should be static");
1102 template <
typename TensorView,
typename WindowLengths>
1105 const multi_index<TensorView::get_num_of_dimension()>& origin)
1111 template <
typename TensorView,
typename WindowLengths,
typename StaticTileDistribution>
1114 const multi_index<TensorView::get_num_of_dimension()>& origin,
1123 template <
typename TensorView,
typename WindowLengths,
typename StaticTileDistribution>
1134 template <
typename TensorView,
typename WindowLengths,
typename StaticTileDistribution>
1147 template <
typename TensorView_,
typename WindowLengths_>
1163 template <
typename T>
1176 template <
typename BottomTensorView_,
1177 typename WindowLengths_,
1178 typename StaticTileDistribution_,
1183 StaticTileDistribution_,
1195 template <
typename T>
1206 template <
typename T>
1217 template <
typename BottomTensorView_,
typename WindowLengths_>
1230 template <
typename T>
#define CK_TILE_DEVICE
Definition: config.hpp:41
#define CK_TILE_LDS_ADDR
Definition: config.hpp:58
CK_TILE_HOST_DEVICE auto get_partition_index(Distribution)
Definition: tile_distribution.hpp:22
Definition: cluster_descriptor.hpp:13
constexpr decltype(auto) apply(F &&f, Tuple &&t)
Definition: tuple.hpp:526
constexpr bool is_tile_window_with_static_distribution_v
Helper variable template to check if a type is a tile window with static distribution.
Definition: tile_window.hpp:1196
__device__ uint32_t amd_wave_read_first_lane(uint16_t v)
Definition: amd_buffer_addressing.hpp:2834
constexpr CK_TILE_HOST_DEVICE void move_tensor_coordinate(const TensorDesc &tensor_desc, TensorCoord &coord, const Index &coord_step)
Definition: tensor_coordinate.hpp:72
CK_TILE_DEVICE auto make_tile_window_raw(const TensorView_ &tensor_view, const WindowLengths_ &window_lengths, const multi_index< TensorView_::get_num_of_dimension()> &origin, const StaticTileDistribution_ &tile_distribution, number< NumCoord >={})
Definition: tile_window.hpp:938
constexpr CK_TILE_HOST_DEVICE auto make_tensor_adaptor_coordinate(const Adaptor &adaptor, const TopIndex &idx_top)
Definition: tensor_adaptor_coordinate.hpp:55
constant< b > bool_constant
Definition: integral_constant.hpp:43
int32_t index_t
Definition: integer.hpp:9
constexpr CK_TILE_HOST_DEVICE auto make_tensor_coordinate(const TensorDesc &tensor_desc, const TopIndex &idx_top)
Definition: tensor_coordinate.hpp:60
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition: type_traits.hpp:21
constexpr bool is_tile_window_with_static_lengths_v
Helper variable template to check if a type is a tile window with static lengths.
Definition: tile_window.hpp:1231
constexpr CK_TILE_DEVICE auto make_tile_window(null_tensor_view, const WindowLengths &window_lengths, const multi_index< WindowLengths::size()> &, Ts &&...)
Definition: null_tile_window.hpp:75
CK_TILE_DEVICE void move_tile_window(null_tile_window< WindowLengths > &, const typename null_tile_window< WindowLengths >::BottomTensorIndex &)
Definition: null_tile_window.hpp:95
constexpr CK_TILE_HOST_DEVICE auto generate_tuple(F &&f, number< N >)
Definition: tuple.hpp:429
CK_TILE_DEVICE void m0_set_with_memory(index_t v)
Definition: utility.hpp:19
constexpr CK_TILE_HOST_DEVICE auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:360
CK_TILE_DEVICE void m0_inc_with_memory(index_t v)
Definition: utility.hpp:25
constexpr CK_TILE_HOST_DEVICE auto container_concat(const X &x, const Ys &... ys)
Definition: container_helper.hpp:363
bool_constant< false > false_type
Definition: integral_constant.hpp:63
typename std::enable_if< B, T >::type enable_if_t
Definition: enable_if.hpp:27
bool_constant< true > true_type
Definition: integral_constant.hpp:62
const GenericPointer< typename T::ValueType > T2 value
Definition: pointer.h:1350
Definition: integral_constant.hpp:13
Definition: type_traits.hpp:76
Type trait to determine if a type is a tile window with static distribution.
Definition: tile_window.hpp:1165
Type trait to determine if a type is a tile window with static lengths.
Definition: tile_window.hpp:1208
Definition: static_distributed_tensor.hpp:21
constexpr CK_TILE_HOST_DEVICE const auto & get_thread_buffer() const
Definition: static_distributed_tensor.hpp:58
Definition: functional.hpp:81
Definition: functional.hpp:43
Definition: tensor_view.hpp:41
constexpr CK_TILE_HOST_DEVICE auto & get_tensor_descriptor() const
Definition: tensor_view.hpp:61
Definition: tile_distribution.hpp:72
constexpr CK_TILE_HOST_DEVICE const auto & get_ps_ys_to_xs_adaptor() const
Definition: tile_distribution.hpp:126
This class provides description of tile windowed view on the device memory.
Definition: tile_window_base.hpp:31
BottomTensorView bottom_tensor_view_
Definition: tile_window_base.hpp:85
remove_cvref_t< typename BottomTensorView::DataType > DataType
Definition: tile_window_base.hpp:36
constexpr CK_TILE_DEVICE auto get_window_origin() const
Definition: tile_window_base.hpp:45
BottomTensorIndex window_origin_
Definition: tile_window_base.hpp:79
constexpr CK_TILE_DEVICE auto get_bottom_tensor_view() const
Definition: tile_window_base.hpp:47
CK_TILE_DEVICE void move(const BottomTensorIndex &step)
Definition: tile_window_base.hpp:67
constexpr CK_TILE_DEVICE auto get_window_lengths() const
Definition: tile_window_base.hpp:46
remove_reference_t< BottomTensorView_ > BottomTensorView
Definition: tile_window_base.hpp:33
remove_cvref_t< WindowLengths_ > WindowLengths
Definition: tile_window_base.hpp:34
WindowLengths window_lengths_
Definition: tile_window_base.hpp:81
This class provides tile (windowed) view and access to the device memory.
Definition: tile_window.hpp:46
CK_TILE_DEVICE void store_raw(const static_distributed_tensor< typename Base::DataType, typename Base::TileDstr > &dstr_tensor, number< i_access_unsupport_ >={}) const
Definition: tile_window.hpp:658
CK_TILE_DEVICE void move_extended(const typename Base::BottomTensorIndex &step)
Definition: tile_window.hpp:858
CK_TILE_DEVICE auto load_transpose() const
Definition: tile_window.hpp:512
CK_TILE_DEVICE void set_window_origin_extended(const typename Base::BottomTensorIndex &)
Definition: tile_window.hpp:867
array< tuple< typename Base::WindowAdaptorCoord, typename Base::BottomTensorCoord >, NumCoord > pre_computed_coords_
Definition: tile_window.hpp:910
CK_TILE_DEVICE void update_raw(const static_distributed_tensor< typename Base::DataType, typename Base::TileDstr > &dstr_tensor, number< i_access_unsupport_ >={}, bool_constant< oob_conditional_check >={}, bool_constant< pre_nop >={}) const
Definition: tile_window.hpp:789
constexpr CK_TILE_DEVICE tile_window_with_static_distribution()=default
CK_TILE_DEVICE auto async_load(LdsTileWindow_ &&lds_tile, number< i_access_unsupport_ >={}, bool_constant< oob_conditional_check >={}) const
Definition: tile_window.hpp:454
CK_TILE_DEVICE auto load_transpose(DistributedTensor &dst_tensor, number< i_access_unsupport_ >={}, bool_constant< oob_conditional_check >={}) const
Definition: tile_window.hpp:525
CK_TILE_DEVICE auto load(DistributedTensor &dst_tensor, const TileWindow_ &tile_window, ElementWise_ elementwise, number< i_access_unsupport_ >={}, bool_constant< oob_conditional_check >={}) const
Definition: tile_window.hpp:157
CK_TILE_DEVICE auto load(DistributedTensor &dst_tensor, number< i_access_unsupport_ >={}, bool_constant< oob_conditional_check >={}) const
Definition: tile_window.hpp:236
CK_TILE_DEVICE auto load(number< i_access_unsupport_ >={}, bool_constant< oob_conditional_check >={}) const
Definition: tile_window.hpp:114
static constexpr auto I0
Definition: tile_window.hpp:56
CK_TILE_DEVICE auto load(const TileWindow_ &tile_window, ElementWise_ elementwise, number< i_access_unsupport_ >={}, bool_constant< oob_conditional_check >={}) const
Load tile with elementwise function.
Definition: tile_window.hpp:137
CK_TILE_DEVICE void load_raw(DstTile &dst_tensor, number< i_access_unsupport_ >={}, bool_constant< oob_conditional_check >={}, bool_constant< pre_nop >={}) const
Definition: tile_window.hpp:299
static constexpr auto I1
Definition: tile_window.hpp:57
CK_TILE_DEVICE auto async_load_raw(LdsTileWindow_ &&lds_tile, number< i_access_unsupport_ >={}, bool_constant< oob_conditional_check >={}, bool_constant< pre_nop >={}) const
Definition: tile_window.hpp:371
CK_TILE_DEVICE void update(const static_distributed_tensor< typename Base::DataType, typename Base::TileDstr > &dstr_tensor, number< i_access_unsupport_ >={}, bool_constant< oob_conditional_check >={}) const
Definition: tile_window.hpp:721
constexpr CK_TILE_DEVICE tile_window_with_static_distribution(const typename Base::BottomTensorView &bottom_tensor_view, const typename Base::WindowLengths &window_lengths, const typename Base::BottomTensorIndex &window_origin, const typename Base::TileDstr &tile_distribution)
Definition: tile_window.hpp:66
static constexpr index_t NumAccessPerCoord
Definition: tile_window.hpp:62
CK_TILE_DEVICE void store(const static_distributed_tensor< typename Base::DataType, typename Base::TileDstr > &dstr_tensor, number< i_access_unsupport_ >={}, bool_constant< oob_conditional_check >={}) const
Definition: tile_window.hpp:588
This class provides description of tile windowed view on the device memory.
Definition: tile_window.hpp:1016
constexpr CK_TILE_DEVICE tile_window_with_static_lengths()=default
CK_TILE_DEVICE void print_tile_window_range(index_t start_i, index_t end_i, index_t start_j, index_t end_j, const char *label="") const
Definition: tile_window.hpp:1048
constexpr CK_TILE_DEVICE tile_window_with_static_lengths(const typename Base::BottomTensorView &bottom_tensor_view, const typename Base::WindowLengths &window_lengths, const typename Base::BottomTensorIndex &window_origin)
Definition: tile_window.hpp:1024
Definition: tile_window_base.hpp:94
remove_cvref_t< StaticTileDistribution_ > TileDstr
Definition: tile_window_base.hpp:95
CK_TILE_DEVICE void move_window_adaptor_and_bottom_tensor_thread_coordinate(WindowAdaptorCoord &window_adaptor_thread_coord, BottomTensorCoord &bottom_tensor_thread_coord, const ATopIndex &idx_diff_adaptor_top) const
Definition: tile_window_base.hpp:129
TileDstr tile_dstr_
Definition: tile_window_base.hpp:253
Definition: tuple.hpp:192