/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/core/tensor/static_distributed_tensor.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/core/tensor/static_distributed_tensor.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/core/tensor/static_distributed_tensor.hpp Source File
static_distributed_tensor.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
16 
17 namespace ck_tile {
18 
19 template <typename DataType_, typename StaticTileDistribution_>
21 {
24 
25  static_assert(StaticTileDistribution::is_static(),
26  "wrong! StaticTileDistribution should be known at compile tile");
27 
29  remove_cvref_t<decltype(StaticTileDistribution{}.get_ys_to_d_descriptor())>;
30  static constexpr index_t PackedSize =
32 
33  static constexpr index_t kThreadElementSpaceSize = ThreadTensorDesc{}.get_element_space_size();
34  static_assert(0 < kThreadElementSpaceSize, "Make sure tile distribution is valid");
35 
36  CK_TILE_HOST_DEVICE static constexpr auto get_num_of_dimension()
37  {
38  return StaticTileDistribution::get_num_of_dimension_x();
39  }
40 
41  CK_TILE_HOST_DEVICE static constexpr auto get_lengths()
42  {
43  return StaticTileDistribution::get_lengths();
44  }
45 
47  {
48  return StaticTileDistribution{};
49  }
50 
52  {
53  return StaticTileDistribution::get_distributed_spans();
54  }
55 
56  CK_TILE_HOST_DEVICE void initialize(const DataType& x) { thread_buf_.initialize(x); }
57 
58  CK_TILE_HOST_DEVICE constexpr const auto& get_thread_buffer() const { return thread_buf_; }
59 
60  CK_TILE_HOST_DEVICE constexpr auto& get_thread_buffer() { return thread_buf_; }
61 
63  {
65  }
66 
67  template <index_t... YSliceOrigins, index_t... YSliceLengths>
70  {
71  static_assert(sizeof...(YSliceOrigins) == StaticTileDistribution::NDimY &&
72  sizeof...(YSliceLengths) == StaticTileDistribution::NDimY,
73  "wrong!");
74 
75  constexpr auto sliced_thread_tensor_desc =
77 
78  thread_buffer<DataType, sliced_thread_tensor_desc.get_element_space_size()>
79  sliced_thread_data;
80 
81  static_ford<sequence<YSliceLengths...>>{}([&](auto idx) {
82  constexpr auto idx_ys = idx + sequence<YSliceOrigins...>{};
83 
84  sliced_thread_data(
85  number<sliced_thread_tensor_desc.calculate_offset(idx) / PackedSize>{}) =
86  thread_buf_[number<ThreadTensorDesc{}.calculate_offset(idx_ys) / PackedSize>{}];
87  });
88 
89  return sliced_thread_data;
90  }
91 
92  template <index_t... YSliceOrigins, index_t... YSliceLengths, typename SlicedThreadData>
95  const SlicedThreadData& sliced_thread_data)
96  {
97  static_assert(sizeof...(YSliceOrigins) == StaticTileDistribution::NDimY &&
98  sizeof...(YSliceLengths) == StaticTileDistribution::NDimY,
99  "wrong!");
100 
101  constexpr auto sliced_thread_tensor_desc =
103 
104  static_ford<sequence<YSliceLengths...>>{}([&](auto idx) {
105  constexpr auto idx_ys = idx + sequence<YSliceOrigins...>{};
106 
107  thread_buf_(number<ThreadTensorDesc{}.calculate_offset(idx_ys) / PackedSize>{}) =
108  sliced_thread_data[number<sliced_thread_tensor_desc.calculate_offset(idx) /
109  PackedSize>{}];
110  });
111  }
112 
113  template <typename TileDistributedIndices>
114  CK_TILE_HOST_DEVICE constexpr const DataType& operator[](TileDistributedIndices) const
115  {
116  static_assert(is_static_v<TileDistributedIndices>,
117  "wrong! Tile Distributed Indices should be static");
118 
119  constexpr auto y_idx = get_tile_distribution().get_y_indices_from_distributed_indices(
120  TileDistributedIndices{});
121 
122  return thread_buf_[number<ThreadTensorDesc{}.calculate_offset(y_idx) / PackedSize>{}];
123  }
124 
125  template <typename TileDistributedIndices>
126  CK_TILE_HOST_DEVICE constexpr DataType& operator()(TileDistributedIndices)
127  {
128  static_assert(is_static_v<TileDistributedIndices>,
129  "wrong! Tile Distributed Indices should be static");
130 
131  constexpr auto y_idx = get_tile_distribution().get_y_indices_from_distributed_indices(
132  TileDistributedIndices{});
133 
134  return thread_buf_(number<ThreadTensorDesc{}.calculate_offset(y_idx) / PackedSize>{});
135  }
136 
137  //
139 };
140 
141 template <typename DataType, typename StaticTileDistribution>
142 CK_TILE_HOST_DEVICE constexpr auto make_static_distributed_tensor(const StaticTileDistribution&)
143 {
146 }
147 
148 template <typename DataType, typename StaticTileDistribution, typename ThreadBuffer>
149 CK_TILE_HOST_DEVICE constexpr auto make_static_distributed_tensor(const StaticTileDistribution&,
150  ThreadBuffer&& thread_buffer_)
151 {
154 }
155 
156 // get X indices from tuple of tile_distributed_index<>
157 template <typename StaticTileDistribution, typename DistributedIndices>
158 CK_TILE_HOST_DEVICE constexpr auto
160  DistributedIndices distributed_indices)
161 {
162  const auto partition_index = detail::get_partition_index(tile_distribution);
163  constexpr auto y_indices =
165 
166  const auto x_coord = make_tensor_adaptor_coordinate(
168  container_concat(partition_index, to_array<ck_tile::index_t, y_indices.size()>(y_indices)));
169 
170  return x_coord.get_bottom_index();
171 }
172 
173 template <typename DataType, typename StaticTileDistribution, typename XIndicesPredicate>
176  DataType value,
177  XIndicesPredicate predicate)
178 {
179  constexpr auto out_spans =
181  sweep_tile_span(out_spans[number<0>{}], [&](auto idx0) {
182  sweep_tile_span(out_spans[number<1>{}], [&](auto idx1) {
183  constexpr auto distributed_indices = make_tuple(idx0, idx1);
184  const auto x_indices = get_x_indices_from_distributed_indices(StaticTileDistribution{},
185  distributed_indices);
186 
187  if(predicate(x_indices))
188  {
189  out_tensor(distributed_indices) = value;
190  }
191  });
192  });
193 }
194 
195 // this function used inside span loop over
196 template <typename YLengths, index_t XUnpacks>
198 {
199  constexpr auto y_size = reduce_on_sequence(YLengths{}, multiplies{}, number<1>{});
200  constexpr auto y_packs = number<XUnpacks>{};
201  static_assert(y_size % y_packs == 0);
202  constexpr auto y_slice_size = y_size / y_packs;
203 
204  constexpr auto slice_info = slice_sequence(YLengths{}, number<y_slice_size>{});
205  constexpr auto unpacks = slice_info[number<1>{}];
206  return unpacks;
207 }
208 
209 namespace detail {
210 
211 // check if 2 static_distributed_tensor has same data type and size of element
212 // but only difference in distribution
213 template <typename X, typename Y>
215 {
216  static constexpr bool value = false;
217 };
218 
219 template <typename TypeX, typename DistX, typename TypeY, typename DistY>
221  static_distributed_tensor<TypeY, DistY>>
222 {
225  static constexpr bool value = std::is_same_v<typename Tx::DataType, typename Ty::DataType> &&
226  Tx::get_thread_buffer_size() == Ty::get_thread_buffer_size();
227 };
228 
229 template <typename X, typename Y>
230 inline constexpr bool is_similiar_distributed_tensor_v =
232 
233 } // namespace detail
234 
235 } // namespace ck_tile
#define CK_TILE_HOST_DEVICE
Definition: config.hpp:42
constexpr bool is_similiar_distributed_tensor_v
Definition: static_distributed_tensor.hpp:230
CK_TILE_HOST_DEVICE auto get_partition_index(Distribution)
Definition: tile_distribution.hpp:22
Definition: cluster_descriptor.hpp:13
constexpr CK_TILE_HOST_DEVICE auto to_array(const std::vector< X > &x)
Definition: array.hpp:286
constexpr CK_TILE_HOST_DEVICE auto make_tensor_adaptor_coordinate(const Adaptor &adaptor, const TopIndex &idx_top)
Definition: tensor_adaptor_coordinate.hpp:55
int32_t index_t
Definition: integer.hpp:9
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition: type_traits.hpp:21
constexpr CK_TILE_HOST_DEVICE index_t reduce_on_sequence(Seq, Reduce f, number< Init >)
Definition: sequence.hpp:979
constexpr CK_TILE_HOST_DEVICE auto make_static_distributed_tensor(const StaticTileDistribution &)
Definition: static_distributed_tensor.hpp:142
constexpr CK_TILE_HOST_DEVICE auto make_naive_tensor_descriptor_packed(const tuple< Lengths... > &lengths, number< GuaranteedLastDimensionVectorLength >=number<-1 >{})
Definition: tensor_descriptor.hpp:365
CK_TILE_DEVICE void sweep_tile_span(TileDistributedSpan_, const F &f)
Definition: sweep_tile.hpp:20
CK_TILE_HOST_DEVICE void set_tile_if(static_distributed_tensor< DataType, StaticTileDistribution > &out_tensor, DataType value, XIndicesPredicate predicate)
Definition: static_distributed_tensor.hpp:175
constexpr CK_TILE_HOST_DEVICE auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:360
constexpr CK_TILE_HOST_DEVICE auto get_y_unpacks_from_x_unpacks(YLengths, number< XUnpacks >)
Definition: static_distributed_tensor.hpp:197
constexpr CK_TILE_HOST_DEVICE auto get_x_indices_from_distributed_indices(StaticTileDistribution tile_distribution, DistributedIndices distributed_indices)
Definition: static_distributed_tensor.hpp:159
constexpr auto slice_sequence(Seq, number< SliceSize >, Mask=typename uniform_sequence_gen< Seq::size(), 1 >::type{})
Definition: sequence.hpp:1246
impl::is_static_impl< remove_cvref_t< T > > is_static
Definition: type_traits.hpp:87
constexpr CK_TILE_HOST_DEVICE auto container_concat(const X &x, const Ys &... ys)
Definition: container_helper.hpp:363
const GenericPointer< typename T::ValueType > T2 value
Definition: pointer.h:1350
Definition: integral_constant.hpp:13
Definition: static_distributed_tensor.hpp:215
static constexpr bool value
Definition: static_distributed_tensor.hpp:216
Definition: math.hpp:98
Definition: numeric.hpp:81
Definition: sequence.hpp:49
Definition: static_distributed_tensor.hpp:21
static constexpr index_t kThreadElementSpaceSize
Definition: static_distributed_tensor.hpp:33
constexpr CK_TILE_HOST_DEVICE DataType & operator()(TileDistributedIndices)
Definition: static_distributed_tensor.hpp:126
remove_cvref_t< StaticTileDistribution_ > StaticTileDistribution
Definition: static_distributed_tensor.hpp:23
CK_TILE_HOST_DEVICE void set_y_sliced_thread_data(sequence< YSliceOrigins... >, sequence< YSliceLengths... >, const SlicedThreadData &sliced_thread_data)
Definition: static_distributed_tensor.hpp:93
CK_TILE_HOST_DEVICE auto get_y_sliced_thread_data(sequence< YSliceOrigins... >, sequence< YSliceLengths... >) const
Definition: static_distributed_tensor.hpp:68
thread_buffer< DataType, get_thread_buffer_size()> thread_buf_
Definition: static_distributed_tensor.hpp:138
static constexpr CK_TILE_HOST_DEVICE auto get_lengths()
Definition: static_distributed_tensor.hpp:41
remove_cvref_t< DataType_ > DataType
Definition: static_distributed_tensor.hpp:22
static constexpr CK_TILE_HOST_DEVICE auto get_tile_distribution()
Definition: static_distributed_tensor.hpp:46
static constexpr CK_TILE_HOST_DEVICE auto get_distributed_spans()
Definition: static_distributed_tensor.hpp:51
static constexpr index_t PackedSize
Definition: static_distributed_tensor.hpp:30
static constexpr CK_TILE_HOST_DEVICE auto get_num_of_dimension()
Definition: static_distributed_tensor.hpp:36
constexpr CK_TILE_HOST_DEVICE auto & get_thread_buffer()
Definition: static_distributed_tensor.hpp:60
static constexpr CK_TILE_HOST_DEVICE index_t get_thread_buffer_size()
Definition: static_distributed_tensor.hpp:62
remove_cvref_t< decltype(StaticTileDistribution{}.get_ys_to_d_descriptor())> ThreadTensorDesc
Definition: static_distributed_tensor.hpp:29
constexpr CK_TILE_HOST_DEVICE const DataType & operator[](TileDistributedIndices) const
Definition: static_distributed_tensor.hpp:114
constexpr CK_TILE_HOST_DEVICE const auto & get_thread_buffer() const
Definition: static_distributed_tensor.hpp:58
CK_TILE_HOST_DEVICE void initialize(const DataType &x)
Definition: static_distributed_tensor.hpp:56
Definition: functional.hpp:141
Definition: debug.hpp:67
Definition: tile_distribution.hpp:72
constexpr CK_TILE_HOST_DEVICE const auto & get_ps_ys_to_xs_adaptor() const
Definition: tile_distribution.hpp:126
static constexpr CK_TILE_HOST_DEVICE auto get_y_indices_from_distributed_indices(DistributedIndices)
Definition: tile_distribution.hpp:205