include/ck_tile/core/tensor/static_distributed_tensor.hpp Source File

include/ck_tile/core/tensor/static_distributed_tensor.hpp Source File#

Composable Kernel: include/ck_tile/core/tensor/static_distributed_tensor.hpp Source File
static_distributed_tensor.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
16 
17 namespace ck_tile {
18 
19 template <typename DataType_, typename StaticTileDistribution_>
21 {
24 
25  static_assert(StaticTileDistribution::is_static(),
26  "wrong! StaticTileDistribution should be known at compile tile");
27 
29  remove_cvref_t<decltype(StaticTileDistribution{}.get_ys_to_d_descriptor())>;
30 
31  static constexpr index_t kThreadElementSpaceSize = ThreadTensorDesc{}.get_element_space_size();
32  static_assert(0 < kThreadElementSpaceSize, "Make sure tile distribution is valid");
33 
34  CK_TILE_HOST_DEVICE static constexpr auto get_num_of_dimension()
35  {
36  return StaticTileDistribution::get_num_of_dimension_x();
37  }
38 
39  CK_TILE_HOST_DEVICE static constexpr auto get_lengths()
40  {
41  return StaticTileDistribution::get_lengths();
42  }
43 
45  {
46  return StaticTileDistribution{};
47  }
48 
50  {
51  return StaticTileDistribution::get_distributed_spans();
52  }
53 
54  CK_TILE_HOST_DEVICE void initialize(const DataType& x) { thread_buf_.initialize(x); }
55 
56  CK_TILE_HOST_DEVICE constexpr const auto& get_thread_buffer() const { return thread_buf_; }
57 
58  CK_TILE_HOST_DEVICE constexpr auto& get_thread_buffer() { return thread_buf_; }
59 
61  {
63  }
64 
65  template <index_t... YSliceOrigins, index_t... YSliceLengths>
68  {
69  static_assert(sizeof...(YSliceOrigins) == StaticTileDistribution::NDimY &&
70  sizeof...(YSliceLengths) == StaticTileDistribution::NDimY,
71  "wrong!");
72 
73  constexpr auto sliced_thread_tensor_desc =
75 
76  thread_buffer<DataType, sliced_thread_tensor_desc.get_element_space_size()>
77  sliced_thread_data;
78 
79  static_ford<sequence<YSliceLengths...>>{}([&](auto idx) {
80  constexpr auto idx_ys = idx + sequence<YSliceOrigins...>{};
81 
82  sliced_thread_data(number<sliced_thread_tensor_desc.calculate_offset(idx)>{}) =
83  thread_buf_[number<ThreadTensorDesc{}.calculate_offset(idx_ys)>{}];
84  });
85 
86  return sliced_thread_data;
87  }
88 
89  template <index_t... YSliceOrigins, index_t... YSliceLengths, typename SlicedThreadData>
92  const SlicedThreadData& sliced_thread_data)
93  {
94  static_assert(sizeof...(YSliceOrigins) == StaticTileDistribution::NDimY &&
95  sizeof...(YSliceLengths) == StaticTileDistribution::NDimY,
96  "wrong!");
97 
98  constexpr auto sliced_thread_tensor_desc =
100 
101  static_ford<sequence<YSliceLengths...>>{}([&](auto idx) {
102  constexpr auto idx_ys = idx + sequence<YSliceOrigins...>{};
103 
104  thread_buf_(number<ThreadTensorDesc{}.calculate_offset(idx_ys)>{}) =
105  sliced_thread_data[number<sliced_thread_tensor_desc.calculate_offset(idx)>{}];
106  });
107  }
108 
109  template <typename TileDistributedIndices>
110  CK_TILE_HOST_DEVICE constexpr const DataType& operator[](TileDistributedIndices) const
111  {
112  static_assert(is_static_v<TileDistributedIndices>,
113  "wrong! Tile Distributed Indices should be static");
114 
115  constexpr auto y_idx = get_tile_distribution().get_y_indices_from_distributed_indices(
116  TileDistributedIndices{});
117 
118  return thread_buf_[number<ThreadTensorDesc{}.calculate_offset(y_idx)>{}];
119  }
120 
121  template <typename TileDistributedIndices>
122  CK_TILE_HOST_DEVICE constexpr DataType& operator()(TileDistributedIndices)
123  {
124  static_assert(is_static_v<TileDistributedIndices>,
125  "wrong! Tile Distributed Indices should be static");
126 
127  constexpr auto y_idx = get_tile_distribution().get_y_indices_from_distributed_indices(
128  TileDistributedIndices{});
129 
130  return thread_buf_(number<ThreadTensorDesc{}.calculate_offset(y_idx)>{});
131  }
132 
133  //
135 };
136 
137 template <typename DataType, typename StaticTileDistribution>
138 CK_TILE_HOST_DEVICE constexpr auto make_static_distributed_tensor(const StaticTileDistribution&)
139 {
142 }
143 
144 template <typename DataType, typename StaticTileDistribution, typename ThreadBuffer>
145 CK_TILE_HOST_DEVICE constexpr auto make_static_distributed_tensor(const StaticTileDistribution&,
146  ThreadBuffer&& thread_buffer_)
147 {
150 }
151 
152 // get X indices from tuple of tile_distributed_index<>
153 template <typename StaticTileDistribution, typename DistributedIndices>
154 CK_TILE_HOST_DEVICE constexpr auto
156  DistributedIndices distributed_indices)
157 {
158  const auto partition_index = detail::get_partition_index(tile_distribution);
159  constexpr auto y_indices =
161 
162  const auto x_coord = make_tensor_adaptor_coordinate(
164  container_concat(partition_index, to_array<ck_tile::index_t, y_indices.size()>(y_indices)));
165 
166  return x_coord.get_bottom_index();
167 }
168 
169 template <typename DataType, typename StaticTileDistribution, typename XIndicesPredicate>
172  DataType value,
173  XIndicesPredicate predicate)
174 {
175  constexpr auto out_spans =
177  sweep_tile_span(out_spans[number<0>{}], [&](auto idx0) {
178  sweep_tile_span(out_spans[number<1>{}], [&](auto idx1) {
179  constexpr auto distributed_indices = make_tuple(idx0, idx1);
180  const auto x_indices = get_x_indices_from_distributed_indices(StaticTileDistribution{},
181  distributed_indices);
182 
183  if(predicate(x_indices))
184  {
185  out_tensor(distributed_indices) = value;
186  }
187  });
188  });
189 }
190 
191 // this function used inside span loop over
192 template <typename YLengths, index_t XUnpacks>
194 {
195  constexpr auto y_size = reduce_on_sequence(YLengths{}, multiplies{}, number<1>{});
196  constexpr auto y_packs = number<XUnpacks>{};
197  static_assert(y_size % y_packs == 0);
198  constexpr auto y_slice_size = y_size / y_packs;
199 
200  constexpr auto slice_info = slice_sequence(YLengths{}, number<y_slice_size>{});
201  constexpr auto unpacks = slice_info[number<1>{}];
202  return unpacks;
203 }
204 
205 namespace detail {
206 
207 // check if 2 static_distributed_tensor has same data type and size of element
208 // but only difference in distribution
209 template <typename X, typename Y>
211 {
212  static constexpr bool value = false;
213 };
214 
215 template <typename TypeX, typename DistX, typename TypeY, typename DistY>
217  static_distributed_tensor<TypeY, DistY>>
218 {
221  static constexpr bool value = std::is_same_v<typename Tx::DataType, typename Ty::DataType> &&
222  Tx::get_thread_buffer_size() == Ty::get_thread_buffer_size();
223 };
224 
225 template <typename X, typename Y>
226 inline constexpr bool is_similiar_distributed_tensor_v =
228 
229 } // namespace detail
230 
231 } // namespace ck_tile
#define CK_TILE_HOST_DEVICE
Definition: config.hpp:41
constexpr bool is_similiar_distributed_tensor_v
Definition: static_distributed_tensor.hpp:226
CK_TILE_HOST_DEVICE auto get_partition_index(Distribution)
Definition: tile_distribution.hpp:22
Definition: cluster_descriptor.hpp:13
constexpr CK_TILE_HOST_DEVICE auto to_array(const std::vector< X > &x)
Definition: array.hpp:241
tuple_array< T, N > thread_buffer
Definition: thread_buffer.hpp:14
constexpr CK_TILE_HOST_DEVICE auto make_tensor_adaptor_coordinate(const Adaptor &adaptor, const TopIndex &idx_top)
Definition: tensor_adaptor_coordinate.hpp:55
int32_t index_t
Definition: integer.hpp:9
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition: type_traits.hpp:20
constexpr CK_TILE_HOST_DEVICE index_t reduce_on_sequence(Seq, Reduce f, number< Init >)
Definition: sequence.hpp:973
constexpr CK_TILE_HOST_DEVICE auto make_static_distributed_tensor(const StaticTileDistribution &)
Definition: static_distributed_tensor.hpp:138
constexpr CK_TILE_HOST_DEVICE auto make_naive_tensor_descriptor_packed(const tuple< Lengths... > &lengths, number< GuaranteedLastDimensionVectorLength >=number<-1 >{})
Definition: tensor_descriptor.hpp:352
CK_TILE_DEVICE void sweep_tile_span(TileDistributedSpan_, const F &f)
Definition: sweep_tile.hpp:20
CK_TILE_HOST_DEVICE void set_tile_if(static_distributed_tensor< DataType, StaticTileDistribution > &out_tensor, DataType value, XIndicesPredicate predicate)
Definition: static_distributed_tensor.hpp:171
constexpr CK_TILE_HOST_DEVICE auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:337
constexpr CK_TILE_HOST_DEVICE auto get_y_unpacks_from_x_unpacks(YLengths, number< XUnpacks >)
Definition: static_distributed_tensor.hpp:193
constexpr CK_TILE_HOST_DEVICE auto get_x_indices_from_distributed_indices(StaticTileDistribution tile_distribution, DistributedIndices distributed_indices)
Definition: static_distributed_tensor.hpp:155
constexpr auto slice_sequence(Seq, number< SliceSize >, Mask=typename uniform_sequence_gen< Seq::size(), 1 >::type{})
Definition: sequence.hpp:1225
impl::is_static_impl< remove_cvref_t< T > > is_static
Definition: type_traits.hpp:86
constexpr CK_TILE_HOST_DEVICE auto container_concat(const X &x, const Ys &... ys)
Definition: container_helper.hpp:363
Definition: integral_constant.hpp:13
Definition: static_distributed_tensor.hpp:211
static constexpr bool value
Definition: static_distributed_tensor.hpp:212
Definition: math.hpp:98
Definition: sequence.hpp:52
Definition: static_distributed_tensor.hpp:21
static constexpr index_t kThreadElementSpaceSize
Definition: static_distributed_tensor.hpp:31
constexpr CK_TILE_HOST_DEVICE DataType & operator()(TileDistributedIndices)
Definition: static_distributed_tensor.hpp:122
remove_cvref_t< StaticTileDistribution_ > StaticTileDistribution
Definition: static_distributed_tensor.hpp:23
CK_TILE_HOST_DEVICE void set_y_sliced_thread_data(sequence< YSliceOrigins... >, sequence< YSliceLengths... >, const SlicedThreadData &sliced_thread_data)
Definition: static_distributed_tensor.hpp:90
CK_TILE_HOST_DEVICE auto get_y_sliced_thread_data(sequence< YSliceOrigins... >, sequence< YSliceLengths... >) const
Definition: static_distributed_tensor.hpp:66
static constexpr CK_TILE_HOST_DEVICE auto get_lengths()
Definition: static_distributed_tensor.hpp:39
remove_cvref_t< DataType_ > DataType
Definition: static_distributed_tensor.hpp:22
static constexpr CK_TILE_HOST_DEVICE auto get_tile_distribution()
Definition: static_distributed_tensor.hpp:44
static constexpr CK_TILE_HOST_DEVICE auto get_distributed_spans()
Definition: static_distributed_tensor.hpp:49
static constexpr CK_TILE_HOST_DEVICE auto get_num_of_dimension()
Definition: static_distributed_tensor.hpp:34
thread_buffer< DataType, kThreadElementSpaceSize > thread_buf_
Definition: static_distributed_tensor.hpp:134
constexpr CK_TILE_HOST_DEVICE auto & get_thread_buffer()
Definition: static_distributed_tensor.hpp:58
static constexpr CK_TILE_HOST_DEVICE index_t get_thread_buffer_size()
Definition: static_distributed_tensor.hpp:60
remove_cvref_t< decltype(StaticTileDistribution{}.get_ys_to_d_descriptor())> ThreadTensorDesc
Definition: static_distributed_tensor.hpp:29
constexpr CK_TILE_HOST_DEVICE const DataType & operator[](TileDistributedIndices) const
Definition: static_distributed_tensor.hpp:110
constexpr CK_TILE_HOST_DEVICE const auto & get_thread_buffer() const
Definition: static_distributed_tensor.hpp:56
CK_TILE_HOST_DEVICE void initialize(const DataType &x)
Definition: static_distributed_tensor.hpp:54
Definition: functional.hpp:117
Definition: tile_distribution.hpp:72
constexpr CK_TILE_HOST_DEVICE const auto & get_ps_ys_to_xs_adaptor() const
Definition: tile_distribution.hpp:126
static constexpr CK_TILE_HOST_DEVICE auto get_y_indices_from_distributed_indices(DistributedIndices)
Definition: tile_distribution.hpp:205