21 template <
typename Distribution>
24 return Distribution::_get_partition_index();
29 template <
index_t... PartialHsLengths>
40 template <
index_t... PartialHsIndices>
66 template <
typename PsYs2XsAdaptor_,
67 typename Ys2DDescriptor_,
68 typename StaticTileDistributionEncoding_,
69 typename TileDistributionDetail_>
79 "wrong! should be static");
81 static constexpr
index_t NDimX = PsYs2XsAdaptor::get_num_of_bottom_dimension();
82 static constexpr
index_t NDimY = Ys2DDescriptor::get_num_of_top_dimension();
84 static constexpr
index_t NDimR = StaticTileDistributionEncoding_::NDimR;
97 static_assert(
NDimP == 1 or
NDimP == 2,
"wrong!");
99 if constexpr(
NDimP == 1)
103 else if constexpr(
NDimP == 2)
141 template <
typename PartitionIndex>
144 static_assert(PartitionIndex::size() ==
NDimP,
"wrong!");
153 constexpr
index_t ndim_low = DstrEncode::ps_to_rhss_major_[idim_p].size();
156 constexpr
index_t rh_major = DstrEncode::ps_to_rhss_major_[idim_p][i];
157 constexpr
index_t rh_minor = DstrEncode::ps_to_rhss_minor_[idim_p][i];
160 if constexpr(rh_major == 0)
162 constexpr
index_t adaptor_hidden_id =
163 DstrDetail::rh_major_minor_to_adaptor_hidden_idss_[rh_major][rh_minor];
166 rs_idx(rh_minor) = dummy_adaptor_coord.get_hidden_index()[adaptor_hidden_id];
175 template <
typename PartitionIndex = decltype(_get_partition_index())>
180 const auto window_adaptor_thread_coord_tmp =
182 return window_adaptor_thread_coord_tmp.get_bottom_index();
187 constexpr
auto distributed_spans_impl = DstrEncode::detail::distributed_spans_lengthss_;
188 constexpr
auto ndims_spans_minor = DstrEncode::detail::ndims_distributed_spans_minor_;
192 constexpr
auto span_impl = distributed_spans_impl[i];
193 constexpr
index_t ndim_span_minor = ndims_spans_minor[i];
203 template <
typename DistributedIndices>
207 constexpr
auto ys_idx_arr = [] {
211 constexpr
index_t span_major = DstrEncode::detail::ys_to_span_major_[i];
212 constexpr
index_t span_minor = DstrEncode::detail::ys_to_span_minor_[i];
216 ys_idx(i) = dstr_index.impl_[span_minor];
235 template <index_t NDimMax>
240 for(
index_t i = 0; i < iend - ibegin; ++i)
249 template <
typename StaticTileDistributionEncoding_>
253 using RsLengths =
typename StaticTileDistributionEncoding_::RsLengths;
254 using HsLengthss =
typename StaticTileDistributionEncoding_::HsLengthss;
255 using Ps2RHssMajor =
typename StaticTileDistributionEncoding_::Ps2RHssMajor;
256 using Ps2RHssMinor =
typename StaticTileDistributionEncoding_::Ps2RHssMinor;
257 using Ys2RHsMajor =
typename StaticTileDistributionEncoding_::Ys2RHsMajor;
258 using Ys2RHsMinor =
typename StaticTileDistributionEncoding_::Ys2RHsMinor;
261 constexpr
index_t kMaxNumTransforms = 20;
262 constexpr
index_t kMaxMetaDataSize = 128;
263 constexpr
index_t kMaxNumDim = 10;
274 constexpr
index_t ndim_x = HsLengthss::size();
283 index_t hidden_dim_cnt = ndim_x;
287 constexpr
index_t ndim_r_minor = RsLengths::size();
289 constexpr
auto r_minor_lengths = RsLengths{};
291 trans(num_tran++) = {
293 MetaData{to_array<index_t, ndim_r_minor>(r_minor_lengths)},
296 NumDim{ndim_r_minor},
297 make_sequential_index<kMaxNumDim>(hidden_dim_cnt, hidden_dim_cnt + ndim_r_minor)};
299 for(
index_t i = 0; i < ndim_r_minor; ++i)
301 rh_major_minor_to_hidden_ids(0)(i) = hidden_dim_cnt;
302 rh_major_minor_to_hidden_lengths(0)(i) = r_minor_lengths[i];
312 &rh_major_minor_to_hidden_ids,
313 &rh_major_minor_to_hidden_lengths](
auto idim_x) {
315 constexpr
auto h_minor_lengths =
316 HsLengthss{}.get(idim_x);
319 constexpr
index_t ndim_h_minor = h_minor_lengths.size();
321 trans(num_tran++) = {
323 MetaData{to_array<index_t, ndim_h_minor>(h_minor_lengths)},
326 NumDim{ndim_h_minor},
327 make_sequential_index<kMaxNumDim>(hidden_dim_cnt, hidden_dim_cnt + ndim_h_minor)};
329 for(
index_t i = 0; i < ndim_h_minor; ++i)
331 rh_major_minor_to_hidden_ids(idim_x + 1)(i) = hidden_dim_cnt;
332 rh_major_minor_to_hidden_lengths(idim_x + 1)(i) = h_minor_lengths[i];
339 constexpr
index_t ndim_p = Ps2RHssMajor::size();
341 Dims hidden_dim_id_ps;
345 index_t hidden_dim_id_p = hidden_dim_cnt++;
347 hidden_dim_id_ps(iDimP) = hidden_dim_id_p;
349 constexpr
auto p2RHsMajor = Ps2RHssMajor{}[iDimP];
350 constexpr
auto p2RHsMinor = Ps2RHssMinor{}[iDimP];
352 static_assert(p2RHsMajor.size() == p2RHsMinor.size(),
"wrong!");
354 constexpr
index_t ndim_low = p2RHsMajor.size();
359 for(
index_t i = 0; i < ndim_low; ++i)
361 index_t rh_major = p2RHsMajor[i];
362 index_t rh_minor = p2RHsMinor[i];
363 low_dims(i) = rh_major_minor_to_hidden_ids[rh_major][rh_minor];
364 low_lengths(i) = rh_major_minor_to_hidden_lengths[rh_major][rh_minor];
368 MetaData{to_array<index_t, ndim_low>(low_lengths)},
372 Dims{hidden_dim_id_p}};
375 constexpr
index_t ndim_bottom = ndim_x;
377 constexpr
auto bottom_dim_ids = make_sequential_index<kMaxNumDim>(0, ndim_bottom);
379 constexpr
auto ys_to_rhs_major = Ys2RHsMajor{};
380 constexpr
auto ys_to_rhs_minor = Ys2RHsMinor{};
382 constexpr
index_t ndim_y = Ys2RHsMajor::size();
383 constexpr
index_t ndim_top = ndim_p + ndim_y;
385 auto top_dim_ids = hidden_dim_id_ps;
388 for(
index_t i = 0; i < ndim_y; ++i)
390 index_t rh_major = ys_to_rhs_major[i];
391 index_t rh_minor = ys_to_rhs_minor[i];
392 top_dim_ids(ndim_p + i) = rh_major_minor_to_hidden_ids[rh_major][rh_minor];
397 const auto ps_ys_to_xs_adaptor_encoding =
398 make_tuple(trans, num_tran, bottom_dim_ids, ndim_bottom, top_dim_ids, ndim_top);
404 for(
index_t i = 0; i < ndim_y; ++i)
406 index_t rh_major = ys_to_rhs_major[i];
407 index_t rh_minor = ys_to_rhs_minor[i];
408 index_t y_length = rh_major_minor_to_hidden_lengths[rh_major][rh_minor];
409 y_lengths(i) = y_length;
410 d_length *= y_length;
414 MetaData{to_array<index_t, ndim_y>(y_lengths)},
418 make_sequential_index<kMaxNumDim>(1, ndim_y + 1));
420 const auto ys_to_d_adaptor_encoding =
make_tuple(
421 make_tuple(tran), 1, Dims{0}, 1, make_sequential_index<kMaxNumDim>(1, ndim_y + 1), ndim_y);
423 return make_tuple(ps_ys_to_xs_adaptor_encoding,
424 ys_to_d_adaptor_encoding,
426 rh_major_minor_to_hidden_ids);
430 template <
typename RhMajorMinor2AdaptorH
iddenIdss>
441 template <
typename StaticTileDistributionEncoding_>
444 using DstrEncode = remove_cvref_t<StaticTileDistributionEncoding_>;
446 constexpr
auto adaptor_impl =
449 constexpr
auto ps_ys_to_xs_adaptor_impl = adaptor_impl.template at<0>();
450 constexpr
auto ys_to_d_adaptor_impl = adaptor_impl.template at<1>();
451 constexpr
index_t d_length = adaptor_impl.template at<2>();
452 constexpr
auto rh_major_minor_to_hidden_ids_impl = adaptor_impl.template at<3>();
454 constexpr
auto ps_ys_to_xs_adaptor =
459 constexpr
auto ys_to_d_descriptor =
463 constexpr
index_t ndim_rh_major = DstrEncode::detail::ndim_rh_major_;
464 constexpr
auto ndims_rhs_minor = DstrEncode::detail::ndims_rhs_minor_;
466 constexpr
auto rh_major_minor_to_hidden_ids =
469 return tile_distribution<
472 remove_cvref_t<DstrEncode>,
473 detail::tile_distribution_detail<
remove_cvref_t<decltype(rh_major_minor_to_hidden_ids)>>>{
474 ps_ys_to_xs_adaptor, ys_to_d_descriptor};
479 template <
typename StaticTileDistributionEncoding_>
484 constexpr
auto adaptor_impl =
487 constexpr
auto ps_ys_to_xs_adaptor_impl = adaptor_impl.template at<0>();
488 constexpr
auto ys_to_d_adaptor_impl = adaptor_impl.template at<1>();
489 constexpr
index_t d_length = adaptor_impl.template at<2>();
490 constexpr
auto rh_major_minor_to_hidden_ids_impl = adaptor_impl.template at<3>();
492 constexpr
auto ps_ys_to_xs_adaptor =
495 constexpr
auto ys_to_d_adaptor =
498 constexpr
auto ys_to_d_descriptor =
502 constexpr
index_t ndim_rh_major = DstrEncode::detail::ndim_rh_major_;
503 constexpr
auto ndims_rhs_minor = DstrEncode::detail::ndims_rhs_minor_;
505 constexpr
auto rh_major_minor_to_hidden_ids =
513 ps_ys_to_xs_adaptor, ys_to_d_descriptor};
553 template <
typename Distribution,
index_t... XSliceBegins,
index_t... XSliceEnds>
559 using Encoding = decltype(Distribution::get_static_tile_distribution_encoding());
561 static_assert(
sizeof...(XSliceBegins) ==
sizeof...(XSliceEnds));
562 static_assert(
sizeof...(XSliceBegins) == Encoding::NDimX,
"only support slice over h, not r");
564 constexpr
auto p_len_over_h = Encoding::detail::get_uniformed_p_dim_lengths_over_h();
568 if constexpr(x_slice_ends[i] == -1)
571 constexpr
auto x_length_ =
577 return x_slice_ends[i];
582 constexpr
auto x_slice_lengths = x_slice_ends_ - x_slice_begins;
585 [&](
auto i) constexpr {
586 constexpr
auto len_ = x_slice_lengths[i];
587 static_assert(len_ % p_len_over_h[i] == 0,
588 "slice length must be dividable by p_len_over_h");
589 return number<len_ / p_len_over_h[i]>{};
591 number<x_slice_lengths.size()>{});
593 constexpr
auto src_h_prefix_sum = Encoding::detail::get_h_dim_lengths_prefix_sum();
594 constexpr
auto src_y_info = Encoding::detail::get_sorted_y_to_h_info();
595 constexpr
auto src_y_dims = src_y_info[
number<0>{}];
596 constexpr
auto src_y_maps = src_y_info[
number<1>{}];
597 constexpr
auto src_y_prefix_sum = src_y_info[
number<2>{}];
599 constexpr
auto sliced_hlen_yidx_ylen = [&]() constexpr {
600 auto y_slice_sorted_origins = make_zero_multi_index<Encoding::NDimY>();
601 auto y_slice_lengths = Encoding::detail::ys_lengths_;
602 constexpr
auto y_to_h_masks = Encoding::detail::get_y_to_h_masks();
608 [&](
auto h_len,
auto id) {
610 h_len,
number<x_slice_lengths_without_p[
id]>{}, y_to_h_masks[id]);
612 constexpr
auto sliced_h_lens = sliced_h[
number<0>{}];
613 constexpr
auto sliced_h_index = sliced_h[
number<2>{}];
617 constexpr
auto found_y_index =
container_find(src_y_dims, uniformed_h_index);
618 constexpr
auto y_to_h_dim_end = src_y_prefix_sum[
id + 1];
620 static_assert(found_y_index >= 0 && found_y_index < src_y_dims.size(),
621 "not sliced at y dim, please check");
624 constexpr
auto sliced_y_to_h_lens =
626 constexpr
auto sliced_y_to_h_dims = sliced_y_to_h_lens.size();
628 y_slice_lengths(src_y_maps[y_to_h_dim_end - 1 - i]) =
629 sliced_y_to_h_lens[sliced_y_to_h_dims - 1 - i];
636 constexpr
auto y_origin = [&]() {
638 constexpr
auto y_to_h_len =
640 constexpr
auto y_to_h_dims = y_to_h_len.size();
644 constexpr
auto y_begin_ = x_slice_begins[id] / p_len_over_h[id];
647 auto y_origin_ = make_zero_multi_index<Encoding::NDimY>();
650 y_origin_(y_to_h_dim_end - 1 - i) = h_origin_[y_to_h_dims - 1 - i];
656 src_y_prefix_sum[
id + 1],
661 return sliced_h_lens;
663 typename Encoding::HsLengthss{},
668 return make_tuple(new_h_lengths, y_slice_origins, y_slice_lengths);
671 constexpr
auto sliced_h_lengths = sliced_hlen_yidx_ylen[
number<0>{}];
672 constexpr
auto sliced_y_origins_array = sliced_hlen_yidx_ylen[
number<1>{}];
673 constexpr
auto sliced_y_origins_size = sliced_y_origins_array.size();
674 constexpr
auto sliced_y_lengths_array = sliced_hlen_yidx_ylen[
number<2>{}];
675 constexpr
auto sliced_y_lengths_size = sliced_y_lengths_array.size();
677 constexpr
auto sliced_y_origins =
TO_SEQUENCE(sliced_y_origins_array, sliced_y_origins_size);
678 constexpr
auto sliced_y_lengths =
TO_SEQUENCE(sliced_y_lengths_array, sliced_y_lengths_size);
686 typename Encoding::Ps2RHssMajor,
687 typename Encoding::Ps2RHssMinor,
688 typename Encoding::Ys2RHsMajor,
689 typename Encoding::Ys2RHsMinor>{}),
697 template <
typename PsYs2XsAdaptor_,
698 typename Ys2DDescriptor_,
699 typename StaticTileDistributionEncoding_,
700 typename TileDistributionDetail_>
703 StaticTileDistributionEncoding_,
704 TileDistributionDetail_>& distribution)
706 printf(
"tile_distribution{");
707 printf(
"tile_distribution_encoding: ");
708 print(StaticTileDistributionEncoding_{});
710 printf(
"ps_ys_to_xs_: ");
711 print(distribution.ps_ys_to_xs_);
713 printf(
"ys_to_d_: ");
714 print(distribution.ys_to_d_);
Concept for encoding of Unicode characters.
#define CK_TILE_HOST_DEVICE
Definition: config.hpp:42
constexpr CK_TILE_HOST_DEVICE auto make_sequential_index(index_t ibegin, index_t iend)
Definition: tile_distribution.hpp:236
constexpr CK_TILE_HOST_DEVICE auto make_tile_distributed_span(sequence< Is... >)
Definition: tile_distribution.hpp:53
constexpr CK_TILE_HOST_DEVICE auto slice_distribution_from_x(Distribution, sequence< XSliceBegins... > x_slice_begins, sequence< XSliceEnds... > x_slice_ends)
Definition: tile_distribution.hpp:554
constexpr CK_TILE_HOST_DEVICE auto make_tile_distributed_index(sequence< Is... >)
Definition: tile_distribution.hpp:59
constexpr CK_TILE_HOST_DEVICE auto make_adaptor_encoding_for_tile_distribution(StaticTileDistributionEncoding_)
Definition: tile_distribution.hpp:251
CK_TILE_HOST_DEVICE auto get_partition_index(Distribution)
Definition: tile_distribution.hpp:22
Definition: cluster_descriptor.hpp:13
constexpr CK_TILE_HOST_DEVICE auto make_zero_multi_index()
Definition: multi_index.hpp:26
constexpr CK_TILE_HOST_DEVICE auto container_reorder_given_old2new(const array< TData, NSize > &old_array, sequence< IRs... > old2new)
Definition: container_helper.hpp:48
constexpr CK_TILE_HOST_DEVICE auto container_reduce(const Container &x, Reduce reduce, Init init, number< IBegin >=number< 0 >{}, number< IEnd >=number< Container::size()>{}, number< IStep >=number< 1 >{})
Definition: container_helper.hpp:198
coord_transform_enum
Definition: coordinate_transform.hpp:17
constexpr CK_TILE_HOST_DEVICE auto pick_sequence_elements_by_mask(Seq, Mask)
Definition: sequence.hpp:942
constexpr CK_TILE_HOST_DEVICE void set_container_subset(array< T, N > &y, sequence< Is... > picks, const array< T, sizeof...(Is)> &x)
Definition: container_helper.hpp:420
constexpr CK_TILE_HOST_DEVICE auto make_tensor_adaptor_coordinate(const Adaptor &adaptor, const TopIndex &idx_top)
Definition: tensor_adaptor_coordinate.hpp:55
constexpr auto reverse_slice_sequence(Seq, number< SliceSize >, Mask=typename uniform_sequence_gen< Seq::size(), 1 >::type{})
Definition: sequence.hpp:1220
constexpr CK_TILE_HOST_DEVICE auto to_array_of_array(tuple< Seqs... > t_of_s)
Definition: tuple.hpp:630
int32_t index_t
Definition: integer.hpp:9
CK_TILE_HOST_DEVICE void print(const tile_distribution_encoding_pattern_2d< BlockSize, YPerTile, XPerTile, VecSize, DistributionPattern, NumWaveGroups > &)
Definition: static_encoding_pattern.hpp:342
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition: type_traits.hpp:21
constexpr CK_TILE_HOST_DEVICE auto generate_sequence_v2(F &&f, number< N >)
Definition: sequence.hpp:1042
constexpr CK_TILE_HOST_DEVICE auto make_tensor_descriptor_from_adaptor(const Adaptor &adaptor, const ElementSpaceSize &element_space_size)
Definition: tensor_descriptor.hpp:177
constexpr CK_TILE_HOST_DEVICE auto make_merge_transform_v3_division_mod(const LowLengths &low_lengths)
Definition: coordinate_transform.hpp:1609
constexpr CK_TILE_HOST_DEVICE auto generate_tuple(F &&f, number< N >)
Definition: tuple.hpp:429
constexpr index_t container_find(sequence< Is... > seq, index_t value)
Definition: container_helper.hpp:447
constexpr CK_TILE_HOST_DEVICE auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:360
constexpr CK_TILE_HOST_DEVICE auto get_container_subset(const array< T, N > &arr, sequence< Is... >)
Definition: container_helper.hpp:389
constexpr CK_TILE_HOST_DEVICE auto make_static_tile_distribution(StaticTileDistributionEncoding_)
Definition: tile_distribution.hpp:480
impl::is_static_impl< remove_cvref_t< T > > is_static
Definition: type_traits.hpp:87
constexpr CK_TILE_HOST_DEVICE auto container_concat(const X &x, const Ys &... ys)
Definition: container_helper.hpp:363
constexpr CK_TILE_HOST_DEVICE auto transform_tuples(F f, const X &x)
Definition: tuple.hpp:505
Definition: sequence.hpp:284
A fixed-size array container similar to std::array with additional utilities.
Definition: array.hpp:43
Definition: integral_constant.hpp:13
Definition: tile_distribution.hpp:432
static constexpr auto rh_major_minor_to_adaptor_hidden_idss_
Definition: tile_distribution.hpp:433
Definition: sequence.hpp:49
static constexpr CK_TILE_HOST_DEVICE index_t size()
Definition: sequence.hpp:53
Definition: functional.hpp:43
Definition: tile_distribution.hpp:42
static constexpr CK_TILE_HOST_DEVICE bool is_static()
Definition: tile_distribution.hpp:47
static constexpr auto impl_
Definition: tile_distribution.hpp:45
Definition: tile_distribution.hpp:31
static constexpr auto impl_
Definition: tile_distribution.hpp:34
static constexpr CK_TILE_HOST_DEVICE bool is_static()
Definition: tile_distribution.hpp:36
Definition: tile_distribution_encoding.hpp:26
Definition: tile_distribution.hpp:72
remove_cvref_t< Ys2DDescriptor_ > Ys2DDescriptor
Definition: tile_distribution.hpp:74
PsYs2XsAdaptor ps_ys_to_xs_
Definition: tile_distribution.hpp:86
static constexpr CK_TILE_HOST_DEVICE auto get_distributed_spans()
Definition: tile_distribution.hpp:185
static CK_TILE_HOST_DEVICE auto _get_partition_index()
Definition: tile_distribution.hpp:94
constexpr CK_TILE_HOST_DEVICE const auto & get_ps_ys_to_xs_adaptor() const
Definition: tile_distribution.hpp:126
static constexpr index_t NDimY
Definition: tile_distribution.hpp:82
remove_cvref_t< StaticTileDistributionEncoding_ > DstrEncode
Definition: tile_distribution.hpp:75
remove_cvref_t< TileDistributionDetail_ > DstrDetail
Definition: tile_distribution.hpp:76
CK_TILE_HOST_DEVICE auto calculate_index(const PartitionIndex &ps_idx=_get_partition_index()) const
Definition: tile_distribution.hpp:177
static constexpr CK_TILE_HOST_DEVICE auto get_lengths()
Definition: tile_distribution.hpp:109
static constexpr index_t NDimP
Definition: tile_distribution.hpp:83
static constexpr CK_TILE_HOST_DEVICE index_t get_num_of_dimension_x()
Definition: tile_distribution.hpp:89
static constexpr CK_TILE_HOST_DEVICE auto get_y_indices_from_distributed_indices(DistributedIndices)
Definition: tile_distribution.hpp:205
CK_TILE_HOST_DEVICE auto calculate_rs_index_from_ps_index(const PartitionIndex &ps_idx) const
Definition: tile_distribution.hpp:142
static constexpr CK_TILE_HOST_DEVICE index_t get_num_of_dimension_p()
Definition: tile_distribution.hpp:91
constexpr CK_TILE_HOST_DEVICE const auto & get_ys_to_d_descriptor() const
Definition: tile_distribution.hpp:131
remove_cvref_t< PsYs2XsAdaptor_ > PsYs2XsAdaptor
Definition: tile_distribution.hpp:73
static constexpr CK_TILE_HOST_DEVICE index_t get_num_of_dimension_r()
Definition: tile_distribution.hpp:92
static constexpr index_t NDimR
Definition: tile_distribution.hpp:84
static constexpr CK_TILE_HOST_DEVICE bool is_static()
Definition: tile_distribution.hpp:227
Ys2DDescriptor ys_to_d_
Definition: tile_distribution.hpp:87
static constexpr index_t NDimX
Definition: tile_distribution.hpp:81
static constexpr CK_TILE_HOST_DEVICE index_t get_num_of_dimension_y()
Definition: tile_distribution.hpp:90
static constexpr CK_TILE_HOST_DEVICE auto get_static_tile_distribution_encoding()
Definition: tile_distribution.hpp:133
#define TO_TUPLE_OF_SEQUENCE(a_of_b_impl, a_size, bs_sizes)
Definition: container_helper.hpp:486
#define CONSTRUCT_STATIC_TENSOR_ADAPTOR_FROM_ENCODING(encoded_tensor_adaptor)
Definition: tensor_adaptor.hpp:840
#define CONSTRUCT_TENSOR_ADAPTOR_FROM_ENCODING(encoded_tensor_adaptor)
Definition: tensor_adaptor.hpp:716
#define TO_SEQUENCE(a, n)
Definition: to_sequence.hpp:10