/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/core/algorithm/coordinate_transform.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/core/algorithm/coordinate_transform.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/core/algorithm/coordinate_transform.hpp Source File
coordinate_transform.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
13 
14 namespace ck_tile {
15 
17 {
18  undefined,
20  pad,
21  embed,
22  merge,
23  unmerge,
24  replicate,
25  xor_t,
26  offset,
27  indexing,
28 };
29 
30 template <index_t NDimLow, index_t NDimUp>
32 {
33  CK_TILE_HOST_DEVICE static constexpr auto get_type_enum()
34  {
36  }
37 
38  CK_TILE_HOST_DEVICE static constexpr index_t get_num_of_lower_dimension() { return NDimLow; }
39 
40  CK_TILE_HOST_DEVICE static constexpr index_t get_num_of_upper_dimension() { return NDimUp; }
41 
42  // return safe value for vector length/stride, based on compile-time known only
43  // variables
44  // MUST be static function
45  template <typename LowVectorLengths, typename LowVectorStrides>
46  CK_TILE_HOST_DEVICE static constexpr auto
48  const LowVectorStrides&)
49  {
50  if constexpr(NDimUp > 0)
51  {
52  array<index_t, NDimUp> up_vector_lengths{-1};
53  array<index_t, NDimUp> up_vector_strides{-1};
54 
55  return make_tuple(up_vector_lengths, up_vector_strides);
56  }
57  else
58  {
60  }
61  }
62 };
63 
64 template <typename LowLength>
65 struct pass_through : public base_transform<1, 1>
66 {
68 
71 
72  using UpLengths = decltype(make_tuple(LowLength{}));
73 
75 
76  CK_TILE_HOST_DEVICE constexpr pass_through() = default;
77 
78  CK_TILE_HOST_DEVICE constexpr pass_through(const LowLength& low_length)
79  : up_lengths_{make_tuple(low_length)}
80  {
81  }
82 
83  CK_TILE_HOST_DEVICE static constexpr auto get_type_enum()
84  {
86  }
87 
88  CK_TILE_HOST_DEVICE constexpr const auto& get_upper_lengths() const { return up_lengths_; }
89 
90  template <typename LowIdx, typename UpIdx>
91  CK_TILE_HOST_DEVICE static constexpr void calculate_lower_index(LowIdx& idx_low,
92  const UpIdx& idx_up)
93  {
94  static_assert(LowIdx::size() == 1 && UpIdx::size() == 1,
95  "wrong! inconsistent # of dimension");
96 
97  idx_low(number<0>{}) = idx_up[number<0>{}];
98  }
99 
100  template <typename LowIdxDiff, typename UpIdxDiff, typename LowIdx, typename UpIdx>
101  CK_TILE_HOST_DEVICE static void update_lower_index(LowIdxDiff& idx_diff_low,
102  const UpIdxDiff& idx_diff_up,
103  LowIdx& idx_low,
104  const UpIdx&)
105  {
106  static_assert(LowIdxDiff::size() == 1 && UpIdxDiff::size() == 1 && LowIdx::size() == 1 &&
107  UpIdx::size() == 1,
108  "wrong! inconsistent # of dimension");
109 
110  constexpr auto I0 = number<0>{};
111 
112  idx_diff_low[I0] = idx_diff_up[I0];
113 
114  idx_low += idx_diff_low;
115  }
116 
117  CK_TILE_HOST_DEVICE static constexpr bool
119  {
120  return true;
121  }
122 
123  template <typename UpIdx>
124  CK_TILE_HOST_DEVICE static constexpr bool
126  {
127  return true;
128  }
129 
131  {
133  }
134 
135  // MUST be static function
136  template <typename LowVectorLengths, typename LowVectorStrides>
137  CK_TILE_HOST_DEVICE static constexpr auto
138  calculate_upper_dimension_safe_vector_length_strides(const LowVectorLengths& low_vector_lengths,
139  const LowVectorStrides& low_vector_strides)
140  {
141  return make_tuple(low_vector_lengths, low_vector_strides);
142  }
143 };
144 
145 template <typename LowLength>
146 CK_TILE_HOST_DEVICE static void print(const pass_through<LowLength>& pt)
147 {
148  printf("pass_through{");
149 
150  printf("up_lengths_: ");
151  print(pt.get_upper_lengths());
152 
153  printf("}");
154 }
155 
156 template <typename LowLength,
157  typename LeftPadLength,
158  typename RightPadLength,
159  bool SkipIsValidCheck = false>
160 struct pad : public base_transform<1, 1>
161 {
164 
165  using UpLengths = decltype(make_tuple(LowLength{} + LeftPadLength{} + RightPadLength{}));
166 
168  LeftPadLength left_pad_length_;
169  RightPadLength right_pad_length_;
170 
172 
173  CK_TILE_HOST_DEVICE constexpr pad(const LowLength& low_length,
174  const LeftPadLength& left_pad_length,
175  const RightPadLength& right_pad_length)
176  : up_lengths_{make_tuple(low_length + left_pad_length + right_pad_length)},
177  left_pad_length_{left_pad_length},
178  right_pad_length_{right_pad_length}
179  {
180  }
181 
182  CK_TILE_HOST_DEVICE constexpr const auto& get_upper_lengths() const { return up_lengths_; }
183 
184  template <typename LowIdx, typename UpIdx>
185  CK_TILE_HOST_DEVICE constexpr void calculate_lower_index(LowIdx& idx_low,
186  const UpIdx& idx_up) const
187  {
188  static_assert(LowIdx::size() == 1 && UpIdx::size() == 1,
189  "wrong! inconsistent # of dimension");
190 
191  idx_low(number<0>{}) = idx_up[number<0>{}] - left_pad_length_;
192  }
193 
194  template <typename LowIdxDiff, typename UpIdxDiff, typename LowIdx, typename UpIdx>
195  CK_TILE_HOST_DEVICE static void update_lower_index(LowIdxDiff& idx_diff_low,
196  const UpIdxDiff& idx_diff_up,
197  LowIdx& idx_low,
198  const UpIdx&)
199  {
200  static_assert(LowIdxDiff::size() == 1 && UpIdxDiff::size() == 1 && LowIdx::size() == 1 &&
201  UpIdx::size() == 1,
202  "wrong! inconsistent # of dimension");
203 
204  constexpr auto I0 = number<0>{};
205 
206  idx_diff_low[I0] = idx_diff_up[I0];
207 
208  idx_low += idx_diff_low;
209  }
210 
211  CK_TILE_HOST_DEVICE static constexpr bool
213  {
214  return SkipIsValidCheck;
215  }
216 
217  template <typename UpIdx>
218  CK_TILE_HOST_DEVICE constexpr bool
220  {
221  return SkipIsValidCheck ||
222  ((idx_up[number<0>{}] >= left_pad_length_) &&
223  (idx_up[number<0>{}] < up_lengths_[number<0>{}] - right_pad_length_));
224  }
225 
227  {
231  }
232 };
233 
234 template <typename LowLength,
235  typename LeftPadLength,
236  typename RightPadLength,
237  bool SkipIsValidCheck>
238 CK_TILE_HOST_DEVICE static void
239 print(const pad<LowLength, LeftPadLength, RightPadLength, SkipIsValidCheck>& p)
240 {
241  printf("pad{");
242  printf("up_lengths_: ");
243  print(p.up_lengths_);
244  printf(", left_pad_length_: ");
245  print(p.left_pad_length_);
246  printf(", right_pad_length_: ");
247  print(p.right_pad_length_);
248  printf("}");
249 }
250 
251 template <typename LowLength, typename LeftPadLength, bool SkipIsValidCheck = false>
252 struct left_pad
253 {
256 
257  using UpLengths = decltype(make_tuple(LowLength{} + LeftPadLength{}));
258 
260  LeftPadLength left_pad_length_;
261 
262  CK_TILE_HOST_DEVICE constexpr left_pad() = default;
263 
264  CK_TILE_HOST_DEVICE constexpr left_pad(const LowLength& low_length,
265  const LeftPadLength& left_pad_length)
266  : up_lengths_{make_tuple(low_length + left_pad_length)}, left_pad_length_{left_pad_length}
267  {
268  }
269 
270  CK_TILE_HOST_DEVICE constexpr const auto& get_upper_lengths() const { return up_lengths_; }
271 
272  template <typename LowIdx, typename UpIdx>
273  CK_TILE_HOST_DEVICE constexpr void calculate_lower_index(LowIdx& idx_low,
274  const UpIdx& idx_up) const
275  {
276  static_assert(LowIdx::size() == 1 && UpIdx::size() == 1,
277  "wrong! inconsistent # of dimension");
278 
279  idx_low(number<0>{}) = idx_up[number<0>{}] - left_pad_length_;
280  }
281 
282  template <typename LowIdxDiff, typename UpIdxDiff, typename LowIdx, typename UpIdx>
283  CK_TILE_HOST_DEVICE static void update_lower_index(LowIdxDiff& idx_diff_low,
284  const UpIdxDiff& idx_diff_up,
285  LowIdx& idx_low,
286  const UpIdx&)
287  {
288  static_assert(LowIdxDiff::size() == 1 && UpIdxDiff::size() == 1 && LowIdx::size() == 1 &&
289  UpIdx::size() == 1,
290  "wrong! inconsistent # of dimension");
291 
292  constexpr auto I0 = number<0>{};
293 
294  idx_diff_low[I0] = idx_diff_up[I0];
295 
296  idx_low += idx_diff_low;
297  }
298 
299  CK_TILE_HOST_DEVICE static constexpr bool
301  {
302  return SkipIsValidCheck;
303  }
304 
305  template <typename UpIdx>
306  CK_TILE_HOST_DEVICE constexpr bool
308  {
309  return SkipIsValidCheck || (idx_up[number<0>{}] >= left_pad_length_);
310  }
311 
313  {
316  }
317 
318  // MUST be static function
319  template <typename LowVectorLengths, typename LowVectorStrides>
320  CK_TILE_HOST_DEVICE static constexpr auto
321  calculate_upper_dimension_safe_vector_length_strides(const LowVectorLengths& low_vector_lengths,
322  const LowVectorStrides& low_vector_strides)
323  {
324  // TODO: we allow pass through this vector length. If one need per-pixel check,
325  // should change the guaranteed vector length while creating the tensor view.
326  // It's up to runtime to check the padding length should be multiple of vector length
327  return make_tuple(low_vector_lengths, low_vector_strides);
328  }
329 };
330 
331 template <typename LowLength, typename LeftPadLength, bool SkipIsValidCheck>
332 CK_TILE_HOST_DEVICE static void
333 print(const left_pad<LowLength, LeftPadLength, SkipIsValidCheck>& lp)
334 {
335  printf("left_pad{");
336  printf("up_lengths_: ");
337  print(lp.up_lengths_);
338  printf(", left_pad_length_: ");
339  print(lp.left_pad_length_);
340  printf("}");
341 }
342 
343 template <typename LowLength, typename RightPadLength, bool SkipIsValidCheck = false>
344 struct right_pad : public base_transform<1, 1>
345 {
348 
349  using UpLengths = decltype(make_tuple(LowLength{} + RightPadLength{}));
350 
352  LowLength low_length_;
353  RightPadLength right_pad_length_;
354 
355  CK_TILE_HOST_DEVICE constexpr right_pad() = default;
356 
357  CK_TILE_HOST_DEVICE constexpr right_pad(const LowLength& low_length,
358  const RightPadLength& right_pad_length)
359  : up_lengths_{make_tuple(low_length + right_pad_length)},
360  low_length_{low_length},
361  right_pad_length_{right_pad_length}
362  {
363  }
364 
365  CK_TILE_HOST_DEVICE constexpr const auto& get_upper_lengths() const { return up_lengths_; }
366 
367  template <typename LowIdx, typename UpIdx>
368  CK_TILE_HOST_DEVICE static constexpr void calculate_lower_index(LowIdx& idx_low,
369  const UpIdx& idx_up)
370  {
371  static_assert(LowIdx::size() == 1 && UpIdx::size() == 1,
372  "wrong! inconsistent # of dimension");
373 
374  idx_low(number<0>{}) = idx_up[number<0>{}];
375  }
376 
377  template <typename LowIdxDiff, typename UpIdxDiff, typename LowIdx, typename UpIdx>
378  CK_TILE_HOST_DEVICE static void update_lower_index(LowIdxDiff& idx_diff_low,
379  const UpIdxDiff& idx_diff_up,
380  LowIdx& idx_low,
381  const UpIdx&)
382  {
383  static_assert(LowIdxDiff::size() == 1 && UpIdxDiff::size() == 1 && LowIdx::size() == 1 &&
384  UpIdx::size() == 1,
385  "wrong! inconsistent # of dimension");
386 
387  constexpr auto I0 = number<0>{};
388 
389  idx_diff_low[I0] = idx_diff_up[I0];
390 
391  idx_low += idx_diff_low;
392  }
393 
394  CK_TILE_HOST_DEVICE static constexpr bool
396  {
397  return SkipIsValidCheck;
398  }
399 
400  template <typename UpIdx>
401  CK_TILE_HOST_DEVICE constexpr bool
403  {
404  return SkipIsValidCheck || (idx_up[number<0>{}] < low_length_);
405  }
406 
408  {
412  }
413 
414  // MUST be static function
415  template <typename LowVectorLengths, typename LowVectorStrides>
416  CK_TILE_HOST_DEVICE static constexpr auto
417  calculate_upper_dimension_safe_vector_length_strides(const LowVectorLengths& low_vector_lengths,
418  const LowVectorStrides& low_vector_strides)
419  {
420  // TODO: we allow pass through this vector length. If one need per-pixel check,
421  // should change the guaranteed vector length while creating the tensor view.
422  // It's up to runtime to check the padding length should be multiple of vector length
423  return make_tuple(low_vector_lengths, low_vector_strides);
424  }
425 };
426 
427 template <typename LowLength, typename RightPadLength, bool SkipIsValidCheck>
428 CK_TILE_HOST_DEVICE static void
429 print(const right_pad<LowLength, RightPadLength, SkipIsValidCheck>& rp)
430 {
431  printf("right_pad{");
432  printf("up_lengths_: ");
433  print(rp.up_lengths_);
434  printf(", right_pad_length_: ");
435  print(rp.right_pad_length_);
436  printf("}");
437 }
438 
439 // idx_low = coefficients[0, ...nDimUp-1] * idx_up[0, ...nDimUp-1]
440 // UpLengths and Coefficients can be either of the followings:
441 // 1) Tuple of index_t, which is known at run-time, or
442 // 2) Tuple of number, which is known at compile-time, or
443 // 3) Tuple of mixture of index_t and number, which is known partially at run-time and partially
444 // at compile-time
445 template <typename UpLengths,
446  typename Coefficients,
447  typename std::enable_if<UpLengths::size() == Coefficients::size(), bool>::type = false>
448 struct embed : public base_transform<1, UpLengths::size()>
449 {
450  static constexpr index_t NDimUp = UpLengths::size();
451 
454 
455  UpLengths up_lengths_;
456  Coefficients coefficients_;
457 
458  CK_TILE_HOST_DEVICE constexpr embed() = default;
459 
460  CK_TILE_HOST_DEVICE constexpr embed(const UpLengths& up_lengths,
461  const Coefficients& coefficients)
462  : up_lengths_{up_lengths}, coefficients_{coefficients}
463  {
464  }
465 
466  CK_TILE_HOST_DEVICE static constexpr auto get_type_enum()
467  {
469  }
470 
471  CK_TILE_HOST_DEVICE constexpr const auto& get_upper_lengths() const { return up_lengths_; }
472 
473  template <typename LowIdx, typename UpIdx>
474  CK_TILE_HOST_DEVICE constexpr void calculate_lower_index(LowIdx& idx_low,
475  const UpIdx& idx_up) const
476  {
477  static_assert(LowIdx::size() == 1 && UpIdx::size() == NDimUp,
478  "wrong! inconsistent # of dimension");
479 
480  idx_low(number<0>{}) = 0;
481 
482  static_for<0, NDimUp, 1>{}([&idx_low, &idx_up, this](auto i) {
483  idx_low(number<0>{}) += idx_up[i] * this->coefficients_[i];
484  });
485  }
486 
487  template <typename LowIdxDiff, typename UpIdxDiff, typename LowIdx, typename UpIdx>
488  CK_TILE_HOST_DEVICE void update_lower_index(LowIdxDiff& idx_diff_low,
489  const UpIdxDiff& idx_diff_up,
490  LowIdx& idx_low,
491  const UpIdx&) const
492  {
493  static_assert(LowIdxDiff::size() == 1 && UpIdxDiff::size() == NDimUp &&
494  LowIdx::size() == 1 && UpIdx::size() == NDimUp,
495  "wrong! inconsistent # of dimension");
496 
497  idx_diff_low(number<0>{}) = 0;
498 
500  [&](auto i) { idx_diff_low(number<0>{}) += idx_diff_up[i] * coefficients_[i]; });
501 
502  idx_low += idx_diff_low;
503  }
504 
505  CK_TILE_HOST_DEVICE static constexpr bool
507  {
508  return true;
509  }
510 
511  template <typename UpIdx>
512  CK_TILE_HOST_DEVICE static constexpr bool
514  {
515  return true;
516  }
517 
519  {
522  }
523 };
524 
525 template <typename UpLengths, typename Coefficients>
526 CK_TILE_HOST_DEVICE static void print(const embed<UpLengths, Coefficients>& e)
527 {
528  printf("embed{");
529  printf("up_lengths_: ");
530  print(e.up_lengths_);
531  printf(", coefficients_: ");
532  print(e.coefficients_);
533  printf("}");
534 }
535 
536 template <typename LowLengths>
538 {
539  template <index_t I>
540  CK_TILE_HOST_DEVICE constexpr auto operator()(number<I> i) const
541  {
542  return magic_division::calculate_magic_numbers(LowLengths{}[i]);
543  }
544 };
545 
546 // Implementation of "merge" transformation primitive that uses magic-number-division to do lowering
547 // of both multi-index and delta of multi-index
548 // Caution:
549 // 1. The magic number division implementation being used would produce correct result if the
550 // dividended is uint32_t and its value is with in 31-bit value range of uint32_t.
551 // 2. The magic number division for int32_t dividened has not been implemented, the int32_t
552 // dividend would be bit-wise interpreted as uint32_t and magic number division implementation for
553 // uint32_t is then used.
554 // 3. For merge primitive, upper-index is the dividend.
555 // 4. When upper-index is uint32_t, its value need to be within 31-bit range.
556 // 5. When upper-index is int32_t type (when index_t is int32_t), its value need to be
557 // non-negative.
558 template <typename LowLengths>
559 struct merge_v2_magic_division : public base_transform<LowLengths::size(), 1>
560 {
561  static constexpr index_t NDimLow = LowLengths::size();
562 
565 
566  using UpLengths =
567  decltype(make_tuple(container_reduce(LowLengths{}, multiplies{}, number<1>{})));
568 
571  number<NDimLow>{}));
572 
573  LowLengths low_lengths_;
576 
577  static constexpr auto I0 = number<0>{};
578  static constexpr auto I1 = number<1>{};
579 
581 
582  CK_TILE_HOST_DEVICE constexpr merge_v2_magic_division(const LowLengths& low_lengths)
583  : low_lengths_{low_lengths},
585  [&](auto i) { return magic_division::calculate_magic_numbers(low_lengths[i]); },
586  number<NDimLow>{})},
587  up_lengths_{make_tuple(container_reduce(low_lengths, multiplies{}, I1))}
588  {
589  static_assert(LowerIndex::size() == NDimLow, "wrong!");
590  }
591 
592  CK_TILE_HOST_DEVICE static constexpr auto get_type_enum()
593  {
595  }
596 
597  CK_TILE_HOST_DEVICE constexpr const auto& get_upper_lengths() const { return up_lengths_; }
598 
599  template <typename LowIdx, typename UpIdx>
600  CK_TILE_HOST_DEVICE constexpr void calculate_lower_index(LowIdx& idx_low,
601  const UpIdx& idx_up) const
602  {
603  static_assert(LowIdx::size() == NDimLow && UpIdx::size() == 1,
604  "wrong! inconsistent # of dimension");
605 
606  index_t tmp = idx_up[I0];
607 
608  static_for<NDimLow - 1, 0, -1>{}([&, this](auto i) {
609  index_t tmp2 =
611  this->low_lengths_magic_divisor_[i][I0],
612  this->low_lengths_magic_divisor_[i][I1]);
613  idx_low(i) = tmp - tmp2 * this->low_lengths_[i];
614  tmp = tmp2;
615  });
616 
617  idx_low(number<0>{}) = tmp;
618  }
619 
620  template <typename LowIdxDiff, typename UpIdxDiff, typename LowIdx, typename UpIdx>
621  CK_TILE_HOST_DEVICE void update_lower_index(LowIdxDiff& idx_diff_low,
622  const UpIdxDiff&,
623  LowIdx& idx_low,
624  const UpIdx& idx_up_new) const
625  {
626  static_assert(LowIdxDiff::size() == NDimLow && UpIdxDiff::size() == 1 &&
627  LowIdx::size() == NDimLow && UpIdx::size() == 1,
628  "wrong! inconsistent # of dimension");
629 
630  index_t tmp = idx_up_new[number<0>{}];
631 
632  static_for<NDimLow - 1, 0, -1>{}([&, this](auto i) {
633  index_t tmp2 =
635  this->low_lengths_magic_divisor_[i][I0],
636  this->low_lengths_magic_divisor_[i][I1]);
637 
638  index_t idx_low_old = idx_low[i];
639 
640  idx_low(i) = tmp - tmp2 * this->low_lengths_[i];
641  tmp = tmp2;
642 
643  idx_diff_low(i) = idx_low[i] - idx_low_old;
644  });
645 
646  idx_diff_low(number<0>{}) = tmp - idx_low(number<0>{});
647 
648  idx_low(number<0>{}) = tmp;
649  }
650 
651  CK_TILE_HOST_DEVICE static constexpr bool
653  {
654  return true;
655  }
656 
658  {
662  }
663 
664  template <typename UpIdx>
665  CK_TILE_HOST_DEVICE static constexpr bool
667  {
668  return true;
669  }
670 
671  // MUST be static function
672  template <typename LowVectorLengths, typename LowVectorStrides>
673  CK_TILE_HOST_DEVICE static constexpr auto
674  calculate_upper_dimension_safe_vector_length_strides(const LowVectorLengths& low_vector_lengths,
675  const LowVectorStrides& low_vector_strides)
676  {
677  array<index_t, 1> up_vector_lengths{-1};
678  array<index_t, 1> up_vector_strides{-1};
679 
680  up_vector_lengths[0] = low_vector_lengths[number<NDimLow - 1>{}];
681  up_vector_strides[0] = low_vector_strides[number<NDimLow - 1>{}];
682 
683  return make_tuple(up_vector_lengths, up_vector_strides);
684  }
685 };
686 
687 template <typename LowLengths>
688 CK_TILE_HOST_DEVICE static void print(const merge_v2_magic_division<LowLengths>& m)
689 {
690  printf("merge_v2_magic_division{");
691  printf("low_lengths_: ");
692  print(m.low_lengths_);
693  printf(", up_lengths_: ");
694  print(m.up_lengths_);
695  printf("}");
696 }
697 
698 // Implementation of "merge" transformation primitive that uses division and mod. It is supposed to
699 // be used for low_lengths that are known at compile time and are power of 2, otherwise performance
700 // will be very bad
701 template <typename LowLengths>
702 struct merge_v3_division_mod : public base_transform<LowLengths::size(), 1>
703 {
704  static constexpr index_t NDimLow = LowLengths::size();
705 
708 
710  decltype(container_reverse_exclusive_scan(LowLengths{}, multiplies{}, number<1>{}));
711 
712  using UpLengths =
713  decltype(make_tuple(container_reduce(LowLengths{}, multiplies{}, number<1>{})));
714 
715  LowLengths low_lengths_;
718 
720 
721  CK_TILE_HOST_DEVICE constexpr merge_v3_division_mod(const LowLengths& low_lengths)
722  : low_lengths_{low_lengths},
723  low_lengths_scan_{
724  container_reverse_exclusive_scan(low_lengths, multiplies{}, number<1>{})},
725  up_lengths_{make_tuple(container_reduce(low_lengths, multiplies{}, number<1>{}))}
726  {
727  static_assert(LowerIndex::size() == NDimLow, "wrong!");
728  }
729 
730  CK_TILE_HOST_DEVICE constexpr const auto& get_upper_lengths() const { return up_lengths_; }
731 
732  template <typename LowIdx, typename UpIdx>
733  CK_TILE_HOST_DEVICE constexpr void calculate_lower_index(LowIdx& idx_low,
734  const UpIdx& idx_up) const
735  {
736  static_assert(LowIdx::size() == NDimLow && UpIdx::size() == 1,
737  "wrong! inconsistent # of dimension");
738 
739  index_t tmp = idx_up[number<0>{}];
740 
741  // division and mod
742  static_for<0, NDimLow - 1, 1>{}([&](auto i) {
743  idx_low(i) = tmp / this->low_lengths_scan_[i];
744  tmp %= this->low_lengths_scan_[i];
745  });
746 
747  idx_low(number<NDimLow - 1>{}) = tmp;
748  }
749 
750  template <typename LowIdxDiff, typename UpIdxDiff, typename LowIdx, typename UpIdx>
751  CK_TILE_HOST_DEVICE void update_lower_index(LowIdxDiff& idx_diff_low,
752  const UpIdxDiff&,
753  LowIdx& idx_low,
754  const UpIdx& idx_up_new) const
755  {
756  static_assert(LowIdxDiff::size() == NDimLow && UpIdxDiff::size() == 1 &&
757  LowIdx::size() == NDimLow && UpIdx::size() == 1,
758  "wrong! inconsistent # of dimension");
759 
760  constexpr auto I0 = number<0>{};
761  constexpr auto INm1 = number<NDimLow - 1>{};
762 
763  index_t tmp = idx_up_new[I0];
764 
765  static_for<0, NDimLow - 1, 1>{}([&](auto i) {
766  const index_t tmp2 = idx_low[i];
767  idx_low(i) = tmp / this->low_lengths_scan_[i];
768  idx_diff_low(i) = idx_low[i] - tmp2;
769  tmp %= this->low_lengths_scan_[i];
770  });
771 
772  const index_t tmp2 = idx_low[INm1];
773  idx_low(INm1) = tmp;
774  idx_diff_low(INm1) = idx_low[INm1] - tmp2;
775  }
776 
777  CK_TILE_HOST_DEVICE static constexpr bool
779  {
780  return true;
781  }
782 
784  {
788  }
789 
790  template <typename UpIdx>
791  CK_TILE_HOST_DEVICE static constexpr bool
793  {
794  return true;
795  }
796 
797  // MUST be static function
798  template <typename LowVectorLengths, typename LowVectorStrides>
799  CK_TILE_HOST_DEVICE static constexpr auto
800  calculate_upper_dimension_safe_vector_length_strides(const LowVectorLengths& low_vector_lengths,
801  const LowVectorStrides& low_vector_strides)
802  {
803  array<index_t, 1> up_vector_lengths{-1};
804  array<index_t, 1> up_vector_strides{-1};
805 
806  up_vector_lengths[0] = low_vector_lengths[number<NDimLow - 1>{}];
807  up_vector_strides[0] = low_vector_strides[number<NDimLow - 1>{}];
808 
809  return make_tuple(up_vector_lengths, up_vector_strides);
810  }
811 };
812 
813 template <typename LowLengths>
814 CK_TILE_HOST_DEVICE static void print(const merge_v3_division_mod<LowLengths>& m)
815 {
816  printf("merge_v3_division_mod{");
817  printf("low_lengths_: ");
818  print(m.low_lengths_);
819  printf(", low_lengths_scan_: ");
820  print(m.low_lengths_scan_);
821  printf(", up_lengths_: ");
822  print(m.up_lengths_);
823  printf("}");
824 }
825 
826 template <typename UpLengths, bool Use24BitIntegerCalculation>
827 struct unmerge : public base_transform<1, UpLengths::size()>
828 {
829  static constexpr index_t NDimUp = UpLengths::size();
830 
833 
835  decltype(container_reverse_exclusive_scan(UpLengths{}, multiplies{}, number<1>{}));
836 
837  UpLengths up_lengths_;
839 
840  CK_TILE_HOST_DEVICE constexpr unmerge() = default;
841 
842  CK_TILE_HOST_DEVICE constexpr unmerge(const UpLengths& up_lengths)
843  : up_lengths_{up_lengths},
844  up_lengths_scan_{container_reverse_exclusive_scan(up_lengths, multiplies{}, number<1>{})}
845  {
846  }
847 
848  CK_TILE_HOST_DEVICE static constexpr auto get_type_enum()
849  {
851  }
852 
853  CK_TILE_HOST_DEVICE constexpr const auto& get_upper_lengths() const { return up_lengths_; }
854 
855  template <typename LowIdx, typename UpIdx>
856  CK_TILE_HOST_DEVICE constexpr void calculate_lower_index(LowIdx& idx_low,
857  const UpIdx& idx_up) const
858  {
859  if constexpr(!Use24BitIntegerCalculation)
860  {
861  idx_low(number<0>{}) = idx_up[number<NDimUp - 1>{}];
862 
863  static_for<0, NDimUp - 1, 1>{}(
864  [&](auto i) { idx_low(number<0>{}) += idx_up[i] * up_lengths_scan_[i]; });
865  }
866  else
867  {
868  idx_low(number<0>{}) = idx_up[number<NDimUp - 1>{}];
869 
870  static_for<0, NDimUp - 1, 1>{}([&](auto i) {
871  idx_low(number<0>{}) =
872  (0x00ffffff & idx_low[number<0>{}]) +
873  (0x00ffffff & idx_up[i]) * (0x00ffffff & up_lengths_scan_[i]);
874  });
875  }
876  }
877 
878  template <typename LowIdxDiff, typename UpIdxDiff, typename LowIdx, typename UpIdx>
879  CK_TILE_HOST_DEVICE void update_lower_index(LowIdxDiff& idx_diff_low,
880  const UpIdxDiff& idx_diff_up,
881  LowIdx& idx_low,
882  const UpIdx&) const
883  {
884  calculate_lower_index(idx_diff_low, idx_diff_up);
885 
886  idx_low += idx_diff_low;
887  }
888 
889  CK_TILE_HOST_DEVICE static constexpr bool
891  {
892  return true;
893  }
894 
895  template <typename UpIdx>
896  CK_TILE_HOST_DEVICE static constexpr bool
898  {
899  return true;
900  }
901 
903  {
906  }
907 
908  // MUST be static function
909  template <typename LowVectorLengths, typename LowVectorStrides>
910  CK_TILE_HOST_DEVICE static constexpr auto
911  calculate_upper_dimension_safe_vector_length_strides(const LowVectorLengths& low_vector_lengths,
912  const LowVectorStrides& low_vector_strides)
913  {
914  array<index_t, NDimUp> up_vector_lengths{-1};
915  array<index_t, NDimUp> up_vector_strides{-1};
916 
917  constexpr auto up_length_last = UpLengths{}[number<NDimUp - 1>{}];
918 
919  if constexpr(ck_tile::is_known_at_compile_time<decltype(up_length_last)>::value)
920  {
921  if(low_vector_lengths[0] != -1)
922  {
923  up_vector_lengths(NDimUp - 1) = gcd(low_vector_lengths[0], up_length_last);
924  }
925  }
926 
927  up_vector_strides(NDimUp - 1) = low_vector_strides[0];
928 
929  return make_tuple(up_vector_lengths, up_vector_strides);
930  }
931 };
932 
933 template <typename UpLengths, bool Use24BitIntegerCalculation>
934 CK_TILE_HOST_DEVICE static void print(const unmerge<UpLengths, Use24BitIntegerCalculation>& u)
935 {
936  printf("unmerge{");
937  printf("up_lengths_: ");
938  print(u.up_lengths_);
939  printf(", up_lengths_scan_: ");
940  print(u.up_lengths_scan_);
941  printf("}");
942 }
943 
944 template <typename LowerIndex>
945 struct freeze : public base_transform<1, 0>
946 {
947  LowerIndex low_idx_;
948 
949  CK_TILE_HOST_DEVICE constexpr freeze() = default;
950 
951  CK_TILE_HOST_DEVICE constexpr freeze(const LowerIndex& low_idx) : low_idx_{low_idx} {}
952 
953  CK_TILE_HOST_DEVICE static constexpr auto get_upper_lengths() { return tuple<>{}; }
954 
955  template <typename LowIdx, typename UpIdx>
956  CK_TILE_HOST_DEVICE constexpr void calculate_lower_index(LowIdx& idx_low,
957  const UpIdx& /* idx_up */) const
958  {
959  static_assert(LowIdx::size() == 1 && UpIdx::size() == 0,
960  "wrong! inconsistent # of dimension");
961 
962  idx_low(number<0>{}) = low_idx_;
963  }
964 
965  template <typename LowIdxDiff, typename UpIdxDiff, typename LowIdx, typename UpIdx>
966  CK_TILE_HOST_DEVICE static void update_lower_index(LowIdxDiff& idx_diff_low,
967  const UpIdxDiff& /* idx_diff_up */,
968  LowIdx& /* idx_low */,
969  const UpIdx& /* idx_up_new */)
970  {
971  idx_diff_low(number<0>{}) = 0;
972  }
973 
974  CK_TILE_HOST_DEVICE static constexpr bool
976  {
977  return true;
978  }
979 
980  template <typename UpIdx>
981  CK_TILE_HOST_DEVICE static constexpr bool
983  {
984  return true;
985  }
986 
988  {
990  }
991 };
992 
993 template <typename LowerIndex>
994 CK_TILE_HOST_DEVICE static void print(const freeze<LowerIndex>& f)
995 {
996  printf("freeze{");
997  printf("low_idx_: ");
998  print(f.low_idx_);
999  printf("}");
1000 }
1001 
1002 // insert a dangling upper dimension without lower dimension
1003 template <typename UpperLength>
1004 struct insert : public base_transform<0, 1>
1005 {
1006  using UpLengths = decltype(make_tuple(UpperLength{}));
1007 
1009 
1010  CK_TILE_HOST_DEVICE constexpr insert() = default;
1011 
1012  CK_TILE_HOST_DEVICE constexpr insert(const UpperLength& up_length)
1013  : up_lengths_{make_tuple(up_length)}
1014  {
1015  }
1016 
1017  CK_TILE_HOST_DEVICE static constexpr index_t get_num_of_lower_dimension() { return 0; }
1018 
1019  CK_TILE_HOST_DEVICE static constexpr index_t get_num_of_upper_dimension() { return 1; }
1020 
1021  CK_TILE_HOST_DEVICE constexpr auto get_upper_lengths() const { return up_lengths_; }
1022 
1023  template <typename LowIdx, typename UpIdx>
1024  CK_TILE_HOST_DEVICE constexpr void calculate_lower_index(LowIdx&, const UpIdx&) const
1025  {
1026  static_assert(LowIdx::size() == 0 && UpIdx::size() == 1,
1027  "wrong! inconsistent # of dimension");
1028  }
1029 
1030  template <typename LowIdxDiff, typename UpIdxDiff, typename LowIdx, typename UpIdx>
1031  CK_TILE_HOST_DEVICE static void
1032  update_lower_index(LowIdxDiff&, const UpIdxDiff&, LowIdx&, const UpIdx&)
1033  {
1034  static_assert(LowIdxDiff::size() == 0 && UpIdxDiff::size() == 1 && LowIdx::size() == 0 &&
1035  UpIdx::size() == 1,
1036  "wrong! inconsistent # of dimension");
1037  }
1038 
1039  CK_TILE_HOST_DEVICE static constexpr bool IsLinearTransform() { return true; }
1040 
1041  CK_TILE_HOST_DEVICE static constexpr bool
1043  {
1044  return true;
1045  }
1046 
1047  template <typename UpIdx>
1048  CK_TILE_HOST_DEVICE static constexpr bool
1050  {
1051  return true;
1052  }
1053 
1055  {
1057  }
1058 };
1059 
1060 template <typename UpperLength>
1061 CK_TILE_HOST_DEVICE static void print(const insert<UpperLength>& i)
1062 {
1063  printf("insert{");
1064  printf("up_lengths_: ");
1065  print(i.up_lengths_);
1066  printf("}");
1067 }
1068 
1069 // replicate the original tensor and create a higher dimensional tensor
1070 template <typename UpLengths>
1071 struct replicate : public base_transform<0, UpLengths::size()>
1072 {
1073  static constexpr index_t NDimUp = UpLengths::size();
1074 
1075  CK_TILE_HOST_DEVICE constexpr replicate() = default;
1076 
1077  CK_TILE_HOST_DEVICE constexpr replicate(const UpLengths& up_lengths) : up_lengths_{up_lengths}
1078  {
1079  }
1080 
1081  CK_TILE_HOST_DEVICE constexpr auto get_upper_lengths() const { return up_lengths_; }
1082 
1083  template <typename LowIdx, typename UpIdx>
1084  CK_TILE_HOST_DEVICE constexpr void calculate_lower_index(LowIdx&, const UpIdx&) const
1085  {
1086  static_assert(LowIdx::size() == 0 && UpIdx::size() == NDimUp,
1087  "wrong! inconsistent # of dimension");
1088  }
1089 
1090  template <typename LowIdxDiff, typename UpIdxDiff, typename LowIdx, typename UpIdx>
1091  CK_TILE_HOST_DEVICE static void
1092  update_lower_index(LowIdxDiff&, const UpIdxDiff&, LowIdx&, const UpIdx&)
1093  {
1094  static_assert(LowIdxDiff::size() == 0 && UpIdxDiff::size() == NDimUp &&
1095  LowIdx::size() == 0 && UpIdx::size() == NDimUp,
1096  "wrong! inconsistent # of dimension");
1097  }
1098 
1099  CK_TILE_HOST_DEVICE static constexpr bool
1101  {
1102  return true;
1103  }
1104 
1105  template <typename UpIdx>
1106  CK_TILE_HOST_DEVICE static constexpr bool
1108  {
1109  return true;
1110  }
1111 
1113  {
1115  }
1116 
1117  //
1118  UpLengths up_lengths_;
1119 };
1120 
1121 template <typename UpLengths>
1122 CK_TILE_HOST_DEVICE static void print(const replicate<UpLengths>& r)
1123 {
1124  printf("replicate{");
1125  printf("up_lengths_: ");
1126  print(r.up_lengths_);
1127  printf("}");
1128 }
1129 
1130 template <typename LowLength, typename SliceBegin, typename SliceEnd>
1131 struct slice : public base_transform<1, 1>
1132 {
1135 
1136  using UpLengths = decltype(make_tuple(SliceEnd{} - SliceBegin{}));
1137 
1139  SliceBegin slice_begin_;
1140  SliceEnd slice_end_;
1141 
1142  CK_TILE_HOST_DEVICE constexpr slice() = default;
1143 
1144  CK_TILE_HOST_DEVICE constexpr slice(const LowLength&,
1145  const SliceBegin& slice_begin,
1146  const SliceEnd& slice_end)
1147  : up_lengths_{make_tuple(slice_end - slice_begin)},
1148  slice_begin_{slice_begin},
1149  slice_end_{slice_end}
1150  {
1151  }
1152 
1153  CK_TILE_HOST_DEVICE constexpr const auto& get_upper_lengths() const { return up_lengths_; }
1154 
1155  template <typename LowIdx, typename UpIdx>
1156  CK_TILE_HOST_DEVICE constexpr void calculate_lower_index(LowIdx& idx_low,
1157  const UpIdx& idx_up) const
1158  {
1159  static_assert(LowIdx::size() == 1 && UpIdx::size() == 1,
1160  "wrong! inconsistent # of dimension");
1161 
1162  idx_low(number<0>{}) = idx_up[number<0>{}] + slice_begin_;
1163  }
1164 
1165  template <typename LowIdxDiff, typename UpIdxDiff, typename LowIdx, typename UpIdx>
1166  CK_TILE_HOST_DEVICE static void update_lower_index(LowIdxDiff& idx_diff_low,
1167  const UpIdxDiff& idx_diff_up,
1168  LowIdx& idx_low,
1169  const UpIdx&)
1170  {
1171  static_assert(LowIdxDiff::size() == 1 && UpIdxDiff::size() == 1 && LowIdx::size() == 1 &&
1172  UpIdx::size() == 1,
1173  "wrong! inconsistent # of dimension");
1174 
1175  constexpr auto I0 = number<0>{};
1176 
1177  idx_diff_low[I0] = idx_diff_up[I0];
1178 
1179  idx_low += idx_diff_low;
1180  }
1181 
1182  CK_TILE_HOST_DEVICE static constexpr bool
1184  {
1185  return true;
1186  }
1187 
1188  template <typename UpIdx>
1189  CK_TILE_HOST_DEVICE constexpr bool
1191  {
1192  return true;
1193  }
1194 
1196  {
1200  }
1201 };
1202 
1203 template <typename LowLength, typename SliceBegin, typename SliceEnd>
1204 CK_TILE_HOST_DEVICE static void print(const slice<LowLength, SliceBegin, SliceEnd>& s)
1205 {
1206  printf("slice{");
1207  printf("up_lengths_: ");
1208  print(s.up_lengths_);
1209  printf(", slice_begin_: ");
1210  print(s.slice_begin_);
1211  printf(", slice_end_: ");
1212  print(s.slice_end_);
1213  printf("}");
1214 }
1215 
1216 /*
1217  * \brief lower_idx = upper_idx % modulus.
1218  * TODO: Need an improved implementation since the modulo operation is expensive.
1219  */
1220 template <typename Modulus, typename UpLength>
1221 struct modulo : public base_transform<1, 1>
1222 {
1225  using UpLengths = decltype(make_tuple(UpLength{}));
1226 
1227  Modulus modulus_;
1229 
1230  CK_TILE_HOST_DEVICE constexpr modulo() = default;
1231 
1232  CK_TILE_HOST_DEVICE constexpr modulo(const Modulus& modulus, const UpLength& up_length)
1233  : modulus_{modulus}, up_lengths_{make_tuple(up_length)}
1234  {
1235  }
1236 
1237  CK_TILE_HOST_DEVICE constexpr const auto& get_upper_lengths() const { return up_lengths_; }
1238 
1239  template <typename LowIdx, typename UpIdx>
1240  CK_TILE_HOST_DEVICE constexpr void calculate_lower_index(LowIdx& idx_low,
1241  const UpIdx& idx_up) const
1242  {
1243  static_assert(LowIdx::size() == 1 && UpIdx::size() == 1,
1244  "wrong! inconsistent # of dimension");
1245 
1246  idx_low(number<0>{}) = idx_up[number<0>{}] % modulus_;
1247  }
1248 
1249  template <typename LowIdxDiff, typename UpIdxDiff, typename LowIdx, typename UpIdx>
1250  CK_TILE_HOST_DEVICE void update_lower_index(LowIdxDiff& idx_diff_low,
1251  const UpIdxDiff& idx_diff_up,
1252  LowIdx& idx_low,
1253  const UpIdx& up_idx) const
1254  {
1255  static_assert(LowIdxDiff::size() == 1 && UpIdxDiff::size() == 1 && LowIdx::size() == 1 &&
1256  UpIdx::size() == 1,
1257  "wrong! inconsistent # of dimension");
1258 
1259  constexpr auto I0 = number<0>{};
1260 
1261  const auto idx_low_old = idx_low;
1262  idx_low[I0] = (up_idx[I0] + idx_diff_up[I0]) % modulus_;
1263  idx_diff_low[I0] = idx_low - idx_low_old;
1264  }
1265 
1266  CK_TILE_HOST_DEVICE static constexpr bool
1268  {
1269  return true;
1270  }
1271 
1272  template <typename UpIdx>
1273  CK_TILE_HOST_DEVICE static constexpr bool
1275  {
1276  return true;
1277  }
1278 
1280  {
1282  }
1283 };
1284 
1285 template <typename Modulus, typename UpLength>
1286 CK_TILE_HOST_DEVICE static void print(const modulo<Modulus, UpLength>& m)
1287 {
1288  printf("modulo{");
1289  printf("modulus_: ");
1290  print(m.modulus_);
1291  printf(", up_lengths_: ");
1292  print(m.up_lengths_);
1293  printf("}");
1294 }
1295 
1296 // 2D XOR, NOTE: "xor" is a keyword
1297 template <typename LowLengths>
1298 struct xor_t : public base_transform<2, 2>
1299 {
1300  static constexpr auto type_enum = coord_transform_enum::xor_t;
1301 
1304 
1305  using UpLengths = LowLengths;
1306 
1308 
1309  CK_TILE_HOST_DEVICE constexpr xor_t() : up_lengths_{} {}
1310 
1311  CK_TILE_HOST_DEVICE constexpr xor_t(const LowLengths& low_lengths) : up_lengths_{low_lengths} {}
1312 
1313  CK_TILE_HOST_DEVICE static constexpr auto get_type_enum()
1314  {
1315  return coord_transform_enum::xor_t;
1316  }
1317 
1318  CK_TILE_HOST_DEVICE constexpr const auto& get_upper_lengths() const { return up_lengths_; }
1319 
1320  template <typename LowIdx, typename UpIdx>
1321  CK_TILE_HOST_DEVICE constexpr void calculate_lower_index(LowIdx& idx_low,
1322  const UpIdx& idx_up) const
1323  {
1324  static_assert(LowIdx::size() == 2 && UpIdx::size() == 2,
1325  "wrong! inconsistent # of dimension");
1326 
1327  idx_low(number<0>{}) = idx_up[number<0>{}];
1328 
1329  idx_low(number<1>{}) =
1330  idx_up[number<1>{}] ^ (idx_up[number<0>{}] % up_lengths_[number<1>{}]);
1331  }
1332 
1333  template <typename LowIdxDiff, typename UpIdxDiff, typename LowIdx, typename UpIdx>
1334  CK_TILE_HOST_DEVICE void update_lower_index(LowIdxDiff& idx_diff_low,
1335  const UpIdxDiff&,
1336  LowIdx& idx_low,
1337  const UpIdx& idx_up) const
1338  {
1339  static_assert(LowIdxDiff::size() == 2 && UpIdxDiff::size() == 2 && LowIdx::size() == 2 &&
1340  UpIdx::size() == 2,
1341  "wrong! inconsistent # of dimension");
1342 
1343  const auto idx_low_old = idx_low;
1344 
1345  calculate_lower_index(idx_low, idx_up);
1346 
1347  idx_diff_low = idx_low - idx_low_old;
1348  }
1349 
1350  CK_TILE_HOST_DEVICE static constexpr bool
1352  {
1353  return true;
1354  }
1355 
1356  template <typename UpIdx>
1357  CK_TILE_HOST_DEVICE static constexpr bool
1359  {
1360  return true;
1361  }
1362 
1364  {
1366  }
1367 
1368  // MUST be static function
1369  template <typename LowVectorLengths, typename LowVectorStrides>
1371  const LowVectorLengths& low_vector_lengths,
1372  const LowVectorStrides& low_vector_strides) const
1373  {
1374  array<index_t, 2> up_vector_lengths = low_vector_lengths;
1375  array<index_t, 2> up_vector_strides = low_vector_strides;
1376 
1377  return make_tuple(up_vector_lengths, up_vector_strides);
1378  }
1379 };
1380 
1381 template <typename LowLengths>
1382 CK_TILE_HOST_DEVICE static void print(const xor_t<LowLengths>& x)
1383 {
1384  printf("xor_t{");
1385  printf("up_lengths_: ");
1386  print(x.up_lengths_);
1387  printf("}");
1388 }
1389 
1390 template <typename LowLength, typename OffsetLength>
1391 struct offset : public base_transform<1, 1>
1392 {
1395 
1396  using UpLengths = decltype(make_tuple(LowLength{}));
1397 
1399  OffsetLength offset_length_;
1400 
1401  CK_TILE_HOST_DEVICE constexpr offset() = default;
1402 
1403  CK_TILE_HOST_DEVICE constexpr offset(const LowLength& low_length,
1404  const OffsetLength& offset_length)
1405  : up_lengths_{make_tuple(low_length)}, offset_length_{offset_length}
1406  {
1407  }
1408 
1409  CK_TILE_HOST_DEVICE static constexpr auto get_type_enum()
1410  {
1411  return coord_transform_enum::offset;
1412  }
1413 
1414  CK_TILE_HOST_DEVICE constexpr const auto& get_upper_lengths() const { return up_lengths_; }
1415 
1416  template <typename LowIdx, typename UpIdx>
1417  CK_TILE_HOST_DEVICE constexpr void calculate_lower_index(LowIdx& idx_low,
1418  const UpIdx& idx_up) const
1419  {
1420  static_assert(LowIdx::size() == 1 && UpIdx::size() == 1,
1421  "wrong! inconsistent # of dimension");
1422 
1423  idx_low(number<0>{}) = idx_up[number<0>{}] + offset_length_;
1424  }
1425 
1426  template <typename LowIdxDiff, typename UpIdxDiff, typename LowIdx, typename UpIdx>
1427  CK_TILE_HOST_DEVICE static void update_lower_index(LowIdxDiff& idx_diff_low,
1428  const UpIdxDiff& idx_diff_up,
1429  LowIdx& idx_low,
1430  const UpIdx&)
1431  {
1432  static_assert(LowIdxDiff::size() == 1 && UpIdxDiff::size() == 1 && LowIdx::size() == 1 &&
1433  UpIdx::size() == 1,
1434  "wrong! inconsistent # of dimension");
1435 
1436  constexpr auto I0 = number<0>{};
1437 
1438  idx_diff_low[I0] = idx_diff_up[I0];
1439 
1440  idx_low += idx_diff_low;
1441  }
1442 
1443  CK_TILE_HOST_DEVICE static constexpr bool
1445  {
1446  return true;
1447  }
1448 
1449  template <typename UpIdx>
1450  CK_TILE_HOST_DEVICE constexpr bool
1452  {
1453  return true;
1454  }
1455 
1457  {
1460  }
1461 };
1462 
1463 template <typename LowLength, typename OffsetLength>
1464 CK_TILE_HOST_DEVICE static void print(const offset<LowLength, OffsetLength>& o)
1465 {
1466  printf("offset{");
1467  printf("up_lengths_: ");
1468  print(o.up_lengths_);
1469  printf(", offset_length_: ");
1470  print(o.offset_length_);
1471  printf("}");
1472 }
1473 
1474 template <typename UpLength, typename IndexingAdaptor>
1475 struct indexing : public base_transform<1, 1>
1476 {
1477  static constexpr index_t NDimUp = 1;
1478 
1481 
1482  using UpLengths = decltype(make_tuple(UpLength{}));
1484  IndexingAdaptor iadaptor_;
1485 
1486  CK_TILE_HOST_DEVICE constexpr indexing() = default;
1487 
1488  CK_TILE_HOST_DEVICE constexpr indexing(const UpLength& up_length,
1489  const IndexingAdaptor& iadaptor)
1490  : up_lengths_{make_tuple(up_length)}, iadaptor_{iadaptor}
1491  {
1492  }
1493 
1494  CK_TILE_HOST_DEVICE static constexpr auto get_type_enum()
1495  {
1496  return coord_transform_enum::indexing;
1497  }
1498 
1499  CK_TILE_HOST_DEVICE constexpr const auto& get_upper_lengths() const { return up_lengths_; }
1500 
1501  template <typename LowIdx, typename UpIdx>
1502  CK_TILE_HOST_DEVICE constexpr void calculate_lower_index(LowIdx& idx_low,
1503  const UpIdx& idx_up) const
1504  {
1505  static_assert(LowIdx::size() == 1 && UpIdx::size() == NDimUp,
1506  "wrong! inconsistent # of dimension");
1507  iadaptor_.calculate_lower_index(idx_low, idx_up);
1508  }
1509 
1510  template <typename LowIdxDiff, typename UpIdxDiff, typename LowIdx, typename UpIdx>
1511  CK_TILE_HOST_DEVICE void update_lower_index(LowIdxDiff& idx_diff_low,
1512  const UpIdxDiff& idx_diff_up,
1513  LowIdx& idx_low,
1514  const UpIdx& idx_up) const
1515  {
1516  // TODO: nonthing changed here
1517  static_assert(LowIdxDiff::size() == 1 && UpIdxDiff::size() == NDimUp &&
1518  LowIdx::size() == 1 && UpIdx::size() == NDimUp,
1519  "wrong! inconsistent # of dimension");
1520 
1521  iadaptor_.update_lower_index(idx_diff_low, idx_diff_up, idx_low, idx_up);
1522  }
1523 
1524  CK_TILE_HOST_DEVICE static constexpr bool
1526  {
1527  return true;
1528  }
1529 
1530  template <typename UpIdx>
1531  CK_TILE_HOST_DEVICE static constexpr bool
1533  {
1534  return true;
1535  }
1536 
1538  {
1541  }
1542 };
1543 
1544 template <typename UpLength, typename IndexingAdaptor>
1545 CK_TILE_HOST_DEVICE static void print(const indexing<UpLength, IndexingAdaptor>& i)
1546 {
1547  printf("indexing{");
1548  printf("up_lengths_: ");
1549  print(i.up_lengths_);
1550  printf(", iadaptor_: ");
1551  print(i.iadaptor_);
1552  printf("}");
1553 }
1554 
1555 //*******************************************************************************************************
1556 
1557 template <typename LowLength>
1558 CK_TILE_HOST_DEVICE constexpr auto make_pass_through_transform(const LowLength& low_length)
1559 {
1560  return pass_through<LowLength>{low_length};
1561 }
1562 
1563 template <typename LowLength, typename LeftPad, typename RightPad, bool SkipIsValidCheck = false>
1564 CK_TILE_HOST_DEVICE constexpr auto
1565 make_pad_transform(const LowLength& low_length,
1566  const LeftPad& left_pad,
1567  const RightPad& right_pad,
1569 {
1570  return pad<LowLength, LeftPad, RightPad, SkipIsValidCheck>{low_length, left_pad, right_pad};
1571 }
1572 
1573 template <typename LowLength, typename LeftPadLength, bool SkipIsValidCheck = false>
1574 CK_TILE_HOST_DEVICE constexpr auto
1575 make_left_pad_transform(const LowLength& low_length,
1576  const LeftPadLength& left_pad_,
1578 {
1579  return left_pad<LowLength, LeftPadLength, SkipIsValidCheck>{low_length, left_pad_};
1580 }
1581 
1582 template <typename LowLength, typename RightPadLength, bool SkipIsValidCheck = false>
1583 CK_TILE_HOST_DEVICE constexpr auto
1584 make_right_pad_transform(const LowLength& low_length,
1585  const RightPadLength& right_pad_,
1587 {
1588  return right_pad<LowLength, RightPadLength, SkipIsValidCheck>{low_length, right_pad_};
1589 }
1590 
1591 template <typename UpLengths,
1592  typename Coefficients,
1593  typename std::enable_if<UpLengths::size() == Coefficients::size(), bool>::type = false>
1594 CK_TILE_HOST_DEVICE constexpr auto make_embed_transform(const UpLengths& up_lengths,
1595  const Coefficients& coefficients)
1596 {
1597  return embed<UpLengths, Coefficients>{up_lengths, coefficients};
1598 }
1599 
1600 template <typename LowLengths>
1601 CK_TILE_HOST_DEVICE constexpr auto
1602 make_merge_transform_v2_magic_division(const LowLengths& low_lengths)
1603 {
1604  return merge_v2_magic_division<LowLengths>{low_lengths};
1605 }
1606 
1607 template <typename LowLengths>
1608 CK_TILE_HOST_DEVICE constexpr auto
1609 make_merge_transform_v3_division_mod(const LowLengths& low_lengths)
1610 {
1611  return merge_v3_division_mod<LowLengths>{low_lengths};
1612 }
1613 
1614 template <typename LowLengths>
1615 CK_TILE_HOST_DEVICE constexpr auto make_merge_transform(const LowLengths& low_lengths)
1616 {
1617  return make_merge_transform_v2_magic_division(low_lengths);
1618 }
1619 
1620 template <typename UpLengths, bool Use24BitIntegerCalculation = false>
1621 CK_TILE_HOST_DEVICE constexpr auto
1622 make_unmerge_transform(const UpLengths& up_lengths,
1624 {
1625  return unmerge<UpLengths, Use24BitIntegerCalculation>{up_lengths};
1626 }
1627 
1628 template <typename LowerIndex>
1629 CK_TILE_HOST_DEVICE constexpr auto make_freeze_transform(const LowerIndex& low_idx)
1630 {
1631  return freeze<LowerIndex>{low_idx};
1632 }
1633 
1634 template <typename UpperIndex>
1635 CK_TILE_HOST_DEVICE constexpr auto make_insert_transform(const UpperIndex& up_idx)
1636 {
1637  return insert<UpperIndex>{up_idx};
1638 }
1639 
1640 template <typename UpLengths>
1641 CK_TILE_HOST_DEVICE constexpr auto make_replicate_transform(const UpLengths& up_lengths)
1642 {
1643  return replicate<UpLengths>{up_lengths};
1644 }
1645 
1646 template <typename LowLength, typename SliceBegin, typename SliceEnd>
1647 CK_TILE_HOST_DEVICE constexpr auto make_slice_transform(const LowLength& low_length,
1648  const SliceBegin& slice_begin,
1649  const SliceEnd& slice_end)
1650 {
1651  return slice<LowLength, SliceBegin, SliceEnd>{low_length, slice_begin, slice_end};
1652 }
1653 
1654 template <typename Modulus, typename UpLength>
1655 CK_TILE_HOST_DEVICE constexpr auto make_modulo_transform(const Modulus& modulus,
1656  const UpLength& up_length)
1657 {
1658  return modulo<Modulus, UpLength>{modulus, up_length};
1659 }
1660 
1661 template <typename LowLengths>
1662 CK_TILE_HOST_DEVICE constexpr auto make_xor_transform(const LowLengths& low_lengths)
1663 {
1664  return xor_t<LowLengths>{low_lengths};
1665 }
1666 
1667 template <typename LowLength, typename OffsetLength>
1668 CK_TILE_HOST_DEVICE constexpr auto make_offset_transform(const LowLength& low_length,
1669  const OffsetLength& offset_length)
1670 {
1671  return offset<LowLength, OffsetLength>{low_length, offset_length};
1672 }
1673 
1674 } // namespace ck_tile
1675 
1677 namespace ck_tile {
1678 
1679 template <typename UpLength, typename Indices>
1680 CK_TILE_HOST_DEVICE constexpr auto make_indexing_transform(const UpLength& up_lengths,
1681  const Indices& indices)
1682 {
1683  // by default we use the simplest one
1686 }
1687 
1688 template <typename UpLength, typename IndexingAdaptor>
1689 CK_TILE_HOST_DEVICE constexpr auto
1690 make_indexing_transform_with_adaptor(const UpLength& up_lengths, const IndexingAdaptor& iadaptor)
1691 {
1692  return indexing<UpLength, IndexingAdaptor>{up_lengths, iadaptor};
1693 }
1694 
1695 } // namespace ck_tile
#define CK_TILE_HOST_DEVICE
Definition: config.hpp:42
__host__ constexpr __device__ auto unmerge(const Layout< Shape, UnrolledDesc > &layout, const NewLengths &new_lengths, [[maybe_unused]] const NewIdxs &new_indexes)
Unmerge selected dim in layout.
Definition: layout_utils.hpp:474
__host__ constexpr __device__ index_t gcd(index_t x, index_t y)
Definition: math.hpp:154
Definition: cluster_descriptor.hpp:13
constexpr CK_TILE_HOST_DEVICE auto make_insert_transform(const UpperIndex &up_idx)
Definition: coordinate_transform.hpp:1635
constexpr CK_TILE_HOST_DEVICE auto container_reduce(const Container &x, Reduce reduce, Init init, number< IBegin >=number< 0 >{}, number< IEnd >=number< Container::size()>{}, number< IStep >=number< 1 >{})
Definition: container_helper.hpp:198
constexpr CK_TILE_HOST_DEVICE auto make_left_pad_transform(const LowLength &low_length, const LeftPadLength &left_pad_, bool_constant< SkipIsValidCheck >=bool_constant< false >{})
Definition: coordinate_transform.hpp:1575
coord_transform_enum
Definition: coordinate_transform.hpp:17
constexpr CK_TILE_HOST_DEVICE auto make_right_pad_transform(const LowLength &low_length, const RightPadLength &right_pad_, bool_constant< SkipIsValidCheck >=bool_constant< false >{})
Definition: coordinate_transform.hpp:1584
__host__ __device__ multiplies() -> multiplies< void, void >
FIXME: create macro to replace 'host device' and nothing more.
constexpr CK_TILE_HOST_DEVICE auto make_indexing_transform_with_adaptor(const UpLength &up_lengths, const IndexingAdaptor &iadaptor)
Definition: coordinate_transform.hpp:1690
constexpr CK_TILE_HOST_DEVICE auto make_offset_transform(const LowLength &low_length, const OffsetLength &offset_length)
Definition: coordinate_transform.hpp:1668
is_static< T > is_known_at_compile_time
Definition: type_traits.hpp:94
constexpr CK_TILE_HOST_DEVICE auto make_slice_transform(const LowLength &low_length, const SliceBegin &slice_begin, const SliceEnd &slice_end)
Definition: coordinate_transform.hpp:1647
constexpr CK_TILE_HOST_DEVICE auto make_merge_transform(const LowLengths &low_lengths)
Definition: coordinate_transform.hpp:1615
int32_t index_t
Definition: integer.hpp:9
constexpr CK_TILE_HOST_DEVICE auto make_pass_through_transform(const LowLength &low_length)
Definition: coordinate_transform.hpp:1558
constexpr CK_TILE_HOST_DEVICE auto make_pad_transform(const LowLength &low_length, const LeftPad &left_pad, const RightPad &right_pad, bool_constant< SkipIsValidCheck >=bool_constant< false >{})
Definition: coordinate_transform.hpp:1565
constexpr CK_TILE_HOST_DEVICE auto make_unmerge_transform(const UpLengths &up_lengths, bool_constant< Use24BitIntegerCalculation >=bool_constant< false >{})
Definition: coordinate_transform.hpp:1622
constexpr CK_TILE_HOST_DEVICE auto make_merge_transform_v3_division_mod(const LowLengths &low_lengths)
Definition: coordinate_transform.hpp:1609
constexpr CK_TILE_HOST_DEVICE auto make_modulo_transform(const Modulus &modulus, const UpLength &up_length)
Definition: coordinate_transform.hpp:1655
constexpr CK_TILE_HOST_DEVICE auto make_indexing_transform(const UpLength &up_lengths, const Indices &indices)
Definition: coordinate_transform.hpp:1680
constexpr CK_TILE_HOST_DEVICE auto generate_tuple(F &&f, number< N >)
Definition: tuple.hpp:429
constexpr CK_TILE_HOST_DEVICE auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:360
constexpr CK_TILE_HOST_DEVICE auto make_xor_transform(const LowLengths &low_lengths)
Definition: coordinate_transform.hpp:1662
constexpr CK_TILE_HOST_DEVICE auto make_replicate_transform(const UpLengths &up_lengths)
Definition: coordinate_transform.hpp:1641
constexpr CK_TILE_HOST_DEVICE auto make_freeze_transform(const LowerIndex &low_idx)
Definition: coordinate_transform.hpp:1629
constexpr CK_TILE_HOST_DEVICE auto make_merge_transform_v2_magic_division(const LowLengths &low_lengths)
Definition: coordinate_transform.hpp:1602
constexpr CK_TILE_HOST_DEVICE auto make_embed_transform(const UpLengths &up_lengths, const Coefficients &coefficients)
Definition: coordinate_transform.hpp:1594
std::enable_if< B, T > enable_if
Definition: enable_if.hpp:24
__host__ constexpr __device__ auto container_reverse_exclusive_scan(const Array< TData, NSize > &x, Reduce f, TData init)
Definition: container_helper.hpp:213
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
__host__ constexpr __device__ auto container_reduce(const Container &x, Reduce reduce, Init init, Number< IBegin >=Number< 0 >{}, Number< IEnd >=Number< Container::Size()>{}, Number< IStep >=Number< 1 >{})
Definition: container_helper.hpp:111
const GenericPointer< typename T::ValueType > T2 value
Definition: pointer.h:1350
A fixed-size array container similar to std::array with additional utilities.
Definition: array.hpp:43
Definition: coordinate_transform.hpp:32
static constexpr CK_TILE_HOST_DEVICE auto get_type_enum()
Definition: coordinate_transform.hpp:33
static constexpr CK_TILE_HOST_DEVICE index_t get_num_of_upper_dimension()
Definition: coordinate_transform.hpp:40
static constexpr CK_TILE_HOST_DEVICE index_t get_num_of_lower_dimension()
Definition: coordinate_transform.hpp:38
static constexpr CK_TILE_HOST_DEVICE auto calculate_upper_dimension_safe_vector_length_strides(const LowVectorLengths &, const LowVectorStrides &)
Definition: coordinate_transform.hpp:47
Definition: integral_constant.hpp:13
Definition: coordinate_transform.hpp:449
constexpr CK_TILE_HOST_DEVICE embed()=default
constexpr CK_TILE_HOST_DEVICE const auto & get_upper_lengths() const
Definition: coordinate_transform.hpp:471
static constexpr CK_TILE_HOST_DEVICE bool is_valid_upper_index_always_mapped_to_valid_lower_index()
Definition: coordinate_transform.hpp:506
constexpr CK_TILE_HOST_DEVICE void calculate_lower_index(LowIdx &idx_low, const UpIdx &idx_up) const
Definition: coordinate_transform.hpp:474
CK_TILE_HOST_DEVICE void update_lower_index(LowIdxDiff &idx_diff_low, const UpIdxDiff &idx_diff_up, LowIdx &idx_low, const UpIdx &) const
Definition: coordinate_transform.hpp:488
static constexpr CK_TILE_HOST_DEVICE auto get_type_enum()
Definition: coordinate_transform.hpp:466
UpLengths up_lengths_
Definition: coordinate_transform.hpp:455
static constexpr CK_TILE_HOST_DEVICE bool is_valid_upper_index_mapped_to_valid_lower_index(const UpIdx &)
Definition: coordinate_transform.hpp:513
static constexpr CK_TILE_HOST_DEVICE bool is_known_at_compile_time()
Definition: coordinate_transform.hpp:518
Coefficients coefficients_
Definition: coordinate_transform.hpp:456
constexpr CK_TILE_HOST_DEVICE embed(const UpLengths &up_lengths, const Coefficients &coefficients)
Definition: coordinate_transform.hpp:460
static constexpr index_t NDimUp
Definition: coordinate_transform.hpp:450
Definition: coordinate_transform.hpp:946
constexpr CK_TILE_HOST_DEVICE void calculate_lower_index(LowIdx &idx_low, const UpIdx &) const
Definition: coordinate_transform.hpp:956
static CK_TILE_HOST_DEVICE void update_lower_index(LowIdxDiff &idx_diff_low, const UpIdxDiff &, LowIdx &, const UpIdx &)
Definition: coordinate_transform.hpp:966
LowerIndex low_idx_
Definition: coordinate_transform.hpp:947
static constexpr CK_TILE_HOST_DEVICE auto get_upper_lengths()
Definition: coordinate_transform.hpp:953
constexpr CK_TILE_HOST_DEVICE freeze()=default
constexpr CK_TILE_HOST_DEVICE freeze(const LowerIndex &low_idx)
Definition: coordinate_transform.hpp:951
static constexpr CK_TILE_HOST_DEVICE bool is_valid_upper_index_always_mapped_to_valid_lower_index()
Definition: coordinate_transform.hpp:975
static constexpr CK_TILE_HOST_DEVICE bool is_known_at_compile_time()
Definition: coordinate_transform.hpp:987
static constexpr CK_TILE_HOST_DEVICE bool is_valid_upper_index_mapped_to_valid_lower_index(const UpIdx &)
Definition: coordinate_transform.hpp:982
Definition: type_traits.hpp:76
Definition: indexing_adaptor.hpp:20
Definition: coordinate_transform.hpp:1476
static constexpr CK_TILE_HOST_DEVICE bool is_valid_upper_index_mapped_to_valid_lower_index(const UpIdx &)
Definition: coordinate_transform.hpp:1532
constexpr CK_TILE_HOST_DEVICE indexing()=default
constexpr CK_TILE_HOST_DEVICE void calculate_lower_index(LowIdx &idx_low, const UpIdx &idx_up) const
Definition: coordinate_transform.hpp:1502
static constexpr CK_TILE_HOST_DEVICE bool is_valid_upper_index_always_mapped_to_valid_lower_index()
Definition: coordinate_transform.hpp:1525
decltype(make_tuple(UpLength{})) UpLengths
Definition: coordinate_transform.hpp:1482
UpLengths up_lengths_
Definition: coordinate_transform.hpp:1483
constexpr CK_TILE_HOST_DEVICE const auto & get_upper_lengths() const
Definition: coordinate_transform.hpp:1499
CK_TILE_HOST_DEVICE void update_lower_index(LowIdxDiff &idx_diff_low, const UpIdxDiff &idx_diff_up, LowIdx &idx_low, const UpIdx &idx_up) const
Definition: coordinate_transform.hpp:1511
static constexpr CK_TILE_HOST_DEVICE bool is_known_at_compile_time()
Definition: coordinate_transform.hpp:1537
IndexingAdaptor iadaptor_
Definition: coordinate_transform.hpp:1484
constexpr CK_TILE_HOST_DEVICE indexing(const UpLength &up_length, const IndexingAdaptor &iadaptor)
Definition: coordinate_transform.hpp:1488
static constexpr CK_TILE_HOST_DEVICE auto get_type_enum()
Definition: coordinate_transform.hpp:1494
Definition: coordinate_transform.hpp:1005
static CK_TILE_HOST_DEVICE void update_lower_index(LowIdxDiff &, const UpIdxDiff &, LowIdx &, const UpIdx &)
Definition: coordinate_transform.hpp:1032
constexpr CK_TILE_HOST_DEVICE insert(const UpperLength &up_length)
Definition: coordinate_transform.hpp:1012
UpLengths up_lengths_
Definition: coordinate_transform.hpp:1008
decltype(make_tuple(UpperLength{})) UpLengths
Definition: coordinate_transform.hpp:1006
static constexpr CK_TILE_HOST_DEVICE bool is_valid_upper_index_mapped_to_valid_lower_index(const UpIdx &)
Definition: coordinate_transform.hpp:1049
constexpr CK_TILE_HOST_DEVICE insert()=default
constexpr CK_TILE_HOST_DEVICE void calculate_lower_index(LowIdx &, const UpIdx &) const
Definition: coordinate_transform.hpp:1024
static constexpr CK_TILE_HOST_DEVICE index_t get_num_of_lower_dimension()
Definition: coordinate_transform.hpp:1017
constexpr CK_TILE_HOST_DEVICE auto get_upper_lengths() const
Definition: coordinate_transform.hpp:1021
static constexpr CK_TILE_HOST_DEVICE bool IsLinearTransform()
Definition: coordinate_transform.hpp:1039
static constexpr CK_TILE_HOST_DEVICE bool is_known_at_compile_time()
Definition: coordinate_transform.hpp:1054
static constexpr CK_TILE_HOST_DEVICE bool is_valid_upper_index_always_mapped_to_valid_lower_index()
Definition: coordinate_transform.hpp:1042
static constexpr CK_TILE_HOST_DEVICE index_t get_num_of_upper_dimension()
Definition: coordinate_transform.hpp:1019
constexpr CK_TILE_HOST_DEVICE auto operator()(number< I > i) const
Definition: coordinate_transform.hpp:540
Definition: coordinate_transform.hpp:253
static constexpr CK_TILE_HOST_DEVICE bool is_valid_upper_index_always_mapped_to_valid_lower_index()
Definition: coordinate_transform.hpp:300
static constexpr CK_TILE_HOST_DEVICE bool is_known_at_compile_time()
Definition: coordinate_transform.hpp:312
LeftPadLength left_pad_length_
Definition: coordinate_transform.hpp:260
UpLengths up_lengths_
Definition: coordinate_transform.hpp:259
constexpr CK_TILE_HOST_DEVICE left_pad()=default
decltype(make_tuple(LowLength{}+LeftPadLength{})) UpLengths
Definition: coordinate_transform.hpp:257
constexpr CK_TILE_HOST_DEVICE bool is_valid_upper_index_mapped_to_valid_lower_index(const UpIdx &idx_up) const
Definition: coordinate_transform.hpp:307
constexpr CK_TILE_HOST_DEVICE left_pad(const LowLength &low_length, const LeftPadLength &left_pad_length)
Definition: coordinate_transform.hpp:264
static CK_TILE_HOST_DEVICE void update_lower_index(LowIdxDiff &idx_diff_low, const UpIdxDiff &idx_diff_up, LowIdx &idx_low, const UpIdx &)
Definition: coordinate_transform.hpp:283
constexpr CK_TILE_HOST_DEVICE const auto & get_upper_lengths() const
Definition: coordinate_transform.hpp:270
constexpr CK_TILE_HOST_DEVICE void calculate_lower_index(LowIdx &idx_low, const UpIdx &idx_up) const
Definition: coordinate_transform.hpp:273
static constexpr CK_TILE_HOST_DEVICE auto calculate_upper_dimension_safe_vector_length_strides(const LowVectorLengths &low_vector_lengths, const LowVectorStrides &low_vector_strides)
Definition: coordinate_transform.hpp:321
static constexpr CK_TILE_HOST_DEVICE auto calculate_magic_numbers(uint32_t divisor)
Definition: magic_div.hpp:29
static constexpr CK_TILE_DEVICE uint32_t do_magic_division(uint32_t dividend, uint32_t multiplier, uint32_t shift)
Definition: magic_div.hpp:60
Definition: coordinate_transform.hpp:560
static constexpr CK_TILE_HOST_DEVICE auto calculate_upper_dimension_safe_vector_length_strides(const LowVectorLengths &low_vector_lengths, const LowVectorStrides &low_vector_strides)
Definition: coordinate_transform.hpp:674
CK_TILE_HOST_DEVICE void update_lower_index(LowIdxDiff &idx_diff_low, const UpIdxDiff &, LowIdx &idx_low, const UpIdx &idx_up_new) const
Definition: coordinate_transform.hpp:621
LowLengthsMagicDivisor low_lengths_magic_divisor_
Definition: coordinate_transform.hpp:574
static constexpr auto I1
Definition: coordinate_transform.hpp:578
static constexpr CK_TILE_HOST_DEVICE bool is_valid_upper_index_always_mapped_to_valid_lower_index()
Definition: coordinate_transform.hpp:652
static constexpr index_t NDimLow
Definition: coordinate_transform.hpp:561
static constexpr CK_TILE_HOST_DEVICE bool is_known_at_compile_time()
Definition: coordinate_transform.hpp:657
constexpr CK_TILE_HOST_DEVICE const auto & get_upper_lengths() const
Definition: coordinate_transform.hpp:597
decltype(generate_tuple(lambda_merge_generate_MagicDivision_calculate_magic_divisor< LowLengths >{}, number< NDimLow >{})) LowLengthsMagicDivisor
Definition: coordinate_transform.hpp:571
LowLengths low_lengths_
Definition: coordinate_transform.hpp:573
static constexpr CK_TILE_HOST_DEVICE bool is_valid_upper_index_mapped_to_valid_lower_index(const UpIdx &)
Definition: coordinate_transform.hpp:666
static constexpr auto I0
Definition: coordinate_transform.hpp:577
UpLengths up_lengths_
Definition: coordinate_transform.hpp:575
decltype(make_tuple(container_reduce(LowLengths{}, multiplies{}, number< 1 >{}))) UpLengths
Definition: coordinate_transform.hpp:567
constexpr CK_TILE_HOST_DEVICE void calculate_lower_index(LowIdx &idx_low, const UpIdx &idx_up) const
Definition: coordinate_transform.hpp:600
constexpr CK_TILE_HOST_DEVICE merge_v2_magic_division()=default
static constexpr CK_TILE_HOST_DEVICE auto get_type_enum()
Definition: coordinate_transform.hpp:592
constexpr CK_TILE_HOST_DEVICE merge_v2_magic_division(const LowLengths &low_lengths)
Definition: coordinate_transform.hpp:582
Definition: coordinate_transform.hpp:703
decltype(make_tuple(container_reduce(LowLengths{}, multiplies{}, number< 1 >{}))) UpLengths
Definition: coordinate_transform.hpp:713
static constexpr CK_TILE_HOST_DEVICE auto calculate_upper_dimension_safe_vector_length_strides(const LowVectorLengths &low_vector_lengths, const LowVectorStrides &low_vector_strides)
Definition: coordinate_transform.hpp:800
decltype(container_reverse_exclusive_scan(LowLengths{}, multiplies{}, number< 1 >{})) LowLengthsScan
Definition: coordinate_transform.hpp:710
LowLengths low_lengths_
Definition: coordinate_transform.hpp:715
static constexpr CK_TILE_HOST_DEVICE bool is_valid_upper_index_mapped_to_valid_lower_index(const UpIdx &)
Definition: coordinate_transform.hpp:792
UpLengths up_lengths_
Definition: coordinate_transform.hpp:717
constexpr CK_TILE_HOST_DEVICE const auto & get_upper_lengths() const
Definition: coordinate_transform.hpp:730
constexpr CK_TILE_HOST_DEVICE merge_v3_division_mod(const LowLengths &low_lengths)
Definition: coordinate_transform.hpp:721
LowLengthsScan low_lengths_scan_
Definition: coordinate_transform.hpp:716
static constexpr CK_TILE_HOST_DEVICE bool is_known_at_compile_time()
Definition: coordinate_transform.hpp:783
constexpr CK_TILE_HOST_DEVICE void calculate_lower_index(LowIdx &idx_low, const UpIdx &idx_up) const
Definition: coordinate_transform.hpp:733
CK_TILE_HOST_DEVICE void update_lower_index(LowIdxDiff &idx_diff_low, const UpIdxDiff &, LowIdx &idx_low, const UpIdx &idx_up_new) const
Definition: coordinate_transform.hpp:751
constexpr CK_TILE_HOST_DEVICE merge_v3_division_mod()=default
static constexpr CK_TILE_HOST_DEVICE bool is_valid_upper_index_always_mapped_to_valid_lower_index()
Definition: coordinate_transform.hpp:778
Definition: coordinate_transform.hpp:1222
UpLengths up_lengths_
Definition: coordinate_transform.hpp:1228
static constexpr CK_TILE_HOST_DEVICE bool is_valid_upper_index_mapped_to_valid_lower_index(const UpIdx &)
Definition: coordinate_transform.hpp:1274
decltype(make_tuple(UpLength{})) UpLengths
Definition: coordinate_transform.hpp:1225
constexpr CK_TILE_HOST_DEVICE modulo(const Modulus &modulus, const UpLength &up_length)
Definition: coordinate_transform.hpp:1232
CK_TILE_HOST_DEVICE void update_lower_index(LowIdxDiff &idx_diff_low, const UpIdxDiff &idx_diff_up, LowIdx &idx_low, const UpIdx &up_idx) const
Definition: coordinate_transform.hpp:1250
constexpr CK_TILE_HOST_DEVICE void calculate_lower_index(LowIdx &idx_low, const UpIdx &idx_up) const
Definition: coordinate_transform.hpp:1240
constexpr CK_TILE_HOST_DEVICE const auto & get_upper_lengths() const
Definition: coordinate_transform.hpp:1237
Modulus modulus_
Definition: coordinate_transform.hpp:1227
constexpr CK_TILE_HOST_DEVICE modulo()=default
static constexpr CK_TILE_HOST_DEVICE bool is_valid_upper_index_always_mapped_to_valid_lower_index()
Definition: coordinate_transform.hpp:1267
static constexpr CK_TILE_HOST_DEVICE bool is_known_at_compile_time()
Definition: coordinate_transform.hpp:1279
Definition: math.hpp:98
Definition: coordinate_transform.hpp:1392
static CK_TILE_HOST_DEVICE void update_lower_index(LowIdxDiff &idx_diff_low, const UpIdxDiff &idx_diff_up, LowIdx &idx_low, const UpIdx &)
Definition: coordinate_transform.hpp:1427
static constexpr CK_TILE_HOST_DEVICE bool is_known_at_compile_time()
Definition: coordinate_transform.hpp:1456
decltype(make_tuple(LowLength{})) UpLengths
Definition: coordinate_transform.hpp:1396
static constexpr CK_TILE_HOST_DEVICE bool is_valid_upper_index_always_mapped_to_valid_lower_index()
Definition: coordinate_transform.hpp:1444
OffsetLength offset_length_
Definition: coordinate_transform.hpp:1399
constexpr CK_TILE_HOST_DEVICE const auto & get_upper_lengths() const
Definition: coordinate_transform.hpp:1414
constexpr CK_TILE_HOST_DEVICE offset()=default
UpLengths up_lengths_
Definition: coordinate_transform.hpp:1398
static constexpr CK_TILE_HOST_DEVICE auto get_type_enum()
Definition: coordinate_transform.hpp:1409
constexpr CK_TILE_HOST_DEVICE offset(const LowLength &low_length, const OffsetLength &offset_length)
Definition: coordinate_transform.hpp:1403
constexpr CK_TILE_HOST_DEVICE void calculate_lower_index(LowIdx &idx_low, const UpIdx &idx_up) const
Definition: coordinate_transform.hpp:1417
constexpr CK_TILE_HOST_DEVICE bool is_valid_upper_index_mapped_to_valid_lower_index(const UpIdx &) const
Definition: coordinate_transform.hpp:1451
Definition: coordinate_transform.hpp:161
constexpr CK_TILE_HOST_DEVICE pad(const LowLength &low_length, const LeftPadLength &left_pad_length, const RightPadLength &right_pad_length)
Definition: coordinate_transform.hpp:173
decltype(make_tuple(LowLength{}+LeftPadLength{}+RightPadLength{})) UpLengths
Definition: coordinate_transform.hpp:165
constexpr CK_TILE_HOST_DEVICE const auto & get_upper_lengths() const
Definition: coordinate_transform.hpp:182
LeftPadLength left_pad_length_
Definition: coordinate_transform.hpp:168
UpLengths up_lengths_
Definition: coordinate_transform.hpp:167
static constexpr CK_TILE_HOST_DEVICE bool is_known_at_compile_time()
Definition: coordinate_transform.hpp:226
constexpr CK_TILE_HOST_DEVICE bool is_valid_upper_index_mapped_to_valid_lower_index(const UpIdx &idx_up) const
Definition: coordinate_transform.hpp:219
static CK_TILE_HOST_DEVICE void update_lower_index(LowIdxDiff &idx_diff_low, const UpIdxDiff &idx_diff_up, LowIdx &idx_low, const UpIdx &)
Definition: coordinate_transform.hpp:195
constexpr CK_TILE_HOST_DEVICE void calculate_lower_index(LowIdx &idx_low, const UpIdx &idx_up) const
Definition: coordinate_transform.hpp:185
constexpr CK_TILE_HOST_DEVICE pad()
Definition: coordinate_transform.hpp:171
RightPadLength right_pad_length_
Definition: coordinate_transform.hpp:169
static constexpr CK_TILE_HOST_DEVICE bool is_valid_upper_index_always_mapped_to_valid_lower_index()
Definition: coordinate_transform.hpp:212
Definition: coordinate_transform.hpp:66
static constexpr CK_TILE_HOST_DEVICE bool is_valid_upper_index_mapped_to_valid_lower_index(const UpIdx &)
Definition: coordinate_transform.hpp:125
UpLengths up_lengths_
Definition: coordinate_transform.hpp:74
constexpr CK_TILE_HOST_DEVICE pass_through(const LowLength &low_length)
Definition: coordinate_transform.hpp:78
static constexpr CK_TILE_HOST_DEVICE bool is_valid_upper_index_always_mapped_to_valid_lower_index()
Definition: coordinate_transform.hpp:118
decltype(make_tuple(LowLength{})) UpLengths
Definition: coordinate_transform.hpp:72
constexpr CK_TILE_HOST_DEVICE pass_through()=default
static constexpr auto type_enum
Definition: coordinate_transform.hpp:67
static CK_TILE_HOST_DEVICE void update_lower_index(LowIdxDiff &idx_diff_low, const UpIdxDiff &idx_diff_up, LowIdx &idx_low, const UpIdx &)
Definition: coordinate_transform.hpp:101
static constexpr CK_TILE_HOST_DEVICE void calculate_lower_index(LowIdx &idx_low, const UpIdx &idx_up)
Definition: coordinate_transform.hpp:91
constexpr CK_TILE_HOST_DEVICE const auto & get_upper_lengths() const
Definition: coordinate_transform.hpp:88
static constexpr CK_TILE_HOST_DEVICE bool is_known_at_compile_time()
Definition: coordinate_transform.hpp:130
static constexpr CK_TILE_HOST_DEVICE auto get_type_enum()
Definition: coordinate_transform.hpp:83
static constexpr CK_TILE_HOST_DEVICE auto calculate_upper_dimension_safe_vector_length_strides(const LowVectorLengths &low_vector_lengths, const LowVectorStrides &low_vector_strides)
Definition: coordinate_transform.hpp:138
Definition: coordinate_transform.hpp:1072
static constexpr CK_TILE_HOST_DEVICE bool is_valid_upper_index_always_mapped_to_valid_lower_index()
Definition: coordinate_transform.hpp:1100
static constexpr CK_TILE_HOST_DEVICE bool is_known_at_compile_time()
Definition: coordinate_transform.hpp:1112
constexpr CK_TILE_HOST_DEVICE replicate()=default
static constexpr CK_TILE_HOST_DEVICE bool is_valid_upper_index_mapped_to_valid_lower_index(const UpIdx &)
Definition: coordinate_transform.hpp:1107
constexpr CK_TILE_HOST_DEVICE void calculate_lower_index(LowIdx &, const UpIdx &) const
Definition: coordinate_transform.hpp:1084
static CK_TILE_HOST_DEVICE void update_lower_index(LowIdxDiff &, const UpIdxDiff &, LowIdx &, const UpIdx &)
Definition: coordinate_transform.hpp:1092
UpLengths up_lengths_
Definition: coordinate_transform.hpp:1118
constexpr CK_TILE_HOST_DEVICE auto get_upper_lengths() const
Definition: coordinate_transform.hpp:1081
constexpr CK_TILE_HOST_DEVICE replicate(const UpLengths &up_lengths)
Definition: coordinate_transform.hpp:1077
Definition: coordinate_transform.hpp:345
LowLength low_length_
Definition: coordinate_transform.hpp:352
constexpr CK_TILE_HOST_DEVICE right_pad(const LowLength &low_length, const RightPadLength &right_pad_length)
Definition: coordinate_transform.hpp:357
constexpr CK_TILE_HOST_DEVICE const auto & get_upper_lengths() const
Definition: coordinate_transform.hpp:365
static constexpr CK_TILE_HOST_DEVICE bool is_known_at_compile_time()
Definition: coordinate_transform.hpp:407
static CK_TILE_HOST_DEVICE void update_lower_index(LowIdxDiff &idx_diff_low, const UpIdxDiff &idx_diff_up, LowIdx &idx_low, const UpIdx &)
Definition: coordinate_transform.hpp:378
static constexpr CK_TILE_HOST_DEVICE void calculate_lower_index(LowIdx &idx_low, const UpIdx &idx_up)
Definition: coordinate_transform.hpp:368
static constexpr CK_TILE_HOST_DEVICE auto calculate_upper_dimension_safe_vector_length_strides(const LowVectorLengths &low_vector_lengths, const LowVectorStrides &low_vector_strides)
Definition: coordinate_transform.hpp:417
RightPadLength right_pad_length_
Definition: coordinate_transform.hpp:353
constexpr CK_TILE_HOST_DEVICE bool is_valid_upper_index_mapped_to_valid_lower_index(const UpIdx &idx_up) const
Definition: coordinate_transform.hpp:402
static constexpr CK_TILE_HOST_DEVICE bool is_valid_upper_index_always_mapped_to_valid_lower_index()
Definition: coordinate_transform.hpp:395
constexpr CK_TILE_HOST_DEVICE right_pad()=default
decltype(make_tuple(LowLength{}+RightPadLength{})) UpLengths
Definition: coordinate_transform.hpp:349
UpLengths up_lengths_
Definition: coordinate_transform.hpp:351
Definition: coordinate_transform.hpp:1132
constexpr CK_TILE_HOST_DEVICE slice(const LowLength &, const SliceBegin &slice_begin, const SliceEnd &slice_end)
Definition: coordinate_transform.hpp:1144
static CK_TILE_HOST_DEVICE void update_lower_index(LowIdxDiff &idx_diff_low, const UpIdxDiff &idx_diff_up, LowIdx &idx_low, const UpIdx &)
Definition: coordinate_transform.hpp:1166
constexpr CK_TILE_HOST_DEVICE slice()=default
constexpr CK_TILE_HOST_DEVICE bool is_valid_upper_index_mapped_to_valid_lower_index(const UpIdx &) const
Definition: coordinate_transform.hpp:1190
SliceBegin slice_begin_
Definition: coordinate_transform.hpp:1139
UpLengths up_lengths_
Definition: coordinate_transform.hpp:1138
decltype(make_tuple(SliceEnd{} - SliceBegin{})) UpLengths
Definition: coordinate_transform.hpp:1136
constexpr CK_TILE_HOST_DEVICE const auto & get_upper_lengths() const
Definition: coordinate_transform.hpp:1153
static constexpr CK_TILE_HOST_DEVICE bool is_valid_upper_index_always_mapped_to_valid_lower_index()
Definition: coordinate_transform.hpp:1183
constexpr CK_TILE_HOST_DEVICE void calculate_lower_index(LowIdx &idx_low, const UpIdx &idx_up) const
Definition: coordinate_transform.hpp:1156
SliceEnd slice_end_
Definition: coordinate_transform.hpp:1140
static constexpr CK_TILE_HOST_DEVICE bool is_known_at_compile_time()
Definition: coordinate_transform.hpp:1195
Definition: functional.hpp:43
Definition: tuple.hpp:192
Definition: coordinate_transform.hpp:828
static constexpr CK_TILE_HOST_DEVICE auto calculate_upper_dimension_safe_vector_length_strides(const LowVectorLengths &low_vector_lengths, const LowVectorStrides &low_vector_strides)
Definition: coordinate_transform.hpp:911
CK_TILE_HOST_DEVICE void update_lower_index(LowIdxDiff &idx_diff_low, const UpIdxDiff &idx_diff_up, LowIdx &idx_low, const UpIdx &) const
Definition: coordinate_transform.hpp:879
static constexpr CK_TILE_HOST_DEVICE bool is_valid_upper_index_mapped_to_valid_lower_index(const UpIdx &)
Definition: coordinate_transform.hpp:897
constexpr CK_TILE_HOST_DEVICE unmerge(const UpLengths &up_lengths)
Definition: coordinate_transform.hpp:842
constexpr CK_TILE_HOST_DEVICE void calculate_lower_index(LowIdx &idx_low, const UpIdx &idx_up) const
Definition: coordinate_transform.hpp:856
static constexpr CK_TILE_HOST_DEVICE bool is_known_at_compile_time()
Definition: coordinate_transform.hpp:902
constexpr CK_TILE_HOST_DEVICE const auto & get_upper_lengths() const
Definition: coordinate_transform.hpp:853
UpLengthsScan up_lengths_scan_
Definition: coordinate_transform.hpp:838
UpLengths up_lengths_
Definition: coordinate_transform.hpp:837
static constexpr CK_TILE_HOST_DEVICE bool is_valid_upper_index_always_mapped_to_valid_lower_index()
Definition: coordinate_transform.hpp:890
constexpr CK_TILE_HOST_DEVICE unmerge()=default
static constexpr CK_TILE_HOST_DEVICE auto get_type_enum()
Definition: coordinate_transform.hpp:848
decltype(container_reverse_exclusive_scan(UpLengths{}, multiplies{}, number< 1 >{})) UpLengthsScan
Definition: coordinate_transform.hpp:835
Definition: coordinate_transform.hpp:1299
constexpr CK_TILE_HOST_DEVICE const auto & get_upper_lengths() const
Definition: coordinate_transform.hpp:1318
constexpr CK_TILE_HOST_DEVICE xor_t(const LowLengths &low_lengths)
Definition: coordinate_transform.hpp:1311
constexpr CK_TILE_HOST_DEVICE void calculate_lower_index(LowIdx &idx_low, const UpIdx &idx_up) const
Definition: coordinate_transform.hpp:1321
constexpr CK_TILE_HOST_DEVICE auto calculate_upper_dimension_safe_vector_length_strides(const LowVectorLengths &low_vector_lengths, const LowVectorStrides &low_vector_strides) const
Definition: coordinate_transform.hpp:1370
LowLengths UpLengths
Definition: coordinate_transform.hpp:1305
UpLengths up_lengths_
Definition: coordinate_transform.hpp:1307
static constexpr CK_TILE_HOST_DEVICE bool is_valid_upper_index_mapped_to_valid_lower_index(const UpIdx &)
Definition: coordinate_transform.hpp:1358
static constexpr CK_TILE_HOST_DEVICE bool is_known_at_compile_time()
Definition: coordinate_transform.hpp:1363
static constexpr CK_TILE_HOST_DEVICE auto get_type_enum()
Definition: coordinate_transform.hpp:1313
constexpr CK_TILE_HOST_DEVICE xor_t()
Definition: coordinate_transform.hpp:1309
CK_TILE_HOST_DEVICE void update_lower_index(LowIdxDiff &idx_diff_low, const UpIdxDiff &, LowIdx &idx_low, const UpIdx &idx_up) const
Definition: coordinate_transform.hpp:1334
static constexpr CK_TILE_HOST_DEVICE bool is_valid_upper_index_always_mapped_to_valid_lower_index()
Definition: coordinate_transform.hpp:1351