/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/core/tensor/slice_tile.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/core/tensor/slice_tile.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/core/tensor/slice_tile.hpp Source File
slice_tile.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
15 
16 namespace ck_tile {
17 
18 template <typename BottomTensorView_,
19  typename WindowLengths_,
20  index_t... SliceBegins,
21  index_t... SliceEnds>
22 CK_TILE_DEVICE constexpr auto
24  sequence<SliceBegins...> slice_begins,
25  sequence<SliceEnds...> slice_ends)
26 {
28  // NOTE: This API will override the origin of the tile window!
29  static_assert(sizeof...(SliceBegins) == sizeof...(SliceEnds));
30  static_assert(sizeof...(SliceBegins) == TileWindow::get_num_of_dimension());
31 
32  constexpr auto slice_lengths = slice_ends - slice_begins;
33 
35  sequence_to_tuple_of_number(slice_lengths),
36  to_multi_index(slice_begins));
37 }
38 
39 template <typename DataType_,
40  typename StaticTileDistribution_,
41  index_t... SliceBegins,
42  index_t... SliceEnds>
43 CK_TILE_DEVICE constexpr auto
45  sequence<SliceBegins...> slice_begins,
46  sequence<SliceEnds...> slice_ends)
47 {
48  using DataType = remove_cvref_t<DataType_>;
49  using Distribution = remove_cvref_t<StaticTileDistribution_>;
50 
51  constexpr auto sliced_dstr_yidx_ylen =
52  detail::slice_distribution_from_x(Distribution{}, slice_begins, slice_ends);
53 
54  constexpr auto sliced_dstr = sliced_dstr_yidx_ylen.template at<0>();
55  constexpr auto sliced_y_origins = sliced_dstr_yidx_ylen.template at<1>();
56  constexpr auto sliced_y_lengths = sliced_dstr_yidx_ylen.template at<2>();
57 
58  auto sliced_tensor = make_static_distributed_tensor<DataType>(sliced_dstr);
59 
60  sliced_tensor.get_thread_buffer() =
61  tile.get_y_sliced_thread_data(sliced_y_origins, sliced_y_lengths);
62 
63  return sliced_tensor;
64 }
65 
66 template <typename DstDataType_,
67  typename DstStaticTileDistribution_,
68  typename SrcDataType_,
69  typename SrcStaticTileDistribution_,
70  index_t... SliceBegins,
71  index_t... SliceEnds>
72 CK_TILE_DEVICE constexpr auto
75  sequence<SliceBegins...> slice_begins,
76  sequence<SliceEnds...> slice_ends)
77 {
78  using DstDistribution = remove_cvref_t<DstStaticTileDistribution_>;
79 
80  constexpr auto sliced_dstr_yidx_ylen =
81  detail::slice_distribution_from_x(DstDistribution{}, slice_begins, slice_ends);
82 
83  constexpr auto sliced_dstr = sliced_dstr_yidx_ylen.template at<0>();
84  constexpr auto sliced_y_origins = sliced_dstr_yidx_ylen.template at<1>();
85  constexpr auto sliced_y_lengths = sliced_dstr_yidx_ylen.template at<2>();
86 
87  static_assert(std::is_same_v<decltype(sliced_dstr), DstDistribution>, "wrong!");
88 
89  dst_tile.SetSlicedThreadData(sliced_y_origins, sliced_y_lengths, src_tile.get_thread_buffer());
90 }
91 
92 } // namespace ck_tile
#define CK_TILE_DEVICE
Definition: config.hpp:41
constexpr CK_TILE_HOST_DEVICE auto slice_distribution_from_x(Distribution, sequence< XSliceBegins... > x_slice_begins, sequence< XSliceEnds... > x_slice_ends)
Definition: tile_distribution.hpp:554
Definition: cluster_descriptor.hpp:13
constexpr CK_TILE_DEVICE auto get_slice_tile(const tile_window_with_static_lengths< BottomTensorView_, WindowLengths_ > &tile, sequence< SliceBegins... > slice_begins, sequence< SliceEnds... > slice_ends)
Definition: slice_tile.hpp:23
constexpr CK_TILE_HOST_DEVICE auto sequence_to_tuple_of_number(sequence< Is... >)
Definition: container_helper.hpp:459
int32_t index_t
Definition: integer.hpp:9
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition: type_traits.hpp:21
constexpr CK_TILE_HOST_DEVICE auto to_multi_index(const T &x)
Definition: multi_index.hpp:33
constexpr CK_TILE_DEVICE auto set_slice_tile(static_distributed_tensor< DstDataType_, DstStaticTileDistribution_ > &dst_tile, const static_distributed_tensor< SrcDataType_, SrcStaticTileDistribution_ > &src_tile, sequence< SliceBegins... > slice_begins, sequence< SliceEnds... > slice_ends)
Definition: slice_tile.hpp:73
constexpr CK_TILE_DEVICE auto make_tile_window(null_tensor_view, const WindowLengths &window_lengths, const multi_index< WindowLengths::size()> &, Ts &&...)
Definition: null_tile_window.hpp:75
constexpr bool is_same_v
Definition: type.hpp:283
Definition: sequence.hpp:49
Definition: static_distributed_tensor.hpp:21
CK_TILE_HOST_DEVICE auto get_y_sliced_thread_data(sequence< YSliceOrigins... >, sequence< YSliceLengths... >) const
Definition: static_distributed_tensor.hpp:68
constexpr CK_TILE_HOST_DEVICE const auto & get_thread_buffer() const
Definition: static_distributed_tensor.hpp:58
constexpr CK_TILE_DEVICE auto get_bottom_tensor_view() const
Definition: tile_window_base.hpp:47
This class provides description of tile windowed view on the device memory.
Definition: tile_window.hpp:873