/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/core/tensor/tile_distribution.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/core/tensor/tile_distribution.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/core/tensor/tile_distribution.hpp Source File
tile_distribution.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
17 
18 namespace ck_tile {
19 
20 namespace detail {
21 template <typename Distribution>
23 {
24  return Distribution::_get_partition_index();
25 }
26 } // namespace detail
27 
28 // distributed span
29 template <index_t... PartialHsLengths>
31 {
32  using Impl = sequence<PartialHsLengths...>;
33 
34  static constexpr auto impl_ = Impl{};
35 
36  CK_TILE_HOST_DEVICE static constexpr bool is_static() { return true; }
37 };
38 
39 // distributed index
40 template <index_t... PartialHsIndices>
42 {
43  using Impl = sequence<PartialHsIndices...>;
44 
45  static constexpr auto impl_ = Impl{};
46 
47  CK_TILE_HOST_DEVICE static constexpr bool is_static() { return true; }
48 };
49 
50 namespace detail {
51 
52 template <index_t... Is>
54 {
55  return tile_distributed_span<Is...>{};
56 }
57 
58 template <index_t... Is>
60 {
61  return tile_distributed_index<Is...>{};
62 }
63 
64 } // namespace detail
65 
66 template <typename PsYs2XsAdaptor_,
67  typename Ys2DDescriptor_,
68  typename StaticTileDistributionEncoding_,
69  typename TileDistributionDetail_> // FIXME: this is for hold ad-hoc but useful info,
70  // should be more elegnat
72 {
77 
79  "wrong! should be static");
80 
81  static constexpr index_t NDimX = PsYs2XsAdaptor::get_num_of_bottom_dimension();
82  static constexpr index_t NDimY = Ys2DDescriptor::get_num_of_top_dimension();
83  static constexpr index_t NDimP = PsYs2XsAdaptor::get_num_of_top_dimension() - NDimY;
84  static constexpr index_t NDimR = StaticTileDistributionEncoding_::NDimR;
85 
88 
93 
95  {
96  // only support warp-tile and block-tile
97  static_assert(NDimP == 1 or NDimP == 2, "wrong!");
98 
99  if constexpr(NDimP == 1)
100  {
101  return array<index_t, 1>{get_lane_id()};
102  }
103  else if constexpr(NDimP == 2)
104  {
105  return array<index_t, 2>{get_warp_id(), get_lane_id()};
106  }
107  }
108 
109  CK_TILE_HOST_DEVICE static constexpr auto get_lengths()
110  {
111 #if 0
112  // FIXME: tensor_adaptor::GetBottomDimensionLengths is wrong. re-enable this after it's fixed
113  ps_ys_to_xs_.GetBottomDimensionLengths();
114 #else
115  return generate_tuple(
116  [&](auto i) {
117  constexpr index_t x_length =
118  container_reduce(typename DstrEncode::HsLengthss{}[i], multiplies{}, 1);
119 
120  return number<x_length>{};
121  },
122  number<NDimX>{});
123 #endif
124  }
125 
126  CK_TILE_HOST_DEVICE constexpr const auto& get_ps_ys_to_xs_adaptor() const
127  {
128  return ps_ys_to_xs_;
129  }
130 
131  CK_TILE_HOST_DEVICE constexpr const auto& get_ys_to_d_descriptor() const { return ys_to_d_; }
132 
134  {
135  return DstrEncode{};
136  }
137 
138 #if 1
139  // Calculate Replication index [R0, R1, ...] based on Partion index
140  // FIXME: very nasty implementation
141  template <typename PartitionIndex>
142  CK_TILE_HOST_DEVICE auto calculate_rs_index_from_ps_index(const PartitionIndex& ps_idx) const
143  {
144  static_assert(PartitionIndex::size() == NDimP, "wrong!");
145 
146  const auto ps_ys_idx = container_concat(ps_idx, array<index_t, NDimY>{0});
147 
148  const auto dummy_adaptor_coord = make_tensor_adaptor_coordinate(ps_ys_to_xs_, ps_ys_idx);
149 
150  array<index_t, NDimR> rs_idx;
151 
152  static_for<0, NDimP, 1>{}([&](auto idim_p) {
153  constexpr index_t ndim_low = DstrEncode::ps_to_rhss_major_[idim_p].size();
154 
155  static_for<0, ndim_low, 1>{}([&](auto i) {
156  constexpr index_t rh_major = DstrEncode::ps_to_rhss_major_[idim_p][i];
157  constexpr index_t rh_minor = DstrEncode::ps_to_rhss_minor_[idim_p][i];
158 
159  // 0-th rh_major is the replicate dimension
160  if constexpr(rh_major == 0)
161  {
162  constexpr index_t adaptor_hidden_id =
163  DstrDetail::rh_major_minor_to_adaptor_hidden_idss_[rh_major][rh_minor];
164 
165  // fill in
166  rs_idx(rh_minor) = dummy_adaptor_coord.get_hidden_index()[adaptor_hidden_id];
167  }
168  });
169  });
170 
171  return rs_idx;
172  }
173 #endif
174 
175  template <typename PartitionIndex = decltype(_get_partition_index())>
177  calculate_index(const PartitionIndex& ps_idx = _get_partition_index()) const
178  {
179  const auto ps_ys_idx = container_concat(ps_idx, array<index_t, NDimY>{0});
180  const auto window_adaptor_thread_coord_tmp =
182  return window_adaptor_thread_coord_tmp.get_bottom_index();
183  }
184 
186  {
187  constexpr auto distributed_spans_impl = DstrEncode::detail::distributed_spans_lengthss_;
188  constexpr auto ndims_spans_minor = DstrEncode::detail::ndims_distributed_spans_minor_;
189 
190  return generate_tuple(
191  [&](auto i) {
192  constexpr auto span_impl = distributed_spans_impl[i];
193  constexpr index_t ndim_span_minor = ndims_spans_minor[i];
194 
195  constexpr auto span = TO_SEQUENCE(span_impl, ndim_span_minor);
196 
198  },
199  number<NDimX>{});
200  }
201 
202  // FIXME: it's hacky to get Y index from Distributed-Index
203  template <typename DistributedIndices>
204  CK_TILE_HOST_DEVICE static constexpr auto
206  {
207  constexpr auto ys_idx_arr = [] {
208  array<index_t, NDimY> ys_idx;
209 
210  static_for<0, NDimY, 1>{}([&](auto i) {
211  constexpr index_t span_major = DstrEncode::detail::ys_to_span_major_[i];
212  constexpr index_t span_minor = DstrEncode::detail::ys_to_span_minor_[i];
213 
214  constexpr auto dstr_index = DistributedIndices{}[number<span_major>{}];
215 
216  ys_idx(i) = dstr_index.impl_[span_minor];
217  });
218 
219  return ys_idx;
220  }();
221 
222  constexpr index_t ndim_y = NDimY;
223 
224  return TO_SEQUENCE(ys_idx_arr, ndim_y);
225  }
226 
227  CK_TILE_HOST_DEVICE static constexpr bool is_static()
228  {
230  }
231 };
232 
233 namespace detail {
234 
235 template <index_t NDimMax>
237 {
239 
240  for(index_t i = 0; i < iend - ibegin; ++i)
241  {
242  arr(i) = ibegin + i;
243  }
244 
245  return arr;
246 }
247 
248 // this returns a constexpr encoding of tile_distribution
249 template <typename StaticTileDistributionEncoding_>
250 CK_TILE_HOST_DEVICE constexpr auto
251 make_adaptor_encoding_for_tile_distribution(StaticTileDistributionEncoding_)
252 {
253  using RsLengths = typename StaticTileDistributionEncoding_::RsLengths;
254  using HsLengthss = typename StaticTileDistributionEncoding_::HsLengthss;
255  using Ps2RHssMajor = typename StaticTileDistributionEncoding_::Ps2RHssMajor;
256  using Ps2RHssMinor = typename StaticTileDistributionEncoding_::Ps2RHssMinor;
257  using Ys2RHsMajor = typename StaticTileDistributionEncoding_::Ys2RHsMajor;
258  using Ys2RHsMinor = typename StaticTileDistributionEncoding_::Ys2RHsMinor;
259 
260  // FIXME: increase max value if fail
261  constexpr index_t kMaxNumTransforms = 20;
262  constexpr index_t kMaxMetaDataSize = 128;
263  constexpr index_t kMaxNumDim = 10;
264 
265  using Name = coord_transform_enum;
266  using MetaData = meta_data_buffer<kMaxMetaDataSize>;
267  using NumDim = index_t;
268  using Dims = array<index_t, kMaxNumDim>;
269  using Lengths = array<index_t, kMaxNumDim>;
270 
271  // Tile Adaptor
272  // bottom dims [x0, x1, x2, ...]
273  // top dims [p0, p1, ..., y0, y1, ...]
274  constexpr index_t ndim_x = HsLengthss::size();
275 
276  // Dim Ids: [idim_x_major, idim_x_minor] to [idim_hidden]
277  array<array<index_t, kMaxNumDim>, ndim_x + 1> rh_major_minor_to_hidden_ids;
278  array<array<index_t, kMaxNumDim>, ndim_x + 1> rh_major_minor_to_hidden_lengths;
279 
280  auto trans = array<tuple<Name, MetaData, NumDim, Dims, NumDim, Dims>, kMaxNumTransforms>{};
281 
282  index_t num_tran = 0;
283  index_t hidden_dim_cnt = ndim_x;
284 
285  // this is replicate transform
286  {
287  constexpr index_t ndim_r_minor = RsLengths::size();
288 
289  constexpr auto r_minor_lengths = RsLengths{};
290 
291  trans(num_tran++) = {
293  MetaData{to_array<index_t, ndim_r_minor>(r_minor_lengths)},
294  NumDim{0},
295  Dims{},
296  NumDim{ndim_r_minor},
297  make_sequential_index<kMaxNumDim>(hidden_dim_cnt, hidden_dim_cnt + ndim_r_minor)};
298 
299  for(index_t i = 0; i < ndim_r_minor; ++i)
300  {
301  rh_major_minor_to_hidden_ids(0)(i) = hidden_dim_cnt;
302  rh_major_minor_to_hidden_lengths(0)(i) = r_minor_lengths[i];
303 
304  hidden_dim_cnt++;
305  }
306  };
307 
308  // these are Unmerge transforms for X dimesions
309  static_for<0, ndim_x, 1>{}([&trans,
310  &num_tran,
311  &hidden_dim_cnt,
312  &rh_major_minor_to_hidden_ids,
313  &rh_major_minor_to_hidden_lengths](auto idim_x) {
314  // typename HsLengthss::base{}.foo();
315  constexpr auto h_minor_lengths =
316  HsLengthss{}.get(idim_x); // std::tuple_element_t<idim_x, HsLengthss>{};
317  // constexpr auto h_minor_lengths = impl::getv<idim_x>(HsLengthss{});
318 
319  constexpr index_t ndim_h_minor = h_minor_lengths.size();
320 
321  trans(num_tran++) = {
323  MetaData{to_array<index_t, ndim_h_minor>(h_minor_lengths)},
324  NumDim{1},
325  Dims{idim_x},
326  NumDim{ndim_h_minor},
327  make_sequential_index<kMaxNumDim>(hidden_dim_cnt, hidden_dim_cnt + ndim_h_minor)};
328 
329  for(index_t i = 0; i < ndim_h_minor; ++i)
330  {
331  rh_major_minor_to_hidden_ids(idim_x + 1)(i) = hidden_dim_cnt;
332  rh_major_minor_to_hidden_lengths(idim_x + 1)(i) = h_minor_lengths[i];
333 
334  hidden_dim_cnt++;
335  }
336  });
337 
338  // transform: P dimensions
339  constexpr index_t ndim_p = Ps2RHssMajor::size();
340 
341  Dims hidden_dim_id_ps;
342 
343  static_for<0, ndim_p, 1>{}([&](auto iDimP) {
344  //
345  index_t hidden_dim_id_p = hidden_dim_cnt++;
346 
347  hidden_dim_id_ps(iDimP) = hidden_dim_id_p;
348 
349  constexpr auto p2RHsMajor = Ps2RHssMajor{}[iDimP];
350  constexpr auto p2RHsMinor = Ps2RHssMinor{}[iDimP];
351 
352  static_assert(p2RHsMajor.size() == p2RHsMinor.size(), "wrong!");
353 
354  constexpr index_t ndim_low = p2RHsMajor.size();
355 
356  Dims low_dims;
357  Lengths low_lengths;
358 
359  for(index_t i = 0; i < ndim_low; ++i)
360  {
361  index_t rh_major = p2RHsMajor[i];
362  index_t rh_minor = p2RHsMinor[i];
363  low_dims(i) = rh_major_minor_to_hidden_ids[rh_major][rh_minor];
364  low_lengths(i) = rh_major_minor_to_hidden_lengths[rh_major][rh_minor];
365  }
366 
367  trans(num_tran++) = {coord_transform_enum::merge,
368  MetaData{to_array<index_t, ndim_low>(low_lengths)},
369  NumDim{ndim_low},
370  low_dims,
371  NumDim{1},
372  Dims{hidden_dim_id_p}};
373  });
374 
375  constexpr index_t ndim_bottom = ndim_x;
376 
377  constexpr auto bottom_dim_ids = make_sequential_index<kMaxNumDim>(0, ndim_bottom);
378 
379  constexpr auto ys_to_rhs_major = Ys2RHsMajor{};
380  constexpr auto ys_to_rhs_minor = Ys2RHsMinor{};
381 
382  constexpr index_t ndim_y = Ys2RHsMajor::size();
383  constexpr index_t ndim_top = ndim_p + ndim_y;
384 
385  auto top_dim_ids = hidden_dim_id_ps;
386 
387  {
388  for(index_t i = 0; i < ndim_y; ++i)
389  {
390  index_t rh_major = ys_to_rhs_major[i];
391  index_t rh_minor = ys_to_rhs_minor[i];
392  top_dim_ids(ndim_p + i) = rh_major_minor_to_hidden_ids[rh_major][rh_minor];
393  }
394  }
395 
396  //
397  const auto ps_ys_to_xs_adaptor_encoding =
398  make_tuple(trans, num_tran, bottom_dim_ids, ndim_bottom, top_dim_ids, ndim_top);
399 
400  // descriptor: [y0, y1, ...] to [d]
401  Lengths y_lengths;
402  index_t d_length = 1;
403 
404  for(index_t i = 0; i < ndim_y; ++i)
405  {
406  index_t rh_major = ys_to_rhs_major[i];
407  index_t rh_minor = ys_to_rhs_minor[i];
408  index_t y_length = rh_major_minor_to_hidden_lengths[rh_major][rh_minor];
409  y_lengths(i) = y_length;
410  d_length *= y_length;
411  }
412 
414  MetaData{to_array<index_t, ndim_y>(y_lengths)},
415  NumDim{1},
416  Dims{0},
417  NumDim{ndim_y},
418  make_sequential_index<kMaxNumDim>(1, ndim_y + 1));
419 
420  const auto ys_to_d_adaptor_encoding = make_tuple(
421  make_tuple(tran), 1, Dims{0}, 1, make_sequential_index<kMaxNumDim>(1, ndim_y + 1), ndim_y);
422 
423  return make_tuple(ps_ys_to_xs_adaptor_encoding,
424  ys_to_d_adaptor_encoding,
425  d_length,
426  rh_major_minor_to_hidden_ids);
427 }
428 
429 // FIXME: this is nasty. Move it inside TileDistributionEncoding::detail
430 template <typename RhMajorMinor2AdaptorHiddenIdss> // tuple<sequence<...>, ...>
432 {
434  to_array_of_array(RhMajorMinor2AdaptorHiddenIdss{});
435 };
436 
437 } // namespace detail
438 
439 #if 0
440 // this returns a constexpr tile_distribution
441 template <typename StaticTileDistributionEncoding_>
442 CK_TILE_HOST_DEVICE constexpr auto make_tile_distribution(StaticTileDistributionEncoding_)
443 {
444  using DstrEncode = remove_cvref_t<StaticTileDistributionEncoding_>;
445 
446  constexpr auto adaptor_impl =
447  detail::make_adaptor_encoding_for_tile_distribution(StaticTileDistributionEncoding_{});
448 
449  constexpr auto ps_ys_to_xs_adaptor_impl = adaptor_impl.template at<0>();
450  constexpr auto ys_to_d_adaptor_impl = adaptor_impl.template at<1>();
451  constexpr index_t d_length = adaptor_impl.template at<2>();
452  constexpr auto rh_major_minor_to_hidden_ids_impl = adaptor_impl.template at<3>();
453 
454  constexpr auto ps_ys_to_xs_adaptor =
455  CONSTRUCT_TENSOR_ADAPTOR_FROM_ENCODING(ps_ys_to_xs_adaptor_impl);
456 
457  constexpr auto ys_to_d_adaptor = CONSTRUCT_TENSOR_ADAPTOR_FROM_ENCODING(ys_to_d_adaptor_impl);
458 
459  constexpr auto ys_to_d_descriptor =
460  make_tensor_descriptor_from_adaptor(ys_to_d_adaptor, d_length);
461 
462  //
463  constexpr index_t ndim_rh_major = DstrEncode::detail::ndim_rh_major_;
464  constexpr auto ndims_rhs_minor = DstrEncode::detail::ndims_rhs_minor_;
465 
466  constexpr auto rh_major_minor_to_hidden_ids =
467  TO_TUPLE_OF_SEQUENCE(rh_major_minor_to_hidden_ids_impl, ndim_rh_major, ndims_rhs_minor);
468 
469  return tile_distribution<
470  remove_cvref_t<decltype(ps_ys_to_xs_adaptor)>,
471  remove_cvref_t<decltype(ys_to_d_descriptor)>,
472  remove_cvref_t<DstrEncode>,
473  detail::tile_distribution_detail<remove_cvref_t<decltype(rh_major_minor_to_hidden_ids)>>>{
474  ps_ys_to_xs_adaptor, ys_to_d_descriptor};
475 }
476 #endif
477 
478 // this returns a static tile_distribution
479 template <typename StaticTileDistributionEncoding_>
480 CK_TILE_HOST_DEVICE constexpr auto make_static_tile_distribution(StaticTileDistributionEncoding_)
481 {
483 
484  constexpr auto adaptor_impl =
485  detail::make_adaptor_encoding_for_tile_distribution(StaticTileDistributionEncoding_{});
486 
487  constexpr auto ps_ys_to_xs_adaptor_impl = adaptor_impl.template at<0>();
488  constexpr auto ys_to_d_adaptor_impl = adaptor_impl.template at<1>();
489  constexpr index_t d_length = adaptor_impl.template at<2>();
490  constexpr auto rh_major_minor_to_hidden_ids_impl = adaptor_impl.template at<3>();
491 
492  constexpr auto ps_ys_to_xs_adaptor =
493  CONSTRUCT_STATIC_TENSOR_ADAPTOR_FROM_ENCODING(ps_ys_to_xs_adaptor_impl);
494 
495  constexpr auto ys_to_d_adaptor =
496  CONSTRUCT_STATIC_TENSOR_ADAPTOR_FROM_ENCODING(ys_to_d_adaptor_impl);
497 
498  constexpr auto ys_to_d_descriptor =
500 
501  //
502  constexpr index_t ndim_rh_major = DstrEncode::detail::ndim_rh_major_;
503  constexpr auto ndims_rhs_minor = DstrEncode::detail::ndims_rhs_minor_;
504 
505  constexpr auto rh_major_minor_to_hidden_ids =
506  TO_TUPLE_OF_SEQUENCE(rh_major_minor_to_hidden_ids_impl, ndim_rh_major, ndims_rhs_minor);
507 
508  return tile_distribution<
509  remove_cvref_t<decltype(ps_ys_to_xs_adaptor)>,
510  remove_cvref_t<decltype(ys_to_d_descriptor)>,
512  detail::tile_distribution_detail<remove_cvref_t<decltype(rh_major_minor_to_hidden_ids)>>>{
513  ps_ys_to_xs_adaptor, ys_to_d_descriptor};
514 }
515 
516 //***********************************************************************************
517 
518 namespace detail {
519 //
520 // slice tensor from x_dim, result in split in y_dim, not p_dim.
521 // We don't support slice cross p_dim (aka, slice different threads)
522 // also, sliced along y_dim need be the first dim of current dim.
523 // Multiply Y dim before sliced dim does not make sense
524 //
525 // e.g
526 // X0 X1
527 // <1, 4, 32> - <4, 1, 4, 2, 4> | slice start:<0, 0>, end:<-1, 32>, (-1 means the last one)
528 // Y P P Y P Y P Y
529 // => <1, 4, 32> - <1, 1, 4, 2, 4> -> OK
530 // |--> slice along this Y dim, is the first dim of X1, totally 4 slices
531 //
532 // X0 X1
533 // <1, 4, 32> - <4, 1, 4, 2, 4> | slice start:<0, 0>, end:<-1, 8>, (-1 means the last one)
534 // Y P P Y P Y P Y
535 // => <1, 4, 32> - <1, 1, 1, 2, 4> -> OK
536 // |--> slice along this Y dim, the P dim is 1 in the left, so is OK
537 // totally 16 slices
538 //
539 // X0 X1
540 // <1, 4, 32> - <4, 1, 4, 2, 4> | slice start:<0, 0>, end:<-1, 4>, (-1 means the last one)
541 // Y P P Y P Y P Y
542 // => <1, 4, 32> - <1, 1, 1, 1, 4> -> Fail
543 // |--> slice along this P dim, will split threads, not supported
544 //
545 // X0 X1
546 // <1, 4, 32> - <4, 1, 4, 2, 4> | slice start:<0, 0>, end:<-1, 16>, (-1 means the last one)
547 // Y P P Y P Y P Y
548 // => <1, 4, 32> - <1, 1, 2, 2, 4> -> OK
549 // |--> slice along this Y dim, but this Y sim need to split into 2
550 // subdime
551 // the P dim in the left is 1, means actually not crossing P
552 //
553 template <typename Distribution, index_t... XSliceBegins, index_t... XSliceEnds>
555  Distribution, sequence<XSliceBegins...> x_slice_begins, sequence<XSliceEnds...> x_slice_ends)
556 {
557  // NOTE: this function need to be called under constexpr context,
558  // due to https://wg21.link/p2280r0 we have to use non-reference type for distribution
559  using Encoding = decltype(Distribution::get_static_tile_distribution_encoding());
560 
561  static_assert(sizeof...(XSliceBegins) == sizeof...(XSliceEnds));
562  static_assert(sizeof...(XSliceBegins) == Encoding::NDimX, "only support slice over h, not r");
563 
564  constexpr auto p_len_over_h = Encoding::detail::get_uniformed_p_dim_lengths_over_h();
565 
566  constexpr auto x_slice_ends_ = generate_sequence_v2(
567  [&](auto i) {
568  if constexpr(x_slice_ends[i] == -1)
569  {
570  // -1 means till the end
571  constexpr auto x_length_ =
572  container_reduce(typename Encoding::HsLengthss{}[i], multiplies{}, number<1>{});
573  return x_length_;
574  }
575  else
576  {
577  return x_slice_ends[i];
578  }
579  },
580  number<x_slice_ends.size()>{});
581 
582  constexpr auto x_slice_lengths = x_slice_ends_ - x_slice_begins;
583 
584  constexpr auto x_slice_lengths_without_p = generate_sequence_v2(
585  [&](auto i) constexpr {
586  constexpr auto len_ = x_slice_lengths[i];
587  static_assert(len_ % p_len_over_h[i] == 0,
588  "slice length must be dividable by p_len_over_h");
589  return number<len_ / p_len_over_h[i]>{};
590  },
591  number<x_slice_lengths.size()>{});
592 
593  constexpr auto src_h_prefix_sum = Encoding::detail::get_h_dim_lengths_prefix_sum();
594  constexpr auto src_y_info = Encoding::detail::get_sorted_y_to_h_info();
595  constexpr auto src_y_dims = src_y_info[number<0>{}];
596  constexpr auto src_y_maps = src_y_info[number<1>{}];
597  constexpr auto src_y_prefix_sum = src_y_info[number<2>{}];
598 
599  constexpr auto sliced_hlen_yidx_ylen = [&]() constexpr {
600  auto y_slice_sorted_origins = make_zero_multi_index<Encoding::NDimY>();
601  auto y_slice_lengths = Encoding::detail::ys_lengths_;
602  constexpr auto y_to_h_masks = Encoding::detail::get_y_to_h_masks();
603 
604  // This lambda will modify some value outside, so c++ will not treat return value as
605  // constexpr
606  // TODO: ugly
607  auto new_h_lengths = transform_tuples(
608  [&](auto h_len, auto id) {
609  constexpr auto sliced_h = reverse_slice_sequence(
610  h_len, number<x_slice_lengths_without_p[id]>{}, y_to_h_masks[id]);
611 
612  constexpr auto sliced_h_lens = sliced_h[number<0>{}];
613  constexpr auto sliced_h_index = sliced_h[number<2>{}];
614 
615  // update y_slice_lengths
616  constexpr auto uniformed_h_index = sliced_h_index + number<src_h_prefix_sum[id]>{};
617  constexpr auto found_y_index = container_find(src_y_dims, uniformed_h_index);
618  constexpr auto y_to_h_dim_end = src_y_prefix_sum[id + 1];
619 
620  static_assert(found_y_index >= 0 && found_y_index < src_y_dims.size(),
621  "not sliced at y dim, please check");
622 
623  {
624  constexpr auto sliced_y_to_h_lens =
625  pick_sequence_elements_by_mask(sliced_h_lens, y_to_h_masks[id]);
626  constexpr auto sliced_y_to_h_dims = sliced_y_to_h_lens.size();
628  y_slice_lengths(src_y_maps[y_to_h_dim_end - 1 - i]) =
629  sliced_y_to_h_lens[sliced_y_to_h_dims - 1 - i];
630  });
631  }
632  // TODO: add validations not across p dim
633 
634  // NOTE: this y_origin is for all dims, not only current dim
635  // will later use pick to select target dim
636  constexpr auto y_origin = [&]() {
637  // can't use Encoding::Ys2RHsMajor/Ys2RHsMinor, these are unordered
638  constexpr auto y_to_h_len =
639  pick_sequence_elements_by_mask(h_len, y_to_h_masks[id]);
640  constexpr auto y_to_h_dims = y_to_h_len.size();
641 
642  constexpr auto h_trans = make_merge_transform_v3_division_mod(y_to_h_len);
643  auto h_origin_ = make_zero_multi_index<h_trans.NDimLow>();
644  constexpr auto y_begin_ = x_slice_begins[id] / p_len_over_h[id];
645  h_trans.calculate_lower_index(h_origin_, sequence<y_begin_.value>{});
646 
647  auto y_origin_ = make_zero_multi_index<Encoding::NDimY>();
648 
649  static_for<0, y_to_h_dims, 1>{}([&](auto i) {
650  y_origin_(y_to_h_dim_end - 1 - i) = h_origin_[y_to_h_dims - 1 - i];
651  });
652  return y_origin_;
653  }();
654 
655  constexpr auto y_picks = typename arithmetic_sequence_gen<src_y_prefix_sum[id],
656  src_y_prefix_sum[id + 1],
657  1>::type{};
658 
660  y_slice_sorted_origins, y_picks, get_container_subset(y_origin, y_picks));
661  return sliced_h_lens;
662  },
663  typename Encoding::HsLengthss{},
664  typename arithmetic_sequence_gen<0, Encoding::HsLengthss::size(), 1>::type{});
665 
666  auto y_slice_origins = container_reorder_given_old2new(y_slice_sorted_origins, src_y_maps);
667 
668  return make_tuple(new_h_lengths, y_slice_origins, y_slice_lengths);
669  }();
670 
671  constexpr auto sliced_h_lengths = sliced_hlen_yidx_ylen[number<0>{}];
672  constexpr auto sliced_y_origins_array = sliced_hlen_yidx_ylen[number<1>{}];
673  constexpr auto sliced_y_origins_size = sliced_y_origins_array.size();
674  constexpr auto sliced_y_lengths_array = sliced_hlen_yidx_ylen[number<2>{}];
675  constexpr auto sliced_y_lengths_size = sliced_y_lengths_array.size();
676 
677  constexpr auto sliced_y_origins = TO_SEQUENCE(sliced_y_origins_array, sliced_y_origins_size);
678  constexpr auto sliced_y_lengths = TO_SEQUENCE(sliced_y_lengths_array, sliced_y_lengths_size);
679 
680  return make_tuple(
682  tile_distribution_encoding<typename Encoding::RsLengths,
683  remove_cvref_t<decltype(sliced_h_lengths)>, // only need to
684  // change the
685  // h_lengths type
686  typename Encoding::Ps2RHssMajor,
687  typename Encoding::Ps2RHssMinor,
688  typename Encoding::Ys2RHsMajor,
689  typename Encoding::Ys2RHsMinor>{}),
690  sliced_y_origins,
691  sliced_y_lengths);
692 }
693 
694 } // namespace detail
695 
696 // Free print function for tile_distribution
697 template <typename PsYs2XsAdaptor_,
698  typename Ys2DDescriptor_,
699  typename StaticTileDistributionEncoding_,
700  typename TileDistributionDetail_>
701 CK_TILE_HOST_DEVICE void print(const tile_distribution<PsYs2XsAdaptor_,
702  Ys2DDescriptor_,
703  StaticTileDistributionEncoding_,
704  TileDistributionDetail_>& distribution)
705 {
706  printf("tile_distribution{");
707  printf("tile_distribution_encoding: ");
708  print(StaticTileDistributionEncoding_{});
709  printf(", ");
710  printf("ps_ys_to_xs_: ");
711  print(distribution.ps_ys_to_xs_);
712  printf(", ");
713  printf("ys_to_d_: ");
714  print(distribution.ys_to_d_);
715  printf("}\n");
716 }
717 
718 } // namespace ck_tile
Definition: span.hpp:18
Concept for encoding of Unicode characters.
#define CK_TILE_HOST_DEVICE
Definition: config.hpp:42
constexpr CK_TILE_HOST_DEVICE auto make_sequential_index(index_t ibegin, index_t iend)
Definition: tile_distribution.hpp:236
constexpr CK_TILE_HOST_DEVICE auto make_tile_distributed_span(sequence< Is... >)
Definition: tile_distribution.hpp:53
constexpr CK_TILE_HOST_DEVICE auto slice_distribution_from_x(Distribution, sequence< XSliceBegins... > x_slice_begins, sequence< XSliceEnds... > x_slice_ends)
Definition: tile_distribution.hpp:554
constexpr CK_TILE_HOST_DEVICE auto make_tile_distributed_index(sequence< Is... >)
Definition: tile_distribution.hpp:59
constexpr CK_TILE_HOST_DEVICE auto make_adaptor_encoding_for_tile_distribution(StaticTileDistributionEncoding_)
Definition: tile_distribution.hpp:251
CK_TILE_HOST_DEVICE auto get_partition_index(Distribution)
Definition: tile_distribution.hpp:22
Definition: cluster_descriptor.hpp:13
constexpr CK_TILE_HOST_DEVICE auto make_zero_multi_index()
Definition: multi_index.hpp:26
constexpr CK_TILE_HOST_DEVICE auto container_reorder_given_old2new(const array< TData, NSize > &old_array, sequence< IRs... > old2new)
Definition: container_helper.hpp:48
constexpr CK_TILE_HOST_DEVICE auto container_reduce(const Container &x, Reduce reduce, Init init, number< IBegin >=number< 0 >{}, number< IEnd >=number< Container::size()>{}, number< IStep >=number< 1 >{})
Definition: container_helper.hpp:198
coord_transform_enum
Definition: coordinate_transform.hpp:17
constexpr CK_TILE_HOST_DEVICE auto pick_sequence_elements_by_mask(Seq, Mask)
Definition: sequence.hpp:942
constexpr CK_TILE_HOST_DEVICE void set_container_subset(array< T, N > &y, sequence< Is... > picks, const array< T, sizeof...(Is)> &x)
Definition: container_helper.hpp:420
constexpr CK_TILE_HOST_DEVICE auto make_tensor_adaptor_coordinate(const Adaptor &adaptor, const TopIndex &idx_top)
Definition: tensor_adaptor_coordinate.hpp:55
constexpr auto reverse_slice_sequence(Seq, number< SliceSize >, Mask=typename uniform_sequence_gen< Seq::size(), 1 >::type{})
Definition: sequence.hpp:1220
constexpr CK_TILE_HOST_DEVICE auto to_array_of_array(tuple< Seqs... > t_of_s)
Definition: tuple.hpp:630
int32_t index_t
Definition: integer.hpp:9
CK_TILE_HOST_DEVICE void print(const tile_distribution_encoding_pattern_2d< BlockSize, YPerTile, XPerTile, VecSize, DistributionPattern, NumWaveGroups > &)
Definition: static_encoding_pattern.hpp:342
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition: type_traits.hpp:21
constexpr CK_TILE_HOST_DEVICE auto generate_sequence_v2(F &&f, number< N >)
Definition: sequence.hpp:1042
constexpr CK_TILE_HOST_DEVICE auto make_tensor_descriptor_from_adaptor(const Adaptor &adaptor, const ElementSpaceSize &element_space_size)
Definition: tensor_descriptor.hpp:177
constexpr CK_TILE_HOST_DEVICE auto make_merge_transform_v3_division_mod(const LowLengths &low_lengths)
Definition: coordinate_transform.hpp:1609
constexpr CK_TILE_HOST_DEVICE auto generate_tuple(F &&f, number< N >)
Definition: tuple.hpp:429
constexpr index_t container_find(sequence< Is... > seq, index_t value)
Definition: container_helper.hpp:447
constexpr CK_TILE_HOST_DEVICE auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:360
constexpr CK_TILE_HOST_DEVICE auto get_container_subset(const array< T, N > &arr, sequence< Is... >)
Definition: container_helper.hpp:389
constexpr CK_TILE_HOST_DEVICE auto make_static_tile_distribution(StaticTileDistributionEncoding_)
Definition: tile_distribution.hpp:480
impl::is_static_impl< remove_cvref_t< T > > is_static
Definition: type_traits.hpp:87
constexpr CK_TILE_HOST_DEVICE auto container_concat(const X &x, const Ys &... ys)
Definition: container_helper.hpp:363
constexpr CK_TILE_HOST_DEVICE auto transform_tuples(F f, const X &x)
Definition: tuple.hpp:505
Definition: sequence.hpp:284
A fixed-size array container similar to std::array with additional utilities.
Definition: array.hpp:43
Definition: integral_constant.hpp:13
Definition: tile_distribution.hpp:432
static constexpr auto rh_major_minor_to_adaptor_hidden_idss_
Definition: tile_distribution.hpp:433
Definition: meta_data_buffer.hpp:16
Definition: math.hpp:98
Definition: sequence.hpp:49
static constexpr CK_TILE_HOST_DEVICE index_t size()
Definition: sequence.hpp:53
Definition: functional.hpp:43
Definition: tile_distribution.hpp:42
static constexpr CK_TILE_HOST_DEVICE bool is_static()
Definition: tile_distribution.hpp:47
static constexpr auto impl_
Definition: tile_distribution.hpp:45
Definition: tile_distribution.hpp:31
static constexpr auto impl_
Definition: tile_distribution.hpp:34
static constexpr CK_TILE_HOST_DEVICE bool is_static()
Definition: tile_distribution.hpp:36
Definition: tile_distribution_encoding.hpp:26
Definition: tile_distribution.hpp:72
remove_cvref_t< Ys2DDescriptor_ > Ys2DDescriptor
Definition: tile_distribution.hpp:74
PsYs2XsAdaptor ps_ys_to_xs_
Definition: tile_distribution.hpp:86
static constexpr CK_TILE_HOST_DEVICE auto get_distributed_spans()
Definition: tile_distribution.hpp:185
static CK_TILE_HOST_DEVICE auto _get_partition_index()
Definition: tile_distribution.hpp:94
constexpr CK_TILE_HOST_DEVICE const auto & get_ps_ys_to_xs_adaptor() const
Definition: tile_distribution.hpp:126
static constexpr index_t NDimY
Definition: tile_distribution.hpp:82
remove_cvref_t< StaticTileDistributionEncoding_ > DstrEncode
Definition: tile_distribution.hpp:75
remove_cvref_t< TileDistributionDetail_ > DstrDetail
Definition: tile_distribution.hpp:76
CK_TILE_HOST_DEVICE auto calculate_index(const PartitionIndex &ps_idx=_get_partition_index()) const
Definition: tile_distribution.hpp:177
static constexpr CK_TILE_HOST_DEVICE auto get_lengths()
Definition: tile_distribution.hpp:109
static constexpr index_t NDimP
Definition: tile_distribution.hpp:83
static constexpr CK_TILE_HOST_DEVICE index_t get_num_of_dimension_x()
Definition: tile_distribution.hpp:89
static constexpr CK_TILE_HOST_DEVICE auto get_y_indices_from_distributed_indices(DistributedIndices)
Definition: tile_distribution.hpp:205
CK_TILE_HOST_DEVICE auto calculate_rs_index_from_ps_index(const PartitionIndex &ps_idx) const
Definition: tile_distribution.hpp:142
static constexpr CK_TILE_HOST_DEVICE index_t get_num_of_dimension_p()
Definition: tile_distribution.hpp:91
constexpr CK_TILE_HOST_DEVICE const auto & get_ys_to_d_descriptor() const
Definition: tile_distribution.hpp:131
remove_cvref_t< PsYs2XsAdaptor_ > PsYs2XsAdaptor
Definition: tile_distribution.hpp:73
static constexpr CK_TILE_HOST_DEVICE index_t get_num_of_dimension_r()
Definition: tile_distribution.hpp:92
static constexpr index_t NDimR
Definition: tile_distribution.hpp:84
static constexpr CK_TILE_HOST_DEVICE bool is_static()
Definition: tile_distribution.hpp:227
Ys2DDescriptor ys_to_d_
Definition: tile_distribution.hpp:87
static constexpr index_t NDimX
Definition: tile_distribution.hpp:81
static constexpr CK_TILE_HOST_DEVICE index_t get_num_of_dimension_y()
Definition: tile_distribution.hpp:90
static constexpr CK_TILE_HOST_DEVICE auto get_static_tile_distribution_encoding()
Definition: tile_distribution.hpp:133
#define TO_TUPLE_OF_SEQUENCE(a_of_b_impl, a_size, bs_sizes)
Definition: container_helper.hpp:486
#define CONSTRUCT_STATIC_TENSOR_ADAPTOR_FROM_ENCODING(encoded_tensor_adaptor)
Definition: tensor_adaptor.hpp:840
#define CONSTRUCT_TENSOR_ADAPTOR_FROM_ENCODING(encoded_tensor_adaptor)
Definition: tensor_adaptor.hpp:716
#define TO_SEQUENCE(a, n)
Definition: to_sequence.hpp:10