/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/core/tensor/tile_distribution_encoding.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/core/tensor/tile_distribution_encoding.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/core/tensor/tile_distribution_encoding.hpp Source File
tile_distribution_encoding.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
16 
17 namespace ck_tile {
18 
19 template <typename RsLengths_, // sequence<...>
20  typename HsLengthss_, // tuple<sequence<...>, ...>
21  typename Ps2RHssMajor_, // tuple<sequence<...>, ...>
22  typename Ps2RHssMinor_, // tuple<sequence<...>, ...>
23  typename Ys2RHsMajor_, // sequence<...>
24  typename Ys2RHsMinor_> // sequence<...>
26 {
33 
34  static_assert(Ps2RHssMajor::size() == Ps2RHssMinor::size(), "wrong!");
35  static_assert(Ys2RHsMajor::size() == Ys2RHsMinor::size(), "wrong!");
36 
37  static constexpr index_t NDimX = HsLengthss::size();
38  static constexpr index_t NDimP = Ps2RHssMajor::size();
39  static constexpr index_t NDimY = Ys2RHsMajor::size();
40  static constexpr index_t NDimR = RsLengths::size();
41 
42  // FIXME: move into detail
43  static constexpr auto rs_lengths_ = RsLengths{};
44  static constexpr auto hs_lengthss_ = HsLengthss{};
45  static constexpr auto ps_to_rhss_major_ = Ps2RHssMajor{};
46  static constexpr auto ps_to_rhss_minor_ = Ps2RHssMinor{};
47  static constexpr auto ys_to_rhs_major_ = Ys2RHsMajor{};
48  static constexpr auto ys_to_rhs_minor_ = Ys2RHsMinor{};
49 
50 #if !CK_TILE_ENC_SUPPORT_Y_TO_R
51  static_assert(container_find(ys_to_rhs_major_, 0) == NDimY,
52  "do not support Y dim pointed to R dim");
53 #endif
54 
55  // redundant but useful info
56  // TODO: really bad code, should be over-hauled
57  struct detail
58  {
59  // ndim_rh_major_, ndim_span_mainor_
60  static constexpr index_t ndim_rh_major_ = NDimX + 1;
61  static constexpr index_t ndim_span_major_ = NDimX;
62 
63  // ndims_rhs_minor_[ndim_rh_major_]
64  static constexpr auto ndims_rhs_minor_ = generate_array(
65  [](auto i) {
66  if constexpr(i.value == 0)
67  {
68  return rs_lengths_.size();
69  }
70  else
71  {
72  return hs_lengthss_[i - number<1>{}].size();
73  }
74  },
76 
77  // max_ndim_rh_minor_
78  static constexpr index_t max_ndim_rh_minor_ =
80 
81  // rhs_lengthss_[ndim_rh_major_][max_ndim_rh_minor_]
82  static constexpr auto rhs_lengthss_ =
84 
85  // ys_lengths_
86  static constexpr auto ys_lengths_ = [] {
87  array<index_t, NDimY> ys_lengths_tmp{-1};
88 
89  for(index_t i = 0; i < NDimY; i++)
90  {
91  index_t rh_major = ys_to_rhs_major_[i];
92  index_t rh_minor = ys_to_rhs_minor_[i];
93 
94  ys_lengths_tmp(i) = rhs_lengthss_[rh_major][rh_minor];
95  }
96 
97  return ys_lengths_tmp;
98  }();
99 
100  // rhs_major_minor_to_ys_[ndim_rh_majpr_][max_ndim_rh_minor_]
101  static constexpr auto rhs_major_minor_to_ys_ = [] {
102  array<array<index_t, max_ndim_rh_minor_>, NDimX + 1> rhs_major_minor_to_ys_tmp{{-1}};
103 
104  static_for<0, NDimY, 1>{}([&](auto i) {
105  constexpr index_t rh_major = ys_to_rhs_major_[i];
106  constexpr index_t rh_minor = ys_to_rhs_minor_[i];
107 
108  rhs_major_minor_to_ys_tmp(rh_major)(rh_minor) = i;
109  });
110 
111  return rhs_major_minor_to_ys_tmp;
112  }();
113 
114  // ndims_span_minor_[NDimY]
115  static constexpr auto ndims_span_minor_ = [] {
116  array<index_t, NDimX> ndims_span_minor{0};
117 
118  for(index_t i = 0; i < NDimY; i++)
119  {
120  const index_t span_major = ys_to_rhs_major_[i] - 1;
121 
122  ndims_span_minor(span_major)++;
123  }
124 
125  return ndims_span_minor;
126  }();
127 
128  // max_ndim_span_minor_
129  static constexpr index_t max_ndim_span_minor_ =
131 
132  // rhs_major_minor_to_span_minor_ [ndim_rh_major_][max_ndim_rh_minor_]
133  static constexpr auto rhs_major_minor_to_span_minor_ = [] {
134  array<array<index_t, max_ndim_rh_minor_>, ndim_rh_major_> rhs_major_minor_to_span_minor{
135  {-1}};
136 
137  static_for<0, ndim_rh_major_, 1>{}([&](auto rh_major) {
138  constexpr index_t ndim_rh_minor = ndims_rhs_minor_[rh_major];
139 
140  index_t cnt_ndim_span_minor = 0;
141 
142  static_for<0, ndim_rh_minor, 1>{}([&](auto rh_minor) {
143  constexpr index_t idim_y = rhs_major_minor_to_ys_[rh_major][rh_minor];
144 
145  if(idim_y >= 0)
146  {
147  rhs_major_minor_to_span_minor(rh_major)(rh_minor) = cnt_ndim_span_minor;
148 
149  cnt_ndim_span_minor++;
150  }
151  });
152  });
153 
154  return rhs_major_minor_to_span_minor;
155  }();
156 
157  // ys_to_span_major_[NDimY]
158  static constexpr auto ys_to_span_major_ =
159  generate_array([](auto i) { return ys_to_rhs_major_[i] - 1; }, number<NDimY>{});
160 
161  // ys_to_span_minor_[NDimY]
162  static constexpr auto ys_to_span_minor_ = generate_array(
163  [](auto i) {
165  },
166  number<NDimY>{});
167 
168  // distributed_spans_lengthss_[ndim_span_major_][max_ndim_span_minor_]
169  static constexpr auto distributed_spans_lengthss_ = [] {
171  distributed_spans_lengthss{{-1}};
172 
173  static_for<0, NDimY, 1>{}([&](auto i) {
174  const index_t rh_major = ys_to_rhs_major_[i];
175  const index_t rh_minor = ys_to_rhs_minor_[i];
176 
177  const index_t h_length = hs_lengthss_[number<rh_major - 1>{}][rh_minor];
178 
179  const index_t span_major = rh_major - 1;
180  const index_t span_minor = rhs_major_minor_to_span_minor_[rh_major][rh_minor];
181 
182  distributed_spans_lengthss(span_major)(span_minor) = h_length;
183  });
184 
185  return distributed_spans_lengthss;
186  }();
187 
188  // ndims_distributed_spans_minor_[ndim_span_major_]
189  static constexpr auto ndims_distributed_spans_minor_ = [] {
190  array<index_t, ndim_span_major_> ndims_distributed_spans_minor{0};
191 
192  static_for<0, NDimY, 1>{}([&](auto i) {
193  const index_t span_major = ys_to_rhs_major_[i] - 1;
194 
195  ndims_distributed_spans_minor(span_major)++;
196  });
197 
198  return ndims_distributed_spans_minor;
199  }();
200 
201  // does_p_own_r_[NDimP][NDimR]
202  static constexpr auto does_p_own_r_ = [] {
203  if constexpr(NDimR > 0)
204  {
205  array<array<bool, NDimR>, NDimP> does_p_own_r{{false}};
206 
207  static_for<0, NDimP, 1>{}([&](auto idim_p) {
208  constexpr index_t ndim_low = ps_to_rhss_major_[idim_p].size();
209 
210  static_for<0, ndim_low, 1>{}([&](auto idim_low) {
211  constexpr index_t rh_major = ps_to_rhss_major_[idim_p][idim_low];
212  constexpr index_t rh_minor = ps_to_rhss_minor_[idim_p][idim_low];
213 
214  if constexpr(rh_major == 0)
215  {
216  does_p_own_r(idim_p)(rh_minor) = true;
217  }
218  });
219  });
220 
221  return does_p_own_r;
222  }
223  else
224  {
225  return array<array<bool, NDimR>, NDimP>{};
226  }
227  }();
228 
229  // ps_over_rs_derivative_[NDimP][NDimR]
230  static constexpr auto ps_over_rs_derivative_ = [] {
231  if constexpr(NDimR > 0)
232  {
233  array<array<index_t, NDimR>, NDimP> ps_over_rs_derivative{{0}};
234 
235  static_for<0, NDimP, 1>{}([&](auto idim_p) {
236  constexpr index_t ndim_low = ps_to_rhss_major_[idim_p].size();
237 
238  index_t p_over_rh_derivative = 1;
239 
240  static_for<ndim_low - 1, -1, -1>{}([&](auto idim_low) {
241  constexpr index_t rh_major = ps_to_rhss_major_[idim_p][idim_low];
242  constexpr index_t rh_minor = ps_to_rhss_minor_[idim_p][idim_low];
243 
244  constexpr index_t rh_length = rhs_lengthss_[rh_major][rh_minor];
245 
246  if constexpr(rh_major == 0)
247  {
248  ps_over_rs_derivative(idim_p)(rh_minor) = p_over_rh_derivative;
249  }
250 
251  p_over_rh_derivative *= rh_length;
252  });
253  });
254 
255  return ps_over_rs_derivative;
256  }
257  else
258  {
260  }
261  }();
262 
264  {
265  // e.g. tuple<seq<1, 4, 32>, seq<4, 1, 4, 2, 4>> --> seq<3, 5>
266  constexpr auto uniformed_h_dim_lengths = generate_sequence_v2(
267  [&](auto i) {
268  constexpr index_t size_ = HsLengthss{}[i].size();
269  return number<size_>{};
270  },
271  number<NDimX>{});
272  return uniformed_h_dim_lengths;
273  }
274 
275  // note: this function only count the p dim length along h, not r
277  {
278  // e.g. tuple<seq<1, 4, 32>, seq<1, 2, 8, 4, 4>>
279  // Y P Y Y P Y P Y
280  // | | |
281  // v v v
282  // return : seq<4, 2 * 4> => seq<4, 8>
283  constexpr auto uniformed_ps_to_rhss_major_ =
284  unpack([](auto... xs_) { return merge_sequences(xs_...); }, ps_to_rhss_major_);
285  constexpr auto uniformed_ps_to_rhss_minor_ =
286  unpack([](auto... xs_) { return merge_sequences(xs_...); }, ps_to_rhss_minor_);
287 
288  constexpr auto p_len_ = [&]() {
289  array<index_t, NDimX> len_{1};
290  static_for<0, NDimX, 1>{}([&](auto idim_x_) {
291  constexpr auto major_ = number<idim_x_ + 1>{}; // RDim
292  static_for<0, uniformed_ps_to_rhss_major_.size(), 1>{}([&](auto idim_u_) {
293  if constexpr(major_.value == uniformed_ps_to_rhss_major_[idim_u_])
294  {
295  constexpr auto minor_ = uniformed_ps_to_rhss_minor_[idim_u_];
296  constexpr auto h_length_ = hs_lengthss_[idim_x_][minor_];
297  len_[idim_x_] *= h_length_;
298  }
299  });
300  });
301  return len_;
302  }();
303  constexpr auto p_len_over_h_seq_ = TO_SEQUENCE(p_len_, NDimX);
304  return p_len_over_h_seq_;
305  }
306 
307  //
308  // R: seq<3>, H: tuple<seq<1, 4, 32>, seq<4, 1, 4, 2, 4>>
309  // => return seq<1, 3, 5>
310  // R: seq<>, H: tuple<seq<2, 4>, seq<16, 8, 8>>
311  // => return seq<0, 2, 3>
313  {
314  constexpr auto uniformed_rh_dim_lengths =
316 
317  return uniformed_rh_dim_lengths;
318  }
319 
320  // e.g. tuple<seq<1, 4, 32>, seq<4, 1, 4, 2, 4>> --> seq<3, 5> --> seq<0, 3, 8>
322  {
323  // <0, len_d0, len_d0+len_d1, ...>
324  // e.g. seq<3, 5> --> seq<0, 3, 8>
325  constexpr auto h_dim_prefix_sum = prefix_sum_sequence(get_uniformed_h_dim_lengths());
326 
327  return h_dim_prefix_sum;
328  }
329 
331  {
332  // <0, len_d0, len_d0+len_d1, ...>
333  // e.g. seq<3, 5> --> seq<0, 3, 8>
334  constexpr auto rh_dim_prefix_sum = prefix_sum_sequence(get_uniformed_rh_dim_lengths());
335 
336  return rh_dim_prefix_sum;
337  }
338 
340  {
341  // tuple<seq<xx..>, seq<yy..>> -> seq<xx..yy..>
342  constexpr auto uniformed_ps_to_rhss_major_ =
343  unpack([](auto... xs_) { return merge_sequences(xs_...); }, ps_to_rhss_major_);
344  constexpr auto uniformed_ps_to_rhss_minor_ =
345  unpack([](auto... xs_) { return merge_sequences(xs_...); }, ps_to_rhss_minor_);
346 
347  constexpr auto all_ps_2_rhss = transform_sequences(
348  [](auto major, auto minor) constexpr {
349  constexpr auto rh_dim_prefix_sum = get_rh_dim_lengths_prefix_sum();
350  return rh_dim_prefix_sum.at(major) + minor;
351  },
352  uniformed_ps_to_rhss_major_,
353  uniformed_ps_to_rhss_minor_);
354 
355  return all_ps_2_rhss;
356  }
357 
359  {
360  constexpr auto all_ys_2_rhss = transform_sequences(
361  [](auto major, auto minor) constexpr {
362  constexpr auto rh_dim_prefix_sum = get_rh_dim_lengths_prefix_sum();
363  return rh_dim_prefix_sum.at(major) + minor;
364  },
365  Ys2RHsMajor{},
366  Ys2RHsMinor{});
367 
368  return all_ys_2_rhss;
369  }
370 
372  {
373  // TODO: Y can't point to R
374  constexpr auto all_ys_2_rhss = transform_sequences(
375  [](auto major, auto minor) constexpr {
376  constexpr auto rh_dim_prefix_sum = get_rh_dim_lengths_prefix_sum();
377  return rh_dim_prefix_sum.at(major) + minor - NDimR;
378  },
379  Ys2RHsMajor{},
380  Ys2RHsMinor{});
381 
382  return all_ys_2_rhss;
383  }
384 
385  // return tuple of seq
386  CK_TILE_HOST_DEVICE static constexpr auto get_y_to_h_masks()
387  {
388  constexpr auto masks_ = generate_tuple(
389  [&](auto i) {
390  constexpr auto size_ = HsLengthss{}[i].size();
391  constexpr auto current_y_to_h_mask_ = [&]() {
392  array<index_t, size_> m_{0};
393  // TODO: we loop over all y for each h dim
394  for(auto j = 0; j < NDimY; j++)
395  {
396  if(Ys2RHsMajor{}[j] == (i + 1) /*RDim need plus 1*/)
397  {
398  m_[Ys2RHsMinor{}[j]] = 1;
399  }
400  }
401  return m_;
402  }();
403 
404  return TO_SEQUENCE(current_y_to_h_mask_, size_);
405  },
406  number<NDimX>{});
407  return masks_;
408  }
409 
410  // return tuple<sorted_dims, sorted_maps, sorted_prefix_sum>
411  template <typename IdxSeq, typename PrefixSumSeq>
412  CK_TILE_HOST_DEVICE static constexpr auto get_sorted_info(IdxSeq, PrefixSumSeq)
413  {
414  using sorted_idx = sequence_unique_sort<IdxSeq, less<index_t>, equal<index_t>>;
415 
416  constexpr auto sorted_dims = typename sorted_idx::type{};
417  constexpr auto sorted_maps = typename sorted_idx::sorted2unsorted_map{};
418 
419  constexpr auto sorted_histogram =
420  histogram_sorted_sequence(sorted_dims, PrefixSumSeq{});
421  constexpr auto sorted_prefix_sum = prefix_sum_sequence(sorted_histogram);
422 
423  return make_tuple(sorted_dims, sorted_maps, sorted_prefix_sum);
424  }
425 
426  // Note here y_to_h does not count R dim!
428  {
430  }
431  };
432 };
433 
434 template <typename encoding, typename shuffle>
436 template <typename encoding, index_t... shuffle>
437 class tile_distribution_encoding_shuffle<encoding, sequence<shuffle...>>
438 {
439  template <typename Ys2RHs>
441 
442  public:
443  using type = tile_distribution_encoding<typename encoding::RsLengths,
444  typename encoding::HsLengthss,
445  typename encoding::Ps2RHssMajor,
446  typename encoding::Ps2RHssMinor,
449 };
450 template <typename encoding, typename shuffle>
453 
454 namespace detail {
455 
456 template <typename OuterDstr, typename InnerDstr>
457 CK_TILE_HOST_DEVICE constexpr auto make_embed_tile_distribution_encoding(OuterDstr, InnerDstr)
458 {
459  static_assert(OuterDstr::NDimX == InnerDstr::NDimX, "wrong!");
460 
461  constexpr index_t NDimHMajor = OuterDstr::NDimX;
462 
463  using RsLengths =
465 
466  constexpr auto hs_lengthss = generate_tuple(
467  [&](auto i) {
468  return merge_sequences(typename OuterDstr::HsLengthss{}[i],
469  typename InnerDstr::HsLengthss{}[i]);
470  },
472 
473  //
474  constexpr auto rhs_major_2_ndim_outer_rhs_minor = [&]() {
475  array<index_t, NDimHMajor + 1> rhs_major_2_ndim_outer_rhs_minor_;
476 
477  // R dimension
478  rhs_major_2_ndim_outer_rhs_minor_(0) = OuterDstr::RsLengths::size();
479 
480  // Hs dimensions
481  static_for<0, NDimHMajor, 1>{}([&](auto i) {
482  rhs_major_2_ndim_outer_rhs_minor_(i + 1) = typename OuterDstr::HsLengthss{}[i].size();
483  });
484 
485  return rhs_major_2_ndim_outer_rhs_minor_;
486  }();
487 
488  // Ps2RHssMinor
489  constexpr auto updated_inner_ps_2_rhss_minor = generate_tuple(
490  [&](auto p) {
491  constexpr auto inner_p_2_rhss_major = typename InnerDstr::Ps2RHssMajor{}[p];
492  constexpr auto inner_p_2_rhss_minor = typename InnerDstr::Ps2RHssMinor{}[p];
493 
494  constexpr index_t ndim_tmp = inner_p_2_rhss_minor.size();
495 
496  constexpr auto updated_inner_p_2_rhss_minor = [&]() {
497  array<index_t, ndim_tmp> updated_inner_p_2_rhss_minor_;
498 
499  for(index_t i = 0; i < ndim_tmp; i++)
500  {
501  index_t rh_major = inner_p_2_rhss_major[i];
502 
503  index_t ndim_outer_h_minor = rhs_major_2_ndim_outer_rhs_minor[rh_major];
504 
505  updated_inner_p_2_rhss_minor_(i) = inner_p_2_rhss_minor[i] + ndim_outer_h_minor;
506  }
507 
508  return updated_inner_p_2_rhss_minor_;
509  }();
510 
511  return TO_SEQUENCE(updated_inner_p_2_rhss_minor, ndim_tmp);
512  },
514 
515  // Ys2RHsMinor
516  constexpr auto updated_inner_ys_2_rhs_minor = [&]() {
517  constexpr auto inner_ys_2_rhs_major = typename InnerDstr::Ys2RHsMajor{};
518  constexpr auto inner_ys_2_rhs_minor = typename InnerDstr::Ys2RHsMinor{};
519 
520  constexpr index_t ndim_tmp = inner_ys_2_rhs_minor.size();
521 
522  constexpr auto updated_inner_ys_2_rhs_minor_ = [&]() {
523  array<index_t, ndim_tmp> updated_inner_ys_2_rhs_minor__;
524 
525  for(index_t i = 0; i < ndim_tmp; i++)
526  {
527  index_t rh_major = inner_ys_2_rhs_major[i];
528 
529  index_t ndim_outer_h_minor = rhs_major_2_ndim_outer_rhs_minor[rh_major];
530 
531  updated_inner_ys_2_rhs_minor__(i) = inner_ys_2_rhs_minor[i] + ndim_outer_h_minor;
532  }
533 
534  return updated_inner_ys_2_rhs_minor__;
535  }();
536 
537  return TO_SEQUENCE(updated_inner_ys_2_rhs_minor_, ndim_tmp);
538  }();
539 
540  //
541  constexpr auto ps_2_rhss_major =
542  container_concat(typename OuterDstr::Ps2RHssMajor{}, typename InnerDstr::Ps2RHssMajor{});
543 
544  constexpr auto ps_2_rhss_minor =
545  container_concat(typename OuterDstr::Ps2RHssMinor{}, updated_inner_ps_2_rhss_minor);
546 
547  //
548  constexpr auto ys_2_rhs_major =
549  merge_sequences(typename OuterDstr::Ys2RHsMajor{}, typename InnerDstr::Ys2RHsMajor{});
550 
551  constexpr auto ys_2_rhs_minor =
552  merge_sequences(typename OuterDstr::Ys2RHsMinor{}, updated_inner_ys_2_rhs_minor);
553 
554  return tile_distribution_encoding<RsLengths,
555  remove_cvref_t<decltype(hs_lengthss)>,
556  remove_cvref_t<decltype(ps_2_rhss_major)>,
557  remove_cvref_t<decltype(ps_2_rhss_minor)>,
558  remove_cvref_t<decltype(ys_2_rhs_major)>,
559  remove_cvref_t<decltype(ys_2_rhs_minor)>>{};
560 }
561 
562 template <typename InDstr, index_t... InReduceDimXs>
563 CK_TILE_HOST_DEVICE constexpr auto
565 {
566  constexpr auto I1 = number<1>{};
567 
568  // FIXME: increase if fail
569  constexpr index_t max_ndim_r_out = 20;
570  constexpr index_t max_ndim_y_out = 20;
571 
572  //
573  constexpr index_t ndim_p = InDstr::NDimP;
574  constexpr index_t ndim_x_in = InDstr::NDimX;
575  constexpr index_t ndim_y_in = InDstr::NDimY;
576  constexpr index_t ndim_rh_major_in = InDstr::NDimX + 1;
577  constexpr index_t ndim_x_out = ndim_x_in - sizeof...(InReduceDimXs);
578  constexpr index_t max_ndim_rh_minor_in = InDstr::detail::max_ndim_rh_minor_;
579 
580  // ndims_ps_low
581  constexpr auto ndims_ps_low = generate_array(
582  [&](auto i) { return InDstr::ps_to_rhss_major_[i].size(); }, number<ndim_p>{});
583 
584  // is_rh_major_in_for_reduce
585  array<bool, ndim_rh_major_in> is_rh_major_in_for_reduce{false};
586 
587  for(index_t i = 0; i < reduce_dim_xs_in.size(); i++)
588  {
589  index_t rh_major = reduce_dim_xs_in[i] + 1;
590 
591  is_rh_major_in_for_reduce(rh_major) = true;
592  }
593 
594  // is_y_in_for_reduce
595  array<bool, ndim_y_in> is_y_in_for_reduce{false};
596 
597  for(index_t i = 0; i < ndim_y_in; i++)
598  {
599  index_t rh_major = InDstr::ys_to_rhs_major_[i];
600 
601  if(is_rh_major_in_for_reduce[rh_major])
602  {
603  is_y_in_for_reduce(i) = true;
604  }
605  }
606 
607  // is_rh_minor_in_for_y_reduce
608  array<array<bool, max_ndim_rh_minor_in>, ndim_rh_major_in> is_rh_minor_in_for_y_reduce{{false}};
609 
610  static_for<0, ndim_y_in, 1>{}([&](auto i) {
611  index_t rh_major = InDstr::ys_to_rhs_major_[i];
612  index_t rh_minor = InDstr::ys_to_rhs_minor_[i];
613 
614  if(is_y_in_for_reduce[i])
615  {
616  is_rh_minor_in_for_y_reduce(rh_major)(rh_minor) = true;
617  }
618  });
619 
620  // in2out_rh_major
621  array<index_t, ndim_rh_major_in> in2out_rh_major{-1};
622  index_t cnt_ndim_rh_major_out = 0;
623 
624  for(index_t i = 0; i < ndim_rh_major_in; i++)
625  {
626  if(is_rh_major_in_for_reduce[i])
627  {
628  in2out_rh_major(i) = 0;
629  }
630  else
631  {
632  in2out_rh_major(i) = cnt_ndim_rh_major_out;
633 
634  cnt_ndim_rh_major_out++;
635  }
636  }
637 
638  // rs_lengths_out, in2out_rh_minor
639  array<index_t, max_ndim_r_out> rs_lengths_out{-1};
640  array<array<index_t, max_ndim_rh_minor_in>, ndim_rh_major_in> in2out_rh_minor{{-1}};
641 
642  // loop over input R dim
643  for(index_t i = 0; i < InDstr::rs_lengths_.size(); i++)
644  {
645  // rs_lengths_out
646  rs_lengths_out(i) = InDstr::rs_lengths_[i];
647 
648  // in2out_rh_minor
649  in2out_rh_minor(0)(i) = i;
650  }
651 
652  // loop over input H Dim
653  index_t cnt_ndim_r_out = InDstr::rs_lengths_.size();
654 
655  static_for<1, ndim_rh_major_in, 1>{}([&](auto rh_major_in) {
656  constexpr auto h_major_in = rh_major_in - I1;
657 
658  constexpr index_t ndim_rh_minor_in = InDstr::hs_lengthss_[h_major_in].size();
659 
660  if(is_rh_major_in_for_reduce[rh_major_in])
661  {
662  for(index_t rh_minor_in = 0; rh_minor_in < ndim_rh_minor_in; rh_minor_in++)
663  {
664  if(not is_rh_minor_in_for_y_reduce[rh_major_in][rh_minor_in])
665  {
666  // rs_lengths_out
667  rs_lengths_out(cnt_ndim_r_out) = InDstr::hs_lengthss_[h_major_in][rh_minor_in];
668 
669  // in2out_rh_minor
670  in2out_rh_minor(rh_major_in)(rh_minor_in) = cnt_ndim_r_out;
671 
672  cnt_ndim_r_out++;
673  }
674  }
675  }
676  else
677  {
678  for(index_t rh_minor_in = 0; rh_minor_in < ndim_rh_minor_in; rh_minor_in++)
679  {
680  // in2out_rh_minor
681  in2out_rh_minor(rh_major_in)(rh_minor_in) = rh_minor_in;
682  }
683  }
684  });
685 
686  // ndim_r_out
687  const index_t ndim_r_out = cnt_ndim_r_out;
688 
689  // ndims_hs_minor_out, hs_lengthss_out
690  array<index_t, ndim_x_out> ndims_hs_minor_out{-1};
691  array<array<index_t, max_ndim_rh_minor_in>, ndim_x_out> hs_lengthss_out{{-1}};
692 
693  index_t cnt_ndim_x_out = 0;
694 
695  static_for<0, ndim_x_in, 1>{}([&](auto i) {
696  if(not is_rh_major_in_for_reduce[i + I1])
697  {
698  // ndims_hs_minor_out
699  ndims_hs_minor_out(cnt_ndim_x_out) = InDstr::hs_lengthss_[i].size();
700 
701  // hs_lengthss_out
702  static_for<0, InDstr::hs_lengthss_[i].size(), 1>{}(
703  [&](auto j) { hs_lengthss_out(cnt_ndim_x_out)(j) = InDstr::hs_lengthss_[i][j]; });
704 
705  cnt_ndim_x_out++;
706  }
707  });
708 
709  // ps_to_rhss_major_out, ps_to_rhss_minor_out
710  array<array<index_t, max_ndim_rh_minor_in>, ndim_p> ps_to_rhss_major_out{{-1}};
711  array<array<index_t, max_ndim_rh_minor_in>, ndim_p> ps_to_rhss_minor_out{{-1}};
712 
713  static_for<0, ndim_p, 1>{}([&](auto idim_p) {
714  static_for<0, InDstr::ps_to_rhss_major_[idim_p].size(), 1>{}([&](auto idim_low) {
715  index_t rh_major_in = InDstr::ps_to_rhss_major_[idim_p][idim_low];
716  index_t rh_minor_in = InDstr::ps_to_rhss_minor_[idim_p][idim_low];
717 
718  ps_to_rhss_major_out(idim_p)(idim_low) = in2out_rh_major[rh_major_in];
719  ps_to_rhss_minor_out(idim_p)(idim_low) = in2out_rh_minor[rh_major_in][rh_minor_in];
720  });
721  });
722 
723  // ys_to_rhs_major_out, ys_to_rhs_minor_out
724  array<index_t, max_ndim_y_out> ys_to_rhs_major_out{-1};
725  array<index_t, max_ndim_y_out> ys_to_rhs_minor_out{-1};
726 
727  index_t cnt_ndim_y_out = 0;
728 
729  static_for<0, ndim_y_in, 1>{}([&](auto i) {
730  if(not is_y_in_for_reduce[i])
731  {
732  index_t rh_major_in = InDstr::ys_to_rhs_major_[i];
733  index_t rh_minor_in = InDstr::ys_to_rhs_minor_[i];
734 
735  ys_to_rhs_major_out(cnt_ndim_y_out) = in2out_rh_major[rh_major_in];
736  ys_to_rhs_minor_out(cnt_ndim_y_out) = in2out_rh_minor[rh_major_in][rh_minor_in];
737 
738  cnt_ndim_y_out++;
739  }
740  });
741 
742  // ndim_y_out
743  const index_t ndim_y_out = cnt_ndim_y_out;
744 
745  //
746  return make_tuple(ndim_x_out,
747  ndim_p,
748  ndim_y_out,
749  ndim_r_out,
750  ndims_hs_minor_out,
751  ndims_ps_low,
752  rs_lengths_out,
753  hs_lengthss_out,
754  ps_to_rhss_major_out,
755  ps_to_rhss_minor_out,
756  ys_to_rhs_major_out,
757  ys_to_rhs_minor_out);
758 }
759 
760 template <typename InDstr, index_t... InReduceDimXs>
761 CK_TILE_HOST_DEVICE constexpr auto
763 {
764  constexpr auto impl = make_reduce_tile_distribution_encoding_impl(InDstr{}, reduce_dim_xs_in);
765 
766  constexpr index_t ndim_x = impl.template at<0>();
767  constexpr index_t ndim_p = impl.template at<1>();
768  constexpr index_t ndim_y = impl.template at<2>();
769  constexpr index_t ndim_r = impl.template at<3>();
770  constexpr auto ndims_hs_minor = impl.template at<4>();
771  constexpr auto ndims_ps_low = impl.template at<5>();
772  constexpr auto rs_lengths_impl = impl.template at<6>();
773  constexpr auto hs_lengthss_impl = impl.template at<7>();
774  constexpr auto ps_to_rhss_major_impl = impl.template at<8>();
775  constexpr auto ps_to_rhss_minor_impl = impl.template at<9>();
776  constexpr auto ys_to_rhs_major_impl = impl.template at<10>();
777  constexpr auto ys_to_rhs_minor_impl = impl.template at<11>();
778 
779  constexpr auto rs_lengths = TO_SEQUENCE(rs_lengths_impl, ndim_r);
780  constexpr auto hs_lengthss = TO_TUPLE_OF_SEQUENCE(hs_lengthss_impl, ndim_x, ndims_hs_minor);
781  constexpr auto ps_to_rhss_major =
782  TO_TUPLE_OF_SEQUENCE(ps_to_rhss_major_impl, ndim_p, ndims_ps_low);
783  constexpr auto ps_to_rhss_minor =
784  TO_TUPLE_OF_SEQUENCE(ps_to_rhss_minor_impl, ndim_p, ndims_ps_low);
785  constexpr auto ys_to_rhs_major = TO_SEQUENCE(ys_to_rhs_major_impl, ndim_y);
786  constexpr auto ys_to_rhs_minor = TO_SEQUENCE(ys_to_rhs_minor_impl, ndim_y);
787 
788  return tile_distribution_encoding<remove_cvref_t<decltype(rs_lengths)>,
789  remove_cvref_t<decltype(hs_lengthss)>,
790  remove_cvref_t<decltype(ps_to_rhss_major)>,
791  remove_cvref_t<decltype(ps_to_rhss_minor)>,
792  remove_cvref_t<decltype(ys_to_rhs_major)>,
793  remove_cvref_t<decltype(ys_to_rhs_minor)>>{};
794 }
795 
796 } // namespace detail
797 
798 // Free print function for tile_distribution_encoding::detail
799 template <typename RsLengths_,
800  typename HsLengthss_,
801  typename Ps2RHssMajor_,
802  typename Ps2RHssMinor_,
803  typename Ys2RHsMajor_,
804  typename Ys2RHsMinor_>
806 print(const typename tile_distribution_encoding<RsLengths_,
807  HsLengthss_,
808  Ps2RHssMajor_,
809  Ps2RHssMinor_,
810  Ys2RHsMajor_,
811  Ys2RHsMinor_>::detail& detail_obj)
812 {
813  printf("tile_distribution_encoding::detail{");
814  printf("ndim_rh_major_: ");
815  print(detail_obj.ndim_rh_major_);
816  printf(", ");
817  printf("ndim_span_major_: ");
818  print(detail_obj.ndim_span_major_);
819  printf(", ");
820  printf("ndims_rhs_minor_: ");
821  print(detail_obj.ndims_rhs_minor_);
822  printf(", ");
823  printf("ndim_rh_major_: ");
824  print(detail_obj.ndim_rh_major_);
825  printf(", ");
826  printf("max_ndim_rh_minor_: ");
827  print(detail_obj.max_ndim_rh_minor_);
828  printf(", ");
829  printf("rhs_lengthss_: ");
830  print(detail_obj.rhs_lengthss_);
831  printf(", ");
832  printf("ys_lengths_: ");
833  print(detail_obj.ys_lengths_);
834  printf(", ");
835  printf("rhs_major_minor_to_ys_: ");
836  print(detail_obj.rhs_major_minor_to_ys_);
837  printf(", ");
838  printf("ndims_span_minor_: ");
839  print(detail_obj.ndims_span_minor_);
840  printf(", ");
841  printf("max_ndim_span_minor_: ");
842  print(detail_obj.max_ndim_span_minor_);
843  printf(", ");
844  printf("ys_to_span_major_: ");
845  print(detail_obj.ys_to_span_major_);
846  printf(", ");
847  printf("ys_to_span_minor_: ");
848  print(detail_obj.ys_to_span_minor_);
849  printf(", ");
850  printf("distributed_spans_lengthss_: ");
851  print(detail_obj.distributed_spans_lengthss_);
852  printf(", ");
853  printf("ndims_distributed_spans_minor_: ");
854  print(detail_obj.ndims_distributed_spans_minor_);
855  printf(", ");
856  printf("ps_over_rs_derivative_: ");
857  print(detail_obj.ps_over_rs_derivative_);
858  printf("}");
859 }
860 
861 // Free print function for tile_distribution_encoding
862 template <typename RsLengths_,
863  typename HsLengthss_,
864  typename Ps2RHssMajor_,
865  typename Ps2RHssMinor_,
866  typename Ys2RHsMajor_,
867  typename Ys2RHsMinor_>
869  HsLengthss_,
870  Ps2RHssMajor_,
871  Ps2RHssMinor_,
872  Ys2RHsMajor_,
873  Ys2RHsMinor_>& encoding)
874 {
875  printf("tile_distribution_encoding{");
876 
877  printf("NDimX: %d, NDimP: %d, NDimY: %d, ", encoding.NDimX, encoding.NDimP, encoding.NDimY);
878  printf("rs_lengths_: ");
879  print(encoding.rs_lengths_);
880  printf(", ");
881  printf("hs_lengthss_: ");
882  print(encoding.hs_lengthss_);
883  printf(", ");
884  printf("ps_to_rhss_major_: ");
885  print(encoding.ps_to_rhss_major_);
886  printf(", ");
887  printf("ps_to_rhss_minor_: ");
888  print(encoding.ps_to_rhss_minor_);
889  printf(", ");
890  printf("ys_to_rhs_major_: ");
891  print(encoding.ys_to_rhs_major_);
892  printf(", ");
893  printf("ys_to_rhs_minor_: ");
894  print(encoding.ys_to_rhs_minor_);
895  printf(", ");
896  printf("}");
897 }
898 
899 } // namespace ck_tile
Definition: tile_distribution_encoding.hpp:435
#define CK_TILE_HOST_DEVICE
Definition: config.hpp:42
constexpr CK_TILE_HOST_DEVICE auto make_reduce_tile_distribution_encoding_impl(InDstr, sequence< InReduceDimXs... > reduce_dim_xs_in)
Definition: tile_distribution_encoding.hpp:564
constexpr CK_TILE_HOST_DEVICE auto make_reduce_tile_distribution_encoding(InDstr, sequence< InReduceDimXs... > reduce_dim_xs_in)
Definition: tile_distribution_encoding.hpp:762
constexpr CK_TILE_HOST_DEVICE auto make_embed_tile_distribution_encoding(OuterDstr, InnerDstr)
Definition: tile_distribution_encoding.hpp:457
Definition: cluster_descriptor.hpp:13
constexpr CK_TILE_HOST_DEVICE auto container_reduce(const Container &x, Reduce reduce, Init init, number< IBegin >=number< 0 >{}, number< IEnd >=number< Container::size()>{}, number< IStep >=number< 1 >{})
Definition: container_helper.hpp:198
typename sequence_merge< Seqs... >::type sequence_merge_t
Definition: sequence.hpp:1020
constexpr CK_TILE_HOST_DEVICE auto transform_sequences(F f, sequence< Xs... >)
Definition: sequence.hpp:829
typename tile_distribution_encoding_shuffle< encoding, shuffle >::type tile_distribution_encoding_shuffle_t
Definition: tile_distribution_encoding.hpp:452
constexpr CK_TILE_HOST_DEVICE auto generate_array(F &&f, number< N >)
Definition: sequence.hpp:1112
constexpr CK_TILE_HOST_DEVICE auto to_array_of_array(tuple< Seqs... > t_of_s)
Definition: tuple.hpp:630
int32_t index_t
Definition: integer.hpp:9
CK_TILE_HOST_DEVICE void print(const tile_distribution_encoding_pattern_2d< BlockSize, YPerTile, XPerTile, VecSize, DistributionPattern, NumWaveGroups > &)
Definition: static_encoding_pattern.hpp:342
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition: type_traits.hpp:21
constexpr CK_TILE_HOST_DEVICE auto generate_sequence_v2(F &&f, number< N >)
Definition: sequence.hpp:1042
constexpr CK_TILE_HOST_DEVICE auto merge_sequences(Seqs...)
Definition: sequence.hpp:823
constexpr CK_TILE_HOST_DEVICE auto generate_tuple(F &&f, number< N >)
Definition: tuple.hpp:429
constexpr index_t container_find(sequence< Is... > seq, index_t value)
Definition: container_helper.hpp:447
constexpr CK_TILE_HOST_DEVICE auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:360
constexpr CK_TILE_HOST_DEVICE auto unpack(F &&f, X &&x)
Definition: functional.hpp:200
constexpr CK_TILE_HOST_DEVICE auto histogram_sorted_sequence(SeqSortedSamples, sequence< r, rs... >)
Definition: sequence.hpp:1099
constexpr CK_TILE_HOST_DEVICE auto container_concat(const X &x, const Ys &... ys)
Definition: container_helper.hpp:363
constexpr auto prefix_sum_sequence(Seq)
Definition: sequence.hpp:905
A fixed-size array container similar to std::array with additional utilities.
Definition: array.hpp:43
static constexpr CK_TILE_HOST_DEVICE auto size()
Definition: array.hpp:97
Definition: integral_constant.hpp:13
Definition: math.hpp:122
Definition: sequence.hpp:590
Definition: sequence.hpp:49
static constexpr CK_TILE_HOST_DEVICE index_t size()
Definition: sequence.hpp:53
Definition: functional.hpp:43
Definition: tile_distribution_encoding.hpp:58
static constexpr index_t max_ndim_span_minor_
Definition: tile_distribution_encoding.hpp:129
static constexpr CK_TILE_HOST_DEVICE auto get_uniformed_idx_p_to_h()
Definition: tile_distribution_encoding.hpp:339
static constexpr CK_TILE_HOST_DEVICE auto get_sorted_info(IdxSeq, PrefixSumSeq)
Definition: tile_distribution_encoding.hpp:412
static constexpr auto rhs_lengthss_
Definition: tile_distribution_encoding.hpp:82
static constexpr auto distributed_spans_lengthss_
Definition: tile_distribution_encoding.hpp:169
static constexpr auto does_p_own_r_
Definition: tile_distribution_encoding.hpp:202
static constexpr CK_TILE_HOST_DEVICE auto get_uniformed_idx_y_to_h()
Definition: tile_distribution_encoding.hpp:371
static constexpr CK_TILE_HOST_DEVICE auto get_uniformed_rh_dim_lengths()
Definition: tile_distribution_encoding.hpp:312
static constexpr index_t max_ndim_rh_minor_
Definition: tile_distribution_encoding.hpp:78
static constexpr auto ys_to_span_major_
Definition: tile_distribution_encoding.hpp:158
static constexpr CK_TILE_HOST_DEVICE auto get_sorted_y_to_h_info()
Definition: tile_distribution_encoding.hpp:427
static constexpr auto rhs_major_minor_to_span_minor_
Definition: tile_distribution_encoding.hpp:133
static constexpr CK_TILE_HOST_DEVICE auto get_uniformed_p_dim_lengths_over_h()
Definition: tile_distribution_encoding.hpp:276
static constexpr CK_TILE_HOST_DEVICE auto get_h_dim_lengths_prefix_sum()
Definition: tile_distribution_encoding.hpp:321
static constexpr index_t ndim_span_major_
Definition: tile_distribution_encoding.hpp:61
static constexpr CK_TILE_HOST_DEVICE auto get_uniformed_h_dim_lengths()
Definition: tile_distribution_encoding.hpp:263
static constexpr auto ndims_span_minor_
Definition: tile_distribution_encoding.hpp:115
static constexpr CK_TILE_HOST_DEVICE auto get_uniformed_idx_y_to_rh()
Definition: tile_distribution_encoding.hpp:358
static constexpr index_t ndim_rh_major_
Definition: tile_distribution_encoding.hpp:60
static constexpr auto ps_over_rs_derivative_
Definition: tile_distribution_encoding.hpp:230
static constexpr CK_TILE_HOST_DEVICE auto get_y_to_h_masks()
Definition: tile_distribution_encoding.hpp:386
static constexpr auto ys_lengths_
Definition: tile_distribution_encoding.hpp:86
static constexpr auto ndims_distributed_spans_minor_
Definition: tile_distribution_encoding.hpp:189
static constexpr auto rhs_major_minor_to_ys_
Definition: tile_distribution_encoding.hpp:101
static constexpr auto ndims_rhs_minor_
Definition: tile_distribution_encoding.hpp:64
static constexpr auto ys_to_span_minor_
Definition: tile_distribution_encoding.hpp:162
static constexpr CK_TILE_HOST_DEVICE auto get_rh_dim_lengths_prefix_sum()
Definition: tile_distribution_encoding.hpp:330
Definition: tile_distribution_encoding.hpp:26
static constexpr index_t NDimR
Definition: tile_distribution_encoding.hpp:40
static constexpr auto ps_to_rhss_minor_
Definition: tile_distribution_encoding.hpp:46
static constexpr auto rs_lengths_
Definition: tile_distribution_encoding.hpp:43
static constexpr index_t NDimP
Definition: tile_distribution_encoding.hpp:38
remove_cvref_t< Ps2RHssMinor_ > Ps2RHssMinor
Definition: tile_distribution_encoding.hpp:30
static constexpr auto ys_to_rhs_major_
Definition: tile_distribution_encoding.hpp:47
static constexpr auto ys_to_rhs_minor_
Definition: tile_distribution_encoding.hpp:48
static constexpr index_t NDimY
Definition: tile_distribution_encoding.hpp:39
remove_cvref_t< Ys2RHsMinor_ > Ys2RHsMinor
Definition: tile_distribution_encoding.hpp:32
static constexpr auto hs_lengthss_
Definition: tile_distribution_encoding.hpp:44
remove_cvref_t< Ys2RHsMajor_ > Ys2RHsMajor
Definition: tile_distribution_encoding.hpp:31
remove_cvref_t< HsLengthss_ > HsLengthss
Definition: tile_distribution_encoding.hpp:28
remove_cvref_t< Ps2RHssMajor_ > Ps2RHssMajor
Definition: tile_distribution_encoding.hpp:29
remove_cvref_t< RsLengths_ > RsLengths
Definition: tile_distribution_encoding.hpp:27
static constexpr auto ps_to_rhss_major_
Definition: tile_distribution_encoding.hpp:45
static constexpr index_t NDimX
Definition: tile_distribution_encoding.hpp:37
#define TO_TUPLE_OF_SEQUENCE(a_of_b_impl, a_size, bs_sizes)
Definition: container_helper.hpp:486
#define TO_SEQUENCE(a, n)
Definition: to_sequence.hpp:10