/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_description/multi_index_transform.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_description/multi_index_transform.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_description/multi_index_transform.hpp Source File
multi_index_transform.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
8 
9 namespace ck {
10 
11 template <typename LowLength>
13 {
16 
17  using UpLengths = decltype(make_tuple(LowLength{}));
18 
20 
21  __host__ __device__ constexpr PassThrough() = default;
22 
23  __host__ __device__ constexpr PassThrough(const LowLength& low_length)
24  : up_lengths_{make_tuple(low_length)}
25  {
26  }
27 
28  __host__ __device__ static constexpr index_t GetNumOfLowerDimension() { return 1; }
29 
30  __host__ __device__ static constexpr index_t GetNumOfUpperDimension() { return 1; }
31 
32  __host__ __device__ constexpr const auto& GetUpperLengths() const { return up_lengths_; }
33 
34  template <typename LowIdx, typename UpIdx>
35  __host__ __device__ static constexpr void CalculateLowerIndex(LowIdx& idx_low,
36  const UpIdx& idx_up)
37  {
38  static_assert(LowIdx::Size() == 1 && UpIdx::Size() == 1,
39  "wrong! inconsistent # of dimension");
40 
41  idx_low(Number<0>{}) = idx_up[Number<0>{}];
42  }
43 
44  template <typename LowIdxDiff,
45  typename UpIdxDiff,
46  typename LowIdx,
47  typename UpIdx,
48  index_t Hack>
49  __host__ __device__ static void UpdateLowerIndex(LowIdxDiff& idx_diff_low,
50  const UpIdxDiff& idx_diff_up,
51  LowIdx& idx_low,
52  const UpIdx&,
54  {
55  static_assert(LowIdxDiff::Size() == 1 && UpIdxDiff::Size() == 1 && LowIdx::Size() == 1 &&
56  UpIdx::Size() == 1,
57  "wrong! inconsistent # of dimension");
58 
59  constexpr auto I0 = Number<0>{};
60 
61  idx_diff_low(I0) = idx_diff_up[I0];
62 
63  idx_low += idx_diff_low;
64  }
65 
66  __host__ __device__ static constexpr bool IsLinearTransform() { return true; }
67 
68  __host__ __device__ static constexpr bool IsValidUpperIndexAlwaysMappedToValidLowerIndex()
69  {
70  return true;
71  }
72 
73  template <typename UpIdx>
74  __host__ __device__ static constexpr bool
75  IsValidUpperIndexMappedToValidLowerIndex(const UpIdx& /* idx_up */)
76  {
77  return true;
78  }
79 
80  __host__ __device__ static constexpr bool IsKnownAtCompileTime()
81  {
83  }
84 
85  __host__ __device__ void Print() const
86  {
87  printf("{");
88  printf("PassThrough, ");
89  printf("up_lengths_");
91  printf("}");
92  }
93 };
94 
95 template <typename LowLength,
96  typename LeftPadLength,
97  typename RightPadLength,
98  bool SkipIsValidCheck = false>
99 struct Pad
100 {
103 
104  using UpLengths = decltype(make_tuple(LowLength{} + LeftPadLength{} + RightPadLength{}));
105 
107  LeftPadLength left_pad_length_;
108  RightPadLength right_pad_length_;
109 
110  __host__ __device__ constexpr Pad() = default;
111 
112  __host__ __device__ constexpr Pad(const LowLength& low_length,
113  const LeftPadLength& left_pad_length,
114  const RightPadLength& right_pad_length)
115  : up_lengths_{make_tuple(low_length + left_pad_length + right_pad_length)},
116  left_pad_length_{left_pad_length},
117  right_pad_length_{right_pad_length}
118  {
119  }
120 
121  __host__ __device__ static constexpr index_t GetNumOfLowerDimension() { return 1; }
122 
123  __host__ __device__ static constexpr index_t GetNumOfUpperDimension() { return 1; }
124 
125  __host__ __device__ constexpr const auto& GetUpperLengths() const { return up_lengths_; }
126 
127  template <typename LowIdx, typename UpIdx>
128  __host__ __device__ constexpr void CalculateLowerIndex(LowIdx& idx_low,
129  const UpIdx& idx_up) const
130  {
131  static_assert(LowIdx::Size() == 1 && UpIdx::Size() == 1,
132  "wrong! inconsistent # of dimension");
133 
134  idx_low(Number<0>{}) = idx_up[Number<0>{}] - left_pad_length_;
135  }
136 
137  template <typename LowIdxDiff,
138  typename UpIdxDiff,
139  typename LowIdx,
140  typename UpIdx,
141  index_t Hack>
142  __host__ __device__ static void UpdateLowerIndex(LowIdxDiff& idx_diff_low,
143  const UpIdxDiff& idx_diff_up,
144  LowIdx& idx_low,
145  const UpIdx&,
146  Number<Hack>)
147  {
148  static_assert(LowIdxDiff::Size() == 1 && UpIdxDiff::Size() == 1 && LowIdx::Size() == 1 &&
149  UpIdx::Size() == 1,
150  "wrong! inconsistent # of dimension");
151 
152  constexpr auto I0 = Number<0>{};
153 
154  idx_diff_low(I0) = idx_diff_up[I0];
155 
156  idx_low += idx_diff_low;
157  }
158 
159  __host__ __device__ static constexpr bool IsLinearTransform() { return true; }
160 
161  __host__ __device__ static constexpr bool IsValidUpperIndexAlwaysMappedToValidLowerIndex()
162  {
163  return SkipIsValidCheck;
164  }
165 
166  template <typename UpIdx>
167  __host__ __device__ constexpr bool
168  IsValidUpperIndexMappedToValidLowerIndex(const UpIdx& idx_up) const
169  {
170  return SkipIsValidCheck ||
171  ((idx_up[Number<0>{}] >= left_pad_length_) &&
172  (idx_up[Number<0>{}] < up_lengths_[Number<0>{}] - right_pad_length_));
173  }
174 
175  __host__ __device__ static constexpr bool IsKnownAtCompileTime()
176  {
180  }
181 
182  __host__ __device__ void Print() const
183  {
184  printf("{");
185  printf("Pad, ");
186  printf("up_lengths_");
188  printf("left_pad_length %d", index_t{left_pad_length_});
189  printf("right_pad_length %d", index_t{right_pad_length_});
190  printf("}");
191  }
192 };
193 
194 template <typename LowLength, typename LeftPadLength, bool SkipIsValidCheck = false>
195 struct LeftPad
196 {
199 
200  using UpLengths = decltype(make_tuple(LowLength{} + LeftPadLength{}));
201 
203  LeftPadLength left_pad_length_;
204 
205  __host__ __device__ constexpr LeftPad() = default;
206 
207  __host__ __device__ constexpr LeftPad(const LowLength& low_length,
208  const LeftPadLength& left_pad_length)
209  : up_lengths_{make_tuple(low_length + left_pad_length)}, left_pad_length_{left_pad_length}
210  {
211  }
212 
213  __host__ __device__ static constexpr index_t GetNumOfLowerDimension() { return 1; }
214 
215  __host__ __device__ static constexpr index_t GetNumOfUpperDimension() { return 1; }
216 
217  __host__ __device__ constexpr const auto& GetUpperLengths() const { return up_lengths_; }
218 
219  template <typename LowIdx, typename UpIdx>
220  __host__ __device__ constexpr void CalculateLowerIndex(LowIdx& idx_low,
221  const UpIdx& idx_up) const
222  {
223  static_assert(LowIdx::Size() == 1 && UpIdx::Size() == 1,
224  "wrong! inconsistent # of dimension");
225 
226  idx_low(Number<0>{}) = idx_up[Number<0>{}] - left_pad_length_;
227  }
228 
229  template <typename LowIdxDiff,
230  typename UpIdxDiff,
231  typename LowIdx,
232  typename UpIdx,
233  index_t Hack>
234  __host__ __device__ static void UpdateLowerIndex(LowIdxDiff& idx_diff_low,
235  const UpIdxDiff& idx_diff_up,
236  LowIdx& idx_low,
237  const UpIdx&,
238  Number<Hack>)
239  {
240  static_assert(LowIdxDiff::Size() == 1 && UpIdxDiff::Size() == 1 && LowIdx::Size() == 1 &&
241  UpIdx::Size() == 1,
242  "wrong! inconsistent # of dimension");
243 
244  constexpr auto I0 = Number<0>{};
245 
246  idx_diff_low(I0) = idx_diff_up[I0];
247 
248  idx_low += idx_diff_low;
249  }
250 
251  __host__ __device__ static constexpr bool IsLinearTransform() { return true; }
252 
253  __host__ __device__ static constexpr bool IsValidUpperIndexAlwaysMappedToValidLowerIndex()
254  {
255  return SkipIsValidCheck;
256  }
257 
258  template <typename UpIdx>
259  __host__ __device__ constexpr bool
260  IsValidUpperIndexMappedToValidLowerIndex(const UpIdx& idx_up) const
261  {
262  return SkipIsValidCheck || (idx_up[Number<0>{}] >= left_pad_length_);
263  }
264 
265  __host__ __device__ static constexpr bool IsKnownAtCompileTime()
266  {
269  }
270 
271  __host__ __device__ void Print() const
272  {
273  printf("{");
274  printf("LeftPad, ");
275  printf("up_lengths_");
277  printf("left_pad_length_ %d", index_t{left_pad_length_});
278  printf("}");
279  }
280 };
281 
282 template <typename LowLength, typename RightPadLength, bool SkipIsValidCheck = false>
283 struct RightPad
284 {
287 
288  using UpLengths = decltype(make_tuple(LowLength{} + RightPadLength{}));
289 
291  LowLength low_length_;
292  RightPadLength right_pad_length_;
293 
294  __host__ __device__ constexpr RightPad() = default;
295 
296  __host__ __device__ constexpr RightPad(const LowLength& low_length,
297  const RightPadLength& right_pad_length)
298  : up_lengths_{make_tuple(low_length + right_pad_length)},
299  low_length_{low_length},
300  right_pad_length_{right_pad_length}
301  {
302  }
303 
304  __host__ __device__ static constexpr index_t GetNumOfLowerDimension() { return 1; }
305 
306  __host__ __device__ static constexpr index_t GetNumOfUpperDimension() { return 1; }
307 
308  __host__ __device__ constexpr const auto& GetUpperLengths() const { return up_lengths_; }
309 
310  template <typename LowIdx, typename UpIdx>
311  __host__ __device__ static constexpr void CalculateLowerIndex(LowIdx& idx_low,
312  const UpIdx& idx_up)
313  {
314  static_assert(LowIdx::Size() == 1 && UpIdx::Size() == 1,
315  "wrong! inconsistent # of dimension");
316 
317  idx_low(Number<0>{}) = idx_up[Number<0>{}];
318  }
319 
320  template <typename LowIdxDiff,
321  typename UpIdxDiff,
322  typename LowIdx,
323  typename UpIdx,
324  index_t Hack>
325  __host__ __device__ static void UpdateLowerIndex(LowIdxDiff& idx_diff_low,
326  const UpIdxDiff& idx_diff_up,
327  LowIdx& idx_low,
328  const UpIdx&,
329  Number<Hack>)
330  {
331  static_assert(LowIdxDiff::Size() == 1 && UpIdxDiff::Size() == 1 && LowIdx::Size() == 1 &&
332  UpIdx::Size() == 1,
333  "wrong! inconsistent # of dimension");
334 
335  constexpr auto I0 = Number<0>{};
336 
337  idx_diff_low(I0) = idx_diff_up[I0];
338 
339  idx_low += idx_diff_low;
340  }
341 
342  __host__ __device__ static constexpr bool IsLinearTransform() { return true; }
343 
344  __host__ __device__ static constexpr bool IsValidUpperIndexAlwaysMappedToValidLowerIndex()
345  {
346  return SkipIsValidCheck;
347  }
348 
349  template <typename UpIdx>
350  __host__ __device__ constexpr bool
351  IsValidUpperIndexMappedToValidLowerIndex(const UpIdx& idx_up) const
352  {
353  return SkipIsValidCheck || (idx_up[Number<0>{}] < low_length_);
354  }
355 
356  __host__ __device__ static constexpr bool IsKnownAtCompileTime()
357  {
361  }
362 
363  __host__ __device__ void Print() const
364  {
365  printf("{");
366  printf("RightPad, ");
367  printf("up_lengths_");
369  printf("low_length_ %d", index_t{low_length_});
370  printf("left_pad_length_ %d", index_t{right_pad_length_});
371  printf("}");
372  }
373 };
374 
375 // idx_low = coefficients[0, ...nDimUp-1] * idx_up[0, ...nDimUp-1]
376 // UpLengths and Coefficients can be either of the followings:
377 // 1) Tuple of index_t, which is known at run-time, or
378 // 2) Tuple of Number, which is known at compile-time, or
379 // 3) Tuple of mixture of index_t and Number, which is known partially at run-time and partially
380 // at compile-time
381 template <typename UpLengths,
382  typename Coefficients,
383  typename enable_if<UpLengths::Size() == Coefficients::Size(), bool>::type = false>
384 struct Embed
385 {
386  static constexpr index_t NDimUp = UpLengths::Size();
387 
390 
391  UpLengths up_lengths_;
392  Coefficients coefficients_;
393 
394  __host__ __device__ constexpr Embed() = default;
395 
396  __host__ __device__ constexpr Embed(const UpLengths& up_lengths,
397  const Coefficients& coefficients)
398  : up_lengths_{up_lengths}, coefficients_{coefficients}
399  {
400  }
401 
402  __host__ __device__ static constexpr index_t GetNumOfLowerDimension() { return 1; }
403 
404  __host__ __device__ static constexpr index_t GetNumOfUpperDimension() { return NDimUp; }
405 
406  __host__ __device__ constexpr const auto& GetUpperLengths() const { return up_lengths_; }
407 
408  template <typename LowIdx, typename UpIdx>
409  __host__ __device__ constexpr void CalculateLowerIndex(LowIdx& idx_low,
410  const UpIdx& idx_up) const
411  {
412  static_assert(LowIdx::Size() == 1 && UpIdx::Size() == NDimUp,
413  "wrong! inconsistent # of dimension");
414 
415  idx_low(Number<0>{}) = 0;
416 
417  static_for<0, NDimUp, 1>{}([&idx_low, &idx_up, this](auto i) {
418  idx_low(Number<0>{}) += idx_up[i] * this->coefficients_[i];
419  });
420  }
421 
422  template <typename LowIdxDiff,
423  typename UpIdxDiff,
424  typename LowIdx,
425  typename UpIdx,
426  index_t Hack>
427  __host__ __device__ void UpdateLowerIndex(LowIdxDiff& idx_diff_low,
428  const UpIdxDiff& idx_diff_up,
429  LowIdx& idx_low,
430  const UpIdx&,
431  Number<Hack>) const
432  {
433  static_assert(LowIdxDiff::Size() == 1 && UpIdxDiff::Size() == NDimUp &&
434  LowIdx::Size() == 1 && UpIdx::Size() == NDimUp,
435  "wrong! inconsistent # of dimension");
436 
437  idx_diff_low(Number<0>{}) = 0;
438 
440  [&](auto i) { idx_diff_low(Number<0>{}) += idx_diff_up[i] * coefficients_[i]; });
441 
442  idx_low += idx_diff_low;
443  }
444 
445  __host__ __device__ static constexpr bool IsLinearTransform() { return true; }
446 
447  __host__ __device__ static constexpr bool IsValidUpperIndexAlwaysMappedToValidLowerIndex()
448  {
449  return true;
450  }
451 
452  template <typename UpIdx>
453  __host__ __device__ static constexpr bool
454  IsValidUpperIndexMappedToValidLowerIndex(const UpIdx& /* idx_up */)
455  {
456  return true;
457  }
458 
459  __host__ __device__ static constexpr bool IsKnownAtCompileTime()
460  {
463  }
464 
465  __host__ __device__ void Print() const
466  {
467  printf("{");
468  printf("Embed, ");
469  printf("up_lengths_ ");
471  printf("coefficients_ ");
473  printf("}");
474  }
475 };
476 
477 // Implementation of "Merge" transformation primitive that uses regular to do lowering of
478 // multi-index and use carry-and-borrow check to do lowering of multi-index delta
479 template <typename LowLengths>
481 {
482  static constexpr index_t NDimLow = LowLengths::Size();
483 
486 
488  decltype(container_reverse_exclusive_scan(LowLengths{}, math::multiplies{}, Number<1>{}));
489 
490  using UpLengths =
491  decltype(make_tuple(container_reduce(LowLengths{}, math::multiplies{}, Number<1>{})));
492 
493  LowLengths low_lengths_;
496 
497  __host__ __device__ constexpr Merge_v1_carry_check() = default;
498 
499  __host__ __device__ constexpr Merge_v1_carry_check(const LowLengths& low_lengths)
500  : low_lengths_{low_lengths},
502  container_reverse_exclusive_scan(low_lengths, math::multiplies{}, Number<1>{})},
503  up_lengths_{make_tuple(container_reduce(low_lengths, math::multiplies{}, Number<1>{}))}
504  {
505  static_assert(LowerIndex::Size() == NDimLow, "wrong!");
506  }
507 
508  __host__ __device__ static constexpr index_t GetNumOfLowerDimension() { return NDimLow; }
509 
510  __host__ __device__ static constexpr index_t GetNumOfUpperDimension() { return 1; }
511 
512  __host__ __device__ constexpr const auto& GetUpperLengths() const { return up_lengths_; }
513 
514  template <typename LowIdx, typename UpIdx>
515  __host__ __device__ constexpr void CalculateLowerIndex(LowIdx& idx_low,
516  const UpIdx& idx_up) const
517  {
518  static_assert(LowIdx::Size() == NDimLow && UpIdx::Size() == 1,
519  "wrong! inconsistent # of dimension");
520 
521  index_t tmp = idx_up[Number<0>{}];
522 
523  // normal division
524  static_for<0, NDimLow - 1, 1>{}([&](auto i) {
525  idx_low(i) = tmp / this->low_lengths_scan_[i];
526  tmp -= idx_low[i] * this->low_lengths_scan_[i];
527  });
528 
529  idx_low(Number<NDimLow - 1>{}) = tmp;
530  }
531 
532  template <typename LowIdxDiff,
533  typename UpIdxDiff,
534  typename LowIdx,
535  typename UpIdx,
536  index_t Hack>
537  __host__ __device__ void UpdateLowerIndex_1a(LowIdxDiff& idx_diff_low,
538  const UpIdxDiff& idx_diff_up,
539  LowIdx& idx_low,
540  const UpIdx& /* idx_up_new */,
541  Number<Hack>) const
542  {
543  static_assert(LowIdxDiff::Size() == NDimLow && UpIdxDiff::Size() == 1 &&
544  LowIdx::Size() == NDimLow && UpIdx::Size() == 1,
545  "wrong! inconsistent # of dimension");
546 
547  // CalculateLowerIndex(idx_diff_low_const) has multiple integer divisions.
548  // However,
549  // 1) If idx_diff_up is known at compile-time, then idx_diff_low_const
550  // can be calculated at compile-time.
551  // 2) If idx_diff_up is not known at compile-time, but its value
552  // doesn't change during the whole kernel execution, then
553  // idx_diff_low_const also
554  // doesn't change during the whole kernel execution. Compiler generated
555  // ISA should
556  // only caclculate idx_diff_low_const once and save it durinng the whole
557  // kernel execution
558  // If neither 1) nor 2) is satisfied, then the calculation will also be
559  // computed at
560  // run-time each time this function is called, and can be very expensive.
561  LowerIndex idx_diff_low_const;
562  LowerIndex idx_low_length_minus_idx_diff_low_const;
563  LowerIndex idx_low_length_plus_idx_diff_low_const;
564 
565 #if !CK_HACK_MERGE_CALCULATE_IDX_DIFF_LOW_CONST_USE_AMD_GCN_READ_FIRST_LANE
566  index_t tmp = idx_diff_up[Number<0>{}];
567 
568  static_for<0, NDimLow - 1, 1>{}([&](auto i) {
569  idx_diff_low_const(i) = tmp / low_lengths_scan_[i];
570  tmp -= idx_diff_low_const[i] * low_lengths_scan_[i];
571  });
572 
573  idx_diff_low_const(Number<NDimLow - 1>{}) = tmp;
574 
575  static_for<0, NDimLow, 1>{}([&](auto i) {
576  idx_low_length_minus_idx_diff_low_const(i) = low_lengths_[i] - idx_diff_low_const[i];
577 
578  idx_low_length_plus_idx_diff_low_const(i) = low_lengths_[i] + idx_diff_low_const[i];
579  });
580 #else
581  // Hack: this force result into SGPR. Need to make sure the result is thread invariant
582  index_t tmp = idx_diff_up[Number<0>{}];
583 
584  static_for<0, NDimLow - 1, 1>{}([&](auto i) {
585  idx_diff_low_const(i) = __builtin_amdgcn_readfirstlane(tmp / low_lengths_scan_[i]);
586  tmp -= idx_diff_low_const[i] * low_lengths_scan_[i];
587  });
588 
589  idx_diff_low_const(Number<NDimLow - 1>{}) = __builtin_amdgcn_readfirstlane(tmp);
590 
591  static_for<0, NDimLow, 1>{}([&](auto i) {
592  idx_low_length_minus_idx_diff_low_const(i) =
593  __builtin_amdgcn_readfirstlane(low_lengths_[i] - idx_diff_low_const[i]);
594 
595  idx_low_length_plus_idx_diff_low_const(i) =
596  __builtin_amdgcn_readfirstlane(low_lengths_[i] + idx_diff_low_const[i]);
597  });
598 #endif
599 
600  if constexpr(Hack == 1)
601  {
602  // do carry check on each low dimension in reversed order
603  // do not need to check the first dimension
604  index_t carry = 0;
605 
606  static_for<NDimLow - 1, 0, -1>{}([&](auto i) {
607  index_t idx_low_tmp = idx_low[i] + carry;
608 
609  bool do_carry = idx_low_tmp >= idx_low_length_minus_idx_diff_low_const[i];
610 
611  idx_diff_low(i) =
612  do_carry ? -idx_low_length_minus_idx_diff_low_const[i] : idx_diff_low_const[i];
613 
614  idx_diff_low(i) += carry;
615 
616  carry = do_carry ? 1 : 0;
617  });
618 
619  idx_diff_low(Number<0>{}) = idx_diff_low_const[Number<0>{}] + carry;
620 
621  idx_low += idx_diff_low;
622  }
623  else if constexpr(Hack == 2)
624  {
625  // do carry check on each low dimension in reversed order
626  // do not need to check the first dimension
627  index_t borrow = 0;
628 
629  static_for<NDimLow - 1, 0, -1>{}([&](auto i) {
630  index_t idx_low_tmp = idx_low[i] - borrow;
631 
632  bool do_borrow = idx_low_tmp < -idx_diff_low_const[i];
633 
634  idx_diff_low(i) =
635  do_borrow ? idx_low_length_plus_idx_diff_low_const[i] : idx_diff_low_const[i];
636 
637  idx_diff_low(i) -= borrow;
638 
639  borrow = do_borrow ? 1 : 0;
640  });
641 
642  idx_diff_low(Number<0>{}) = idx_diff_low_const[Number<0>{}] - borrow;
643 
644  idx_low += idx_diff_low;
645  }
646  else
647  {
648  // do carry check on each low dimension in reversed order
649  // do not need to check the first dimension
650  index_t carry = 0;
651 
652  static_for<NDimLow - 1, 0, -1>{}([&](auto i) {
653  index_t idx_low_tmp = idx_low[i] + carry;
654 
655  bool do_carry = idx_low_tmp >= idx_low_length_minus_idx_diff_low_const[i];
656  bool do_borrow = idx_low_tmp < -idx_diff_low_const[i];
657 
658  idx_diff_low(i) =
659  do_carry ? -idx_low_length_minus_idx_diff_low_const[i] : idx_diff_low_const[i];
660  idx_diff_low(i) =
661  do_borrow ? idx_low_length_plus_idx_diff_low_const[i] : idx_diff_low[i];
662 
663  idx_diff_low(i) += carry;
664 
665  carry = do_carry ? 1 : 0;
666  carry = do_borrow ? -1 : carry;
667  });
668 
669  idx_diff_low(Number<0>{}) = idx_diff_low_const[Number<0>{}] + carry;
670 
671  idx_low += idx_diff_low;
672  }
673  }
674 
675  template <typename LowIdxDiff,
676  typename UpIdxDiff,
677  typename LowIdx,
678  typename UpIdx,
679  index_t Hack>
680  __host__ __device__ void UpdateLowerIndex_1b(LowIdxDiff& idx_diff_low,
681  const UpIdxDiff& idx_diff_up,
682  LowIdx& idx_low,
683  const UpIdx& /* idx_up_new */,
684  Number<Hack>) const
685  {
686  static_assert(LowIdxDiff::Size() == NDimLow && UpIdxDiff::Size() == 1 &&
687  LowIdx::Size() == NDimLow && UpIdx::Size() == 1,
688  "wrong! inconsistent # of dimension");
689 
690  // CalculateLowerIndex(idx_diff_low_const) has multiple integer divisions.
691  // However,
692  // 1) If idx_diff_up is known at compile-time, then idx_diff_low_const
693  // can be calculated at compile-time.
694  // 2) If idx_diff_up is not known at compile-time, but its value
695  // doesn't change during the whole kernel execution, then
696  // idx_diff_low_const also
697  // doesn't change during the whole kernel execution. Compiler generated
698  // ISA should
699  // only caclculate idx_diff_low_const once and save it durinng the whole
700  // kernel execution
701  // If neither 1) nor 2) is satisfied, then the calculation will also be
702  // computed at
703  // run-time each time this function is called, and can be very expensive.
704  LowerIndex idx_diff_low_const;
705  LowerIndex idx_low_length_minus_idx_diff_low_const;
706  LowerIndex idx_low_length_plus_idx_diff_low_const;
707 
708 #if !CK_HACK_MERGE_CALCULATE_IDX_DIFF_LOW_CONST_USE_AMD_GCN_READ_FIRST_LANE
709  index_t tmp = idx_diff_up[Number<0>{}];
710 
711  static_for<0, NDimLow - 1, 1>{}([&](auto i) {
712  idx_diff_low_const(i) = tmp / low_lengths_scan_[i];
713  tmp -= idx_diff_low_const[i] * low_lengths_scan_[i];
714  });
715 
716  idx_diff_low_const(Number<NDimLow - 1>{}) = tmp;
717 
718  static_for<0, NDimLow, 1>{}([&](auto i) {
719  idx_low_length_minus_idx_diff_low_const(i) = low_lengths_[i] - idx_diff_low_const[i];
720 
721  idx_low_length_plus_idx_diff_low_const(i) = low_lengths_[i] + idx_diff_low_const[i];
722  });
723 #else
724  // Hack: this force result into SGPR. Need to make sure the result is thread invariant
725  index_t tmp = idx_diff_up[Number<0>{}];
726 
727  static_for<0, NDimLow - 1, 1>{}([&](auto i) {
728  idx_diff_low_const(i) = __builtin_amdgcn_readfirstlane(tmp / low_lengths_scan_[i]);
729  tmp -= idx_diff_low_const[i] * low_lengths_scan_[i];
730  });
731 
732  idx_diff_low_const(Number<NDimLow - 1>{}) = __builtin_amdgcn_readfirstlane(tmp);
733 
734  static_for<0, NDimLow, 1>{}([&](auto i) {
735  idx_low_length_minus_idx_diff_low_const(i) =
736  __builtin_amdgcn_readfirstlane(low_lengths_[i] - idx_diff_low_const[i]);
737 
738  idx_low_length_plus_idx_diff_low_const(i) = low_lengths_[i] + idx_diff_low_const[i];
739  });
740 #endif
741 
742  if constexpr(Hack == 1)
743  {
744  // do carry check on each low dimension in reversed order
745  // do not need to check the first dimension
746  index_t carry = 0;
747 
748  static_for<NDimLow - 1, 0, -1>{}([&](auto i) {
749  index_t idx_low_tmp = idx_low[i] + carry;
750 
751  bool do_carry = idx_low_tmp >= idx_low_length_minus_idx_diff_low_const[i];
752 
753  idx_diff_low(i) =
754  do_carry ? -idx_low_length_minus_idx_diff_low_const[i] : idx_diff_low_const[i];
755 
756  idx_diff_low(i) += carry;
757 
758  carry = do_carry ? 1 : 0;
759  });
760 
761  idx_diff_low(Number<0>{}) = idx_diff_low_const[Number<0>{}] + carry;
762 
763  idx_low += idx_diff_low;
764  }
765  else if constexpr(Hack == 2)
766  {
767  // do carry check on each low dimension in reversed order
768  // do not need to check the first dimension
769  index_t borrow = 0;
770 
771  static_for<NDimLow - 1, 0, -1>{}([&](auto i) {
772  index_t negative_idx_low_tmp = borrow - idx_low[i];
773 
774  bool do_borrow = negative_idx_low_tmp > idx_diff_low_const[i];
775 
776  idx_diff_low(i) =
777  do_borrow ? idx_low_length_plus_idx_diff_low_const[i] : idx_diff_low_const[i];
778 
779  idx_diff_low(i) -= borrow;
780 
781  borrow = do_borrow ? 1 : 0;
782  });
783 
784  idx_diff_low(Number<0>{}) = idx_diff_low_const[Number<0>{}] - borrow;
785 
786  idx_low += idx_diff_low;
787  }
788  else
789  {
790  // do carry check on each low dimension in reversed order
791  // do not need to check the first dimension
792  index_t carry = 0;
793 
794  static_for<NDimLow - 1, 0, -1>{}([&](auto i) {
795  index_t idx_low_tmp = idx_low[i] + carry;
796 
797  bool do_carry = idx_low_tmp >= idx_low_length_minus_idx_diff_low_const[i];
798  bool do_borrow = idx_low_tmp < -idx_diff_low_const[i];
799 
800  idx_diff_low(i) =
801  do_carry ? -idx_low_length_minus_idx_diff_low_const[i] : idx_diff_low_const[i];
802  idx_diff_low(i) =
803  do_borrow ? idx_low_length_plus_idx_diff_low_const[i] : idx_diff_low[i];
804 
805  idx_diff_low(i) += carry;
806 
807  carry = do_carry ? 1 : 0;
808  carry = do_borrow ? -1 : carry;
809  });
810 
811  idx_diff_low(Number<0>{}) = idx_diff_low_const[Number<0>{}] + carry;
812 
813  idx_low += idx_diff_low;
814  }
815  }
816 
817  template <typename LowIdxDiff,
818  typename UpIdxDiff,
819  typename LowIdx,
820  typename UpIdx,
821  index_t Hack>
822  __host__ __device__ void UpdateLowerIndex_2(LowIdxDiff& idx_diff_low,
823  const UpIdxDiff& idx_diff_up,
824  LowIdx& idx_low,
825  const UpIdx& /* idx_up_new */,
826  Number<Hack>) const
827  {
828  static_assert(LowIdxDiff::Size() == NDimLow && UpIdxDiff::Size() == 1 &&
829  LowIdx::Size() == NDimLow && UpIdx::Size() == 1,
830  "wrong! inconsistent # of dimension");
831 
832  // CalculateLowerIndex(idx_diff_low_const) has multiple integer divisions.
833  // However,
834  // 1) If idx_diff_up is known at compile-time, then idx_diff_low_const
835  // can be calculated at compile-time.
836  // 2) If idx_diff_up is not known at compile-time, but its value
837  // doesn't change during the whole kernel execution, then
838  // idx_diff_low_const also
839  // doesn't change during the whole kernel execution. Compiler generated
840  // ISA should
841  // only caclculate idx_diff_low_const once and save it durinng the whole
842  // kernel execution
843  // If neither 1) nor 2) is satisfied, then the calculation will also be
844  // computed at run-time each time this function is called, and can be
845  // very expensive.
846  LowerIndex idx_diff_low_const;
847 
848 #if !CK_HACK_MERGE_CALCULATE_IDX_DIFF_LOW_CONST_USE_AMD_GCN_READ_FIRST_LANE
849  index_t tmp = idx_diff_up[Number<0>{}];
850 
851  static_for<0, NDimLow - 1, 1>{}([&](auto i) {
852  idx_diff_low_const(i) = tmp / low_lengths_scan_[i];
853  tmp -= idx_diff_low_const[i] * low_lengths_scan_[i];
854  });
855 
856  idx_diff_low_const(Number<NDimLow - 1>{}) = tmp;
857 #else
858  // Hack: this force result into SGPR. Need to make sure the result is thread invariant
859  index_t tmp = idx_diff_up[Number<0>{}];
860 
861  static_for<0, NDimLow - 1, 1>{}([&](auto i) {
862  idx_diff_low_const(i) = __builtin_amdgcn_readfirstlane(tmp / low_lengths_scan_[i]);
863  tmp -= idx_diff_low_const[i] * low_lengths_scan_[i];
864  });
865 
866  idx_diff_low_const(Number<NDimLow - 1>{}) = __builtin_amdgcn_readfirstlane(tmp);
867 #endif
868 
869  if constexpr(Hack == 1)
870  {
871  // do carry check on each low dimension in reversed order
872  // do not need to check the first dimension
873  bool do_carry = 0;
874 
875  static_for<NDimLow - 1, 0, -1>{}([&](auto i) {
876  idx_diff_low(i) = idx_diff_low_const[i] + do_carry;
877 
878  index_t idx_low_tmp = idx_low[i] + idx_diff_low[i];
879 
880  do_carry = idx_low_tmp >= low_lengths_[i];
881 
882 #if 0
883  // TODO: use exec-mask inline asm, which use 1 VALU
884  if(do_carry)
885  {
886  idx_diff_low(i) -= low_lengths_[i];
887  }
888 #elif 1
889  // this use 2 VALU
890  idx_diff_low(i) = do_carry ? idx_diff_low[i] - low_lengths_[i] : idx_diff_low[i];
891 #elif 1
892  // this use 2 VALU
893  index_t idx_diff_low_tmp = idx_diff_low[i] - low_lengths_[i];
894  idx_diff_low(i) = do_carry ? idx_diff_low_tmp : idx_diff_low[i];
895 #endif
896 
897  idx_low(i) += idx_diff_low[i];
898  });
899 
900  constexpr auto I0 = Number<0>{};
901 
902  idx_diff_low(I0) = idx_diff_low_const[I0] + do_carry;
903 
904  idx_low(I0) += idx_diff_low[I0];
905  }
906  else if constexpr(Hack == 2)
907  {
908  // do borrow check on each low dimension in reversed order
909  // do not need to check the first dimension
910  bool do_borrow = 0;
911 
912  static_for<NDimLow - 1, 0, -1>{}([&](auto i) {
913  idx_diff_low(i) = idx_diff_low_const[i] - do_borrow;
914 
915  index_t idx_low_tmp = idx_low[i] + idx_diff_low[i];
916 
917  do_borrow = idx_low_tmp < 0;
918 
919 #if 0
920  // TODO: use exec-mask inline asm
921  if(do_borrow)
922  {
923  idx_diff_low(i) += low_lengths_[i];
924  }
925 #elif 1
926  idx_diff_low(i) = do_borrow ? idx_diff_low[i] + low_lengths_[i] : idx_diff_low[i];
927 #elif 1
928  index_t idx_diff_low_tmp = idx_diff_low[i] + low_lengths_[i];
929  idx_diff_low(i) = do_borrow ? idx_diff_low_tmp : idx_diff_low[i];
930 #endif
931 
932  idx_low(i) += idx_diff_low[i];
933  });
934 
935  constexpr auto I0 = Number<0>{};
936 
937  idx_diff_low(I0) = idx_diff_low_const[I0] - do_borrow;
938 
939  idx_low(I0) += idx_diff_low[I0];
940  }
941  else
942  {
943  // not implemented
944  }
945  }
946 
947  template <typename LowIdxDiff,
948  typename UpIdxDiff,
949  typename LowIdx,
950  typename UpIdx,
951  index_t Hack>
952  __host__ __device__ void UpdateLowerIndex(LowIdxDiff& idx_diff_low,
953  const UpIdxDiff& idx_diff_up,
954  LowIdx& idx_low,
955  const UpIdx& idx_up_new,
956  Number<Hack>) const
957  {
958 #if 1
959  UpdateLowerIndex_1a(idx_diff_low, idx_diff_up, idx_low, idx_up_new, Number<Hack>{});
960 #elif 0
961  UpdateLowerIndex_1b(idx_diff_low, idx_diff_up, idx_low, idx_up_new, Number<Hack>{});
962 #else
963  UpdateLowerIndex_2(idx_diff_low, idx_diff_up, idx_low, idx_up_new, Number<Hack>{});
964 #endif
965  }
966 
967  __host__ __device__ static constexpr bool IsLinearTransform() { return false; }
968 
969  __host__ __device__ static constexpr bool IsValidUpperIndexAlwaysMappedToValidLowerIndex()
970  {
971  return true;
972  }
973 
974  __host__ __device__ static constexpr bool IsKnownAtCompileTime()
975  {
979  }
980 
981  template <typename UpIdx>
982  __host__ __device__ static constexpr bool
983  IsValidUpperIndexMappedToValidLowerIndex(const UpIdx& /* idx_up */)
984  {
985  return true;
986  }
987 
988  __host__ __device__ void Print() const
989  {
990  printf("{");
991  printf("Merge_v1_carry_check, ");
992  printf("low_lengths_ ");
993  print_multi_index(low_lengths_);
994  printf("low_lengths_scan_ ");
995  print_multi_index(low_lengths_scan_);
996  printf("up_lengths_ ");
997  print_multi_index(up_lengths_);
998  printf("}");
999  }
1000 };
1001 
1002 template <typename LowLengths>
1004 {
1005  template <index_t I>
1006  __host__ __device__ constexpr auto operator()(Number<I> i) const
1007  {
1008  return MagicDivision::CalculateMagicMultiplier(LowLengths{}[i]);
1009  }
1010 };
1011 
1012 template <typename LowLengths>
1014 {
1015  template <index_t I>
1016  __host__ __device__ constexpr auto operator()(Number<I> i) const
1017  {
1018  return MagicDivision::CalculateMagicShift(LowLengths{}[i]);
1019  }
1020 };
1021 
1022 // Implementation of "Merge" transformation primitive that uses magic-number-division to do lowering
1023 // of both multi-index and delta of multi-index
1024 // Caution:
1025 // 1. The magic number division implementation being used would produce correct result if the
1026 // dividended is uint32_t and its value is with in 31-bit value range of uint32_t.
1027 // 2. The magic number division for int32_t dividened has not been implemented, the int32_t
1028 // dividend would be bit-wise interpreted as uint32_t and magic number division implementation for
1029 // uint32_t is then used.
1030 // 3. For Merge primitive, upper-index is the dividend.
1031 // 4. When upper-index is uint32_t, its value need to be within 31-bit range.
1032 // 5. When upper-index is int32_t type (when index_t is int32_t), its value need to be
1033 // non-negative.
1034 template <typename LowLengths>
1036 {
1037  static constexpr index_t NDimLow = LowLengths::Size();
1038 
1041 
1042  using UpLengths =
1043  decltype(make_tuple(container_reduce(LowLengths{}, math::multiplies{}, Number<1>{})));
1044 
1047  Number<NDimLow>{}));
1048 
1051  Number<NDimLow>{}));
1052 
1053  LowLengths low_lengths_;
1057 
1058  __host__ __device__ constexpr Merge_v2_magic_division() = default;
1059 
1060  __host__ __device__ constexpr Merge_v2_magic_division(const LowLengths& low_lengths)
1061  : low_lengths_{low_lengths},
1062  low_lengths_magic_divisor_multiplier_{generate_tuple(
1063  [&](auto i) { return MagicDivision::CalculateMagicMultiplier(low_lengths[i]); },
1064  Number<NDimLow>{})},
1065  low_lengths_magic_divisor_shift_{generate_tuple(
1066  [&](auto i) { return MagicDivision::CalculateMagicShift(low_lengths[i]); },
1067  Number<NDimLow>{})},
1068  up_lengths_{make_tuple(container_reduce(low_lengths, math::multiplies{}, Number<1>{}))}
1069  {
1070  static_assert(LowerIndex::Size() == NDimLow, "wrong!");
1071  }
1072 
1073  __host__ __device__ static constexpr index_t GetNumOfLowerDimension() { return NDimLow; }
1074 
1075  __host__ __device__ static constexpr index_t GetNumOfUpperDimension() { return 1; }
1076 
1077  __host__ __device__ constexpr const auto& GetUpperLengths() const { return up_lengths_; }
1078 
1079  template <typename LowIdx, typename UpIdx>
1080  __host__ __device__ constexpr void CalculateLowerIndex(LowIdx& idx_low,
1081  const UpIdx& idx_up) const
1082  {
1083  static_assert(LowIdx::Size() == NDimLow && UpIdx::Size() == 1,
1084  "wrong! inconsistent # of dimension");
1085 
1086  index_t tmp = idx_up[Number<0>{}];
1087 
1088  static_for<NDimLow - 1, 0, -1>{}([&, this](auto i) {
1089  index_t tmp2 =
1090  MagicDivision::DoMagicDivision(tmp,
1091  this->low_lengths_magic_divisor_multiplier_[i],
1092  this->low_lengths_magic_divisor_shift_[i]);
1093  idx_low(i) = tmp - tmp2 * this->low_lengths_[i];
1094  tmp = tmp2;
1095  });
1096 
1097  idx_low(Number<0>{}) = tmp;
1098  }
1099 
1100  template <typename LowIdxDiff,
1101  typename UpIdxDiff,
1102  typename LowIdx,
1103  typename UpIdx,
1104  index_t Hack>
1105  __host__ __device__ void UpdateLowerIndex(LowIdxDiff& idx_diff_low,
1106  const UpIdxDiff&,
1107  LowIdx& idx_low,
1108  const UpIdx& idx_up_new,
1109  Number<Hack>) const
1110  {
1111  static_assert(LowIdxDiff::Size() == NDimLow && UpIdxDiff::Size() == 1 &&
1112  LowIdx::Size() == NDimLow && UpIdx::Size() == 1,
1113  "wrong! inconsistent # of dimension");
1114 
1115  index_t tmp = idx_up_new[Number<0>{}];
1116 
1117  static_for<NDimLow - 1, 0, -1>{}([&, this](auto i) {
1118  index_t tmp2 =
1119  MagicDivision::DoMagicDivision(tmp,
1120  this->low_lengths_magic_divisor_multiplier_[i],
1121  this->low_lengths_magic_divisor_shift_[i]);
1122 
1123  index_t idx_low_old = idx_low[i];
1124 
1125  idx_low(i) = tmp - tmp2 * this->low_lengths_[i];
1126  tmp = tmp2;
1127 
1128  idx_diff_low(i) = idx_low[i] - idx_low_old;
1129  });
1130 
1131  idx_diff_low(Number<0>{}) = tmp - idx_low(Number<0>{});
1132 
1133  idx_low(Number<0>{}) = tmp;
1134  }
1135 
1136  __host__ __device__ static constexpr bool IsLinearTransform() { return false; }
1137 
1138  __host__ __device__ static constexpr bool IsValidUpperIndexAlwaysMappedToValidLowerIndex()
1139  {
1140  return true;
1141  }
1142 
1143  __host__ __device__ static constexpr bool IsKnownAtCompileTime()
1144  {
1149  }
1150 
1151  template <typename UpIdx>
1152  __host__ __device__ static constexpr bool
1153  IsValidUpperIndexMappedToValidLowerIndex(const UpIdx& /* idx_up */)
1154  {
1155  return true;
1156  }
1157 
1158  __host__ __device__ void Print() const
1159  {
1160  printf("{");
1161  printf("Merge_v2_magic_division, ");
1162  printf("low_lengths_ ");
1163  print_multi_index(low_lengths_);
1164  printf("low_lengths_magic_divisor_multiplier_ ");
1165  print_multi_index(low_lengths_magic_divisor_multiplier_);
1166  printf("low_lengths_magic_divisor_shift_ ");
1167  print_multi_index(low_lengths_magic_divisor_shift_);
1168  printf("up_lengths_ ");
1169  print_multi_index(up_lengths_);
1170  printf("}");
1171  }
1172 };
1173 
1174 // Implementation of "Merge" transformation primitive that uses magic-number-division to do lowering
1175 // of both multi-index and delta of multi-index
1176 // Caution:
1177 // 1. The magic number division implementation being used would produce correct result if the
1178 // dividended is uint32_t and its value is with in 31-bit value range of uint32_t.
1179 // 2. The magic number division for int32_t dividened has not been implemented, the int32_t
1180 // dividend would be bit-wise interpreted as uint32_t and magic number division implementation for
1181 // uint32_t is then used.
1182 // 3. For Merge primitive, upper-index is the dividend.
1183 // 4. When upper-index is uint32_t, its value need to be within 31-bit range.
1184 // 5. When upper-index is int32_t type (when index_t is int32_t), its value need to be
1185 // non-negative.
1186 template <typename LowLengths>
1188 {
1189  static constexpr index_t NDimLow = LowLengths::Size();
1190 
1193 
1195  decltype(container_reverse_exclusive_scan(LowLengths{}, math::multiplies{}, Number<1>{}));
1196 
1197  using UpLengths =
1198  decltype(make_tuple(container_reduce(LowLengths{}, math::multiplies{}, Number<1>{})));
1199 
1202  Number<NDimLow>{}));
1203 
1206  Number<NDimLow>{}));
1207 
1208  LowLengths low_lengths_;
1213 
1214  __host__ __device__ constexpr Merge_v2r2_magic_division() = default;
1215 
1216  __host__ __device__ constexpr Merge_v2r2_magic_division(const LowLengths& low_lengths)
1217  : low_lengths_{low_lengths},
1218  low_lengths_scan_{
1219  container_reverse_exclusive_scan(low_lengths, math::multiplies{}, Number<1>{})},
1220  low_lengths_scan_magic_divisor_multiplier_{generate_tuple(
1221  [&](auto i) { return MagicDivision::CalculateMagicMultiplier(low_lengths_scan_[i]); },
1222  Number<NDimLow>{})},
1223  low_lengths_scan_magic_divisor_shift_{generate_tuple(
1224  [&](auto i) { return MagicDivision::CalculateMagicShift(low_lengths_scan_[i]); },
1225  Number<NDimLow>{})},
1226  up_lengths_{make_tuple(container_reduce(low_lengths, math::multiplies{}, Number<1>{}))}
1227  {
1228  static_assert(LowerIndex::Size() == NDimLow, "wrong!");
1229  }
1230 
1231  __host__ __device__ static constexpr index_t GetNumOfLowerDimension() { return NDimLow; }
1232 
1233  __host__ __device__ static constexpr index_t GetNumOfUpperDimension() { return 1; }
1234 
1235  __host__ __device__ constexpr const auto& GetUpperLengths() const { return up_lengths_; }
1236 
1237  template <typename LowIdx, typename UpIdx>
1238  __host__ __device__ constexpr void CalculateLowerIndex(LowIdx& idx_low,
1239  const UpIdx& idx_up) const
1240  {
1241  static_assert(LowIdx::Size() == NDimLow && UpIdx::Size() == 1,
1242  "wrong! inconsistent # of dimension");
1243 
1244  index_t tmp = idx_up[Number<0>{}];
1245 
1246  static_for<0, NDimLow - 1, 1>{}([&, this](auto i) {
1247  idx_low(i) =
1248  MagicDivision::DoMagicDivision(tmp,
1249  this->low_lengths_scan_magic_divisor_multiplier_[i],
1250  this->low_lengths_scan_magic_divisor_shift_[i]);
1251 
1252  tmp -= idx_low[i] * this->low_lengths_scan_[i];
1253  });
1254 
1255  idx_low(Number<NDimLow - 1>{}) = tmp;
1256  }
1257 
1258  template <typename LowIdxDiff,
1259  typename UpIdxDiff,
1260  typename LowIdx,
1261  typename UpIdx,
1262  index_t Hack>
1263  __host__ __device__ void UpdateLowerIndex(LowIdxDiff& idx_diff_low,
1264  const UpIdxDiff&,
1265  LowIdx& idx_low,
1266  const UpIdx& idx_up_new,
1267  Number<Hack>) const
1268  {
1269  static_assert(LowIdxDiff::Size() == NDimLow && UpIdxDiff::Size() == 1 &&
1270  LowIdx::Size() == NDimLow && UpIdx::Size() == 1,
1271  "wrong! inconsistent # of dimension");
1272 
1273  index_t tmp = idx_up_new[Number<0>{}];
1274 
1275  static_for<0, NDimLow - 1, 1>{}([&, this](auto i) {
1276  index_t idx_low_old = idx_low[i];
1277 
1278  idx_low(i) =
1279  MagicDivision::DoMagicDivision(tmp,
1280  this->low_lengths_scan_magic_divisor_multiplier_[i],
1281  this->low_lengths_scan_magic_divisor_shift_[i]);
1282 
1283  idx_diff_low(i) = idx_low[i] - idx_low_old;
1284 
1285  tmp -= idx_low[i] * this->low_lengths_scan_[i];
1286  });
1287 
1288  idx_diff_low(Number<NDimLow - 1>{}) = tmp - idx_low[Number<NDimLow - 1>{}];
1289 
1290  idx_low(Number<NDimLow - 1>{}) = tmp;
1291  }
1292 
1293  __host__ __device__ static constexpr bool IsLinearTransform() { return false; }
1294 
1295  __host__ __device__ static constexpr bool IsValidUpperIndexAlwaysMappedToValidLowerIndex()
1296  {
1297  return true;
1298  }
1299 
1300  __host__ __device__ static constexpr bool IsKnownAtCompileTime()
1301  {
1306  }
1307 
1308  template <typename UpIdx>
1309  __host__ __device__ static constexpr bool
1310  IsValidUpperIndexMappedToValidLowerIndex(const UpIdx& /* idx_up */)
1311  {
1312  return true;
1313  }
1314 
1315  __host__ __device__ void Print() const
1316  {
1317  printf("{");
1318  printf("Merge_v2r2_magic_division, ");
1319  printf("low_lengths_ ");
1320  print_multi_index(low_lengths_);
1321  printf("low_lengths_scan ");
1322  print_multi_index(low_lengths_scan_);
1323  printf("low_lengths_scan_magic_divisor_multiplier_ ");
1324  print_multi_index(low_lengths_scan_magic_divisor_multiplier_);
1325  printf("low_lengths_scan_magic_divisor_shift_ ");
1326  print_multi_index(low_lengths_scan_magic_divisor_shift_);
1327  printf("up_lengths_ ");
1328  print_multi_index(up_lengths_);
1329  printf("}");
1330  }
1331 };
1332 
1333 // Implementation of "Merge" transformation primitive that uses division and mod. It is supposed to
1334 // be used for low_lengths that are known at compile time and are power of 2, otherwise performance
1335 // will be very bad
1336 template <typename LowLengths>
1338 {
1339  static constexpr index_t NDimLow = LowLengths::Size();
1340 
1343 
1345  decltype(container_reverse_exclusive_scan(LowLengths{}, math::multiplies{}, Number<1>{}));
1346 
1347  using UpLengths =
1348  decltype(make_tuple(container_reduce(LowLengths{}, math::multiplies{}, Number<1>{})));
1349 
1350  LowLengths low_lengths_;
1353 
1354  __host__ __device__ constexpr Merge_v3_division_mod() = default;
1355 
1356  __host__ __device__ constexpr Merge_v3_division_mod(const LowLengths& low_lengths)
1357  : low_lengths_{low_lengths},
1358  low_lengths_scan_{
1359  container_reverse_exclusive_scan(low_lengths, math::multiplies{}, Number<1>{})},
1360  up_lengths_{make_tuple(container_reduce(low_lengths, math::multiplies{}, Number<1>{}))}
1361  {
1362  static_assert(LowerIndex::Size() == NDimLow, "wrong!");
1363  }
1364 
1365  __host__ __device__ static constexpr index_t GetNumOfLowerDimension() { return NDimLow; }
1366 
1367  __host__ __device__ static constexpr index_t GetNumOfUpperDimension() { return 1; }
1368 
1369  __host__ __device__ constexpr const auto& GetUpperLengths() const { return up_lengths_; }
1370 
1371  template <typename LowIdx, typename UpIdx>
1372  __host__ __device__ constexpr void CalculateLowerIndex(LowIdx& idx_low,
1373  const UpIdx& idx_up) const
1374  {
1375  static_assert(LowIdx::Size() == NDimLow && UpIdx::Size() == 1,
1376  "wrong! inconsistent # of dimension");
1377 
1378  index_t tmp = idx_up[Number<0>{}];
1379 
1380  // division and mod
1381  static_for<0, NDimLow - 1, 1>{}([&](auto i) {
1382  idx_low(i) = tmp / this->low_lengths_scan_[i];
1383  tmp %= this->low_lengths_scan_[i];
1384  });
1385 
1386  idx_low(Number<NDimLow - 1>{}) = tmp;
1387  }
1388 
1389  template <typename LowIdxDiff,
1390  typename UpIdxDiff,
1391  typename LowIdx,
1392  typename UpIdx,
1393  index_t Hack>
1394  __host__ __device__ void UpdateLowerIndex(LowIdxDiff& idx_diff_low,
1395  const UpIdxDiff&,
1396  LowIdx& idx_low,
1397  const UpIdx& idx_up_new,
1398  Number<Hack>) const
1399  {
1400  static_assert(LowIdxDiff::Size() == NDimLow && UpIdxDiff::Size() == 1 &&
1401  LowIdx::Size() == NDimLow && UpIdx::Size() == 1,
1402  "wrong! inconsistent # of dimension");
1403 
1404  constexpr auto I0 = Number<0>{};
1405  constexpr auto INm1 = Number<NDimLow - 1>{};
1406 
1407  index_t tmp = idx_up_new[I0];
1408 
1409  static_for<0, NDimLow - 1, 1>{}([&](auto i) {
1410  const index_t tmp2 = idx_low[i];
1411  idx_low(i) = tmp / this->low_lengths_scan_[i];
1412  idx_diff_low(i) = idx_low[i] - tmp2;
1413  tmp %= this->low_lengths_scan_[i];
1414  });
1415 
1416  const index_t tmp2 = idx_low[INm1];
1417  idx_low(INm1) = tmp;
1418  idx_diff_low(INm1) = idx_low[INm1] - tmp2;
1419  }
1420 
1421  __host__ __device__ static constexpr bool IsLinearTransform() { return false; }
1422 
1423  __host__ __device__ static constexpr bool IsValidUpperIndexAlwaysMappedToValidLowerIndex()
1424  {
1425  return true;
1426  }
1427 
1428  __host__ __device__ static constexpr bool IsKnownAtCompileTime()
1429  {
1433  }
1434 
1435  template <typename UpIdx>
1436  __host__ __device__ static constexpr bool
1437  IsValidUpperIndexMappedToValidLowerIndex(const UpIdx& /* idx_up */)
1438  {
1439  return true;
1440  }
1441 
1442  __host__ __device__ void Print() const
1443  {
1444  printf("{");
1445  printf("Merge_v3_direct_division_mod, ");
1446  printf("low_lengths_ ");
1447  print_multi_index(low_lengths_);
1448  printf("low_lengths_scan_ ");
1449  print_multi_index(low_lengths_scan_);
1450  printf("up_lengths_ ");
1451  print_multi_index(up_lengths_);
1452  printf("}");
1453  }
1454 };
1455 
1456 template <typename UpLengths, bool Use24BitIntegerCalculation>
1457 struct UnMerge
1458 {
1459  static constexpr index_t NDimUp = UpLengths::Size();
1460 
1463 
1465  decltype(container_reverse_exclusive_scan(UpLengths{}, math::multiplies{}, Number<1>{}));
1466 
1467  UpLengths up_lengths_;
1469 
1470  __host__ __device__ constexpr UnMerge() = default;
1471 
1472  __host__ __device__ constexpr UnMerge(const UpLengths& up_lengths)
1473  : up_lengths_{up_lengths},
1474  up_lengths_scan_{
1475  container_reverse_exclusive_scan(up_lengths, math::multiplies{}, Number<1>{})}
1476  {
1477  }
1478 
1479  __host__ __device__ static constexpr index_t GetNumOfLowerDimension() { return 1; }
1480 
1481  __host__ __device__ static constexpr index_t GetNumOfUpperDimension() { return NDimUp; }
1482 
1483  __host__ __device__ constexpr const auto& GetUpperLengths() const { return up_lengths_; }
1484 
1485  template <typename LowIdx, typename UpIdx>
1486  __host__ __device__ constexpr void CalculateLowerIndex(LowIdx& idx_low,
1487  const UpIdx& idx_up) const
1488  {
1489  if constexpr(!Use24BitIntegerCalculation)
1490  {
1491  idx_low(Number<0>{}) = idx_up[Number<NDimUp - 1>{}];
1492 
1493  static_for<0, NDimUp - 1, 1>{}(
1494  [&](auto i) { idx_low(Number<0>{}) += idx_up[i] * up_lengths_scan_[i]; });
1495  }
1496  else
1497  {
1498  idx_low(Number<0>{}) = idx_up[Number<NDimUp - 1>{}];
1499 
1500  static_for<0, NDimUp - 1, 1>{}([&](auto i) {
1501  idx_low(Number<0>{}) =
1502  (0x00ffffff & idx_low[Number<0>{}]) +
1503  (0x00ffffff & idx_up[i]) * (0x00ffffff & up_lengths_scan_[i]);
1504  });
1505  }
1506  }
1507 
1508  template <typename LowIdxDiff,
1509  typename UpIdxDiff,
1510  typename LowIdx,
1511  typename UpIdx,
1512  index_t Hack>
1513  __host__ __device__ void UpdateLowerIndex(LowIdxDiff& idx_diff_low,
1514  const UpIdxDiff& idx_diff_up,
1515  LowIdx& idx_low,
1516  const UpIdx&,
1517  Number<Hack>) const
1518  {
1519  CalculateLowerIndex(idx_diff_low, idx_diff_up);
1520 
1521  idx_low += idx_diff_low;
1522  }
1523 
1524  __host__ __device__ static constexpr bool IsLinearTransform() { return true; }
1525 
1526  __host__ __device__ static constexpr bool IsValidUpperIndexAlwaysMappedToValidLowerIndex()
1527  {
1528  return true;
1529  }
1530 
1531  template <typename UpIdx>
1532  __host__ __device__ static constexpr bool
1533  IsValidUpperIndexMappedToValidLowerIndex(const UpIdx& /* idx_up */)
1534  {
1535  return true;
1536  }
1537 
1538  __host__ __device__ static constexpr bool IsKnownAtCompileTime()
1539  {
1542  }
1543 
1544  __host__ __device__ void Print() const
1545  {
1546  printf("{");
1547  printf("UnMerge, ");
1548  printf("up_lengths_");
1549  print_multi_index(up_lengths_);
1550  printf("up_lengths_scan_");
1551  print_multi_index(up_lengths_scan_);
1552  printf("}");
1553  }
1554 };
1555 
1556 template <typename LowerIndex>
1557 struct Freeze
1558 {
1559  LowerIndex low_idx_;
1560 
1561  __host__ __device__ constexpr Freeze() = default;
1562 
1563  __host__ __device__ constexpr Freeze(const LowerIndex& low_idx) : low_idx_{low_idx} {}
1564 
1565  __host__ __device__ static constexpr index_t GetNumOfLowerDimension() { return 1; }
1566 
1567  __host__ __device__ static constexpr index_t GetNumOfUpperDimension() { return 0; }
1568 
1569  __host__ __device__ static constexpr auto GetUpperLengths() { return Tuple<>{}; }
1570 
1571  template <typename LowIdx, typename UpIdx>
1572  __host__ __device__ constexpr void CalculateLowerIndex(LowIdx& idx_low,
1573  const UpIdx& /* idx_up */) const
1574  {
1575  static_assert(LowIdx::Size() == 1 && UpIdx::Size() == 0,
1576  "wrong! inconsistent # of dimension");
1577 
1578  idx_low(Number<0>{}) = low_idx_;
1579  }
1580 
1581  template <typename LowIdxDiff,
1582  typename UpIdxDiff,
1583  typename LowIdx,
1584  typename UpIdx,
1585  index_t Hack>
1586  __host__ __device__ static void UpdateLowerIndex(LowIdxDiff& idx_diff_low,
1587  const UpIdxDiff& /* idx_diff_up */,
1588  LowIdx& /* idx_low */,
1589  const UpIdx& /* idx_up_new */,
1590  Number<Hack>)
1591  {
1592  idx_diff_low(Number<0>{}) = 0;
1593  }
1594 
1595  __host__ __device__ static constexpr bool IsLinearTransform() { return true; }
1596 
1597  __host__ __device__ static constexpr bool IsValidUpperIndexAlwaysMappedToValidLowerIndex()
1598  {
1599  return true;
1600  }
1601 
1602  template <typename UpIdx>
1603  __host__ __device__ static constexpr bool
1604  IsValidUpperIndexMappedToValidLowerIndex(const UpIdx& /* idx_up */)
1605  {
1606  return true;
1607  }
1608 
1609  __host__ __device__ static constexpr bool IsKnownAtCompileTime()
1610  {
1612  }
1613 
1614  __host__ __device__ void Print() const
1615  {
1616  printf("Freeze");
1617  printf("low_idx_ %d", index_t{low_idx_});
1618  }
1619 };
1620 
1621 // Insert a dangling upper dimension without lower dimension
1622 template <typename UpperLength>
1623 struct Insert
1624 {
1625  using UpLengths = decltype(make_tuple(UpperLength{}));
1626 
1628 
1629  __host__ __device__ constexpr Insert() = default;
1630 
1631  __host__ __device__ constexpr Insert(const UpperLength& up_length)
1632  : up_lengths_{make_tuple(up_length)}
1633  {
1634  }
1635 
1636  __host__ __device__ static constexpr index_t GetNumOfLowerDimension() { return 0; }
1637 
1638  __host__ __device__ static constexpr index_t GetNumOfUpperDimension() { return 1; }
1639 
1640  __host__ __device__ constexpr auto GetUpperLengths() const { return up_lengths_; }
1641 
1642  template <typename LowIdx, typename UpIdx>
1643  __host__ __device__ constexpr void CalculateLowerIndex(LowIdx&, const UpIdx&) const
1644  {
1645  static_assert(LowIdx::Size() == 0 && UpIdx::Size() == 1,
1646  "wrong! inconsistent # of dimension");
1647  }
1648 
1649  template <typename LowIdxDiff,
1650  typename UpIdxDiff,
1651  typename LowIdx,
1652  typename UpIdx,
1653  index_t Hack>
1654  __host__ __device__ static void
1655  UpdateLowerIndex(LowIdxDiff&, const UpIdxDiff&, LowIdx&, const UpIdx&, Number<Hack>)
1656  {
1657  static_assert(LowIdxDiff::Size() == 0 && UpIdxDiff::Size() == 1 && LowIdx::Size() == 0 &&
1658  UpIdx::Size() == 1,
1659  "wrong! inconsistent # of dimension");
1660  }
1661 
1662  __host__ __device__ static constexpr bool IsLinearTransform() { return true; }
1663 
1664  __host__ __device__ static constexpr bool IsValidUpperIndexAlwaysMappedToValidLowerIndex()
1665  {
1666  return true;
1667  }
1668 
1669  template <typename UpIdx>
1670  __host__ __device__ static constexpr bool
1671  IsValidUpperIndexMappedToValidLowerIndex(const UpIdx& /* idx_up */)
1672  {
1673  return true;
1674  }
1675 
1676  __host__ __device__ static constexpr bool IsKnownAtCompileTime()
1677  {
1679  }
1680 
1681  __host__ __device__ void Print() const
1682  {
1683  printf("Insert");
1684  print_multi_index(up_lengths_);
1685  }
1686 };
1687 
1688 template <typename VectorSize, typename UpLength>
1690 {
1693 
1694  using UpLengths = decltype(make_tuple(UpLength{}));
1695 
1697  VectorSize vector_size_;
1698 
1699  __host__ __device__ constexpr Vectorize() = default;
1700 
1701  __host__ __device__ constexpr Vectorize(const VectorSize& vector_size,
1702  const UpLength& up_length)
1703  : vector_size_{vector_size}, up_lengths_{make_tuple(up_length)}
1704  {
1705  }
1706 
1707  __host__ __device__ static constexpr index_t GetNumOfLowerDimension() { return 1; }
1708 
1709  __host__ __device__ static constexpr index_t GetNumOfUpperDimension() { return 1; }
1710 
1711  __host__ __device__ constexpr const auto& GetUpperLengths() const { return up_lengths_; }
1712 
1713  template <typename LowIdx, typename UpIdx>
1714  __host__ __device__ constexpr void CalculateLowerIndex(LowIdx& idx_low,
1715  const UpIdx& idx_up) const
1716  {
1717  static_assert(LowIdx::Size() == 1 && UpIdx::Size() == 1,
1718  "wrong! inconsistent # of dimension");
1719 
1720  idx_low(Number<0>{}) = vector_size_ * idx_up[Number<0>{}];
1721  }
1722 
1723  template <typename LowIdxDiff,
1724  typename UpIdxDiff,
1725  typename LowIdx,
1726  typename UpIdx,
1727  index_t Hack>
1728  __host__ __device__ void UpdateLowerIndex(LowIdxDiff& idx_diff_low,
1729  const UpIdxDiff& idx_diff_up,
1730  LowIdx& idx_low,
1731  const UpIdx&,
1732  Number<Hack>) const
1733  {
1734  static_assert(LowIdxDiff::Size() == 1 && UpIdxDiff::Size() == 1 && LowIdx::Size() == 1 &&
1735  UpIdx::Size() == 1,
1736  "wrong! inconsistent # of dimension");
1737 
1738  constexpr auto I0 = Number<0>{};
1739 
1740  idx_diff_low(I0) = vector_size_ * idx_diff_up[I0];
1741 
1742  idx_low += idx_diff_low;
1743  }
1744 
1745  __host__ __device__ static constexpr bool IsLinearTransform() { return true; }
1746 
1747  __host__ __device__ static constexpr bool IsValidUpperIndexAlwaysMappedToValidLowerIndex()
1748  {
1749  return true;
1750  }
1751 
1752  template <typename UpIdx>
1753  __host__ __device__ static constexpr bool
1754  IsValidUpperIndexMappedToValidLowerIndex(const UpIdx& /* idx_up */)
1755  {
1756  return true;
1757  }
1758 
1759  __host__ __device__ static constexpr bool IsKnownAtCompileTime()
1760  {
1762  }
1763 
1764  __host__ __device__ void Print() const
1765  {
1766  printf("{");
1767  printf("Vectorize, ");
1768  printf("up_lengths_");
1769  print_multi_index(up_lengths_);
1770  printf("}");
1771  }
1772 };
1773 
1774 template <typename LowLength, typename SliceBegin, typename SliceEnd>
1775 struct Slice
1776 {
1779 
1780  using UpLengths = decltype(make_tuple(SliceEnd{} - SliceBegin{}));
1781 
1783  SliceBegin slice_begin_;
1784  SliceEnd slice_end_;
1785 
1786  __host__ __device__ constexpr Slice() = default;
1787 
1788  __host__ __device__ constexpr Slice(const LowLength&,
1789  const SliceBegin& slice_begin,
1790  const SliceEnd& slice_end)
1791  : up_lengths_{make_tuple(slice_end - slice_begin)},
1792  slice_begin_{slice_begin},
1793  slice_end_{slice_end}
1794  {
1795  }
1796 
1797  __host__ __device__ static constexpr index_t GetNumOfLowerDimension() { return 1; }
1798 
1799  __host__ __device__ static constexpr index_t GetNumOfUpperDimension() { return 1; }
1800 
1801  __host__ __device__ constexpr const auto& GetUpperLengths() const { return up_lengths_; }
1802 
1803  template <typename LowIdx, typename UpIdx>
1804  __host__ __device__ constexpr void CalculateLowerIndex(LowIdx& idx_low,
1805  const UpIdx& idx_up) const
1806  {
1807  static_assert(LowIdx::Size() == 1 && UpIdx::Size() == 1,
1808  "wrong! inconsistent # of dimension");
1809 
1810  idx_low(Number<0>{}) = idx_up[Number<0>{}] + slice_begin_;
1811  }
1812 
1813  template <typename LowIdxDiff,
1814  typename UpIdxDiff,
1815  typename LowIdx,
1816  typename UpIdx,
1817  index_t Hack>
1818  __host__ __device__ static void UpdateLowerIndex(LowIdxDiff& idx_diff_low,
1819  const UpIdxDiff& idx_diff_up,
1820  LowIdx& idx_low,
1821  const UpIdx&,
1822  Number<Hack>)
1823  {
1824  static_assert(LowIdxDiff::Size() == 1 && UpIdxDiff::Size() == 1 && LowIdx::Size() == 1 &&
1825  UpIdx::Size() == 1,
1826  "wrong! inconsistent # of dimension");
1827 
1828  constexpr auto I0 = Number<0>{};
1829 
1830  idx_diff_low(I0) = idx_diff_up[I0];
1831 
1832  idx_low += idx_diff_low;
1833  }
1834 
1835  __host__ __device__ static constexpr bool IsLinearTransform() { return true; }
1836 
1837  __host__ __device__ static constexpr bool IsValidUpperIndexAlwaysMappedToValidLowerIndex()
1838  {
1839  return true;
1840  }
1841 
1842  template <typename UpIdx>
1843  __host__ __device__ constexpr bool IsValidUpperIndexMappedToValidLowerIndex(const UpIdx&) const
1844  {
1845  return true;
1846  }
1847 
1848  __host__ __device__ static constexpr bool IsKnownAtCompileTime()
1849  {
1853  }
1854 
1855  __host__ __device__ void Print() const
1856  {
1857  printf("{");
1858  printf("Slice, ");
1859  printf("up_lengths_");
1860  print_multi_index(up_lengths_);
1861  printf("slice_begin_ %d", index_t{slice_begin_});
1862  printf("slice_end %d", index_t{slice_end_});
1863  printf("}");
1864  }
1865 };
1866 
1867 /*
1868  * \brief lower_idx = upper_idx % modulus.
1869  * TODO: Need an improved implementation since the modulo operation is expensive.
1870  */
1871 template <typename Modulus, typename UpLength>
1872 struct Modulo
1873 {
1876  using UpLengths = decltype(make_tuple(UpLength{}));
1877 
1878  Modulus modulus_;
1880 
1881  __host__ __device__ constexpr Modulo() = default;
1882 
1883  __host__ __device__ constexpr Modulo(const Modulus& modulus, const UpLength& up_length)
1884  : modulus_{modulus}, up_lengths_{make_tuple(up_length)}
1885  {
1886  }
1887 
1888  __host__ __device__ static constexpr index_t GetNumOfLowerDimension() { return 1; }
1889 
1890  __host__ __device__ static constexpr index_t GetNumOfUpperDimension() { return 1; }
1891 
1892  __host__ __device__ constexpr const auto& GetUpperLengths() const { return up_lengths_; }
1893 
1894  template <typename LowIdx, typename UpIdx>
1895  __host__ __device__ constexpr void CalculateLowerIndex(LowIdx& idx_low,
1896  const UpIdx& idx_up) const
1897  {
1898  static_assert(LowIdx::Size() == 1 && UpIdx::Size() == 1,
1899  "wrong! inconsistent # of dimension");
1900 
1901  idx_low(Number<0>{}) = idx_up[Number<0>{}] % modulus_;
1902  }
1903 
1904  template <typename LowIdxDiff,
1905  typename UpIdxDiff,
1906  typename LowIdx,
1907  typename UpIdx,
1908  index_t Hack>
1909  __host__ __device__ void UpdateLowerIndex(LowIdxDiff& idx_diff_low,
1910  const UpIdxDiff& idx_diff_up,
1911  LowIdx& idx_low,
1912  const UpIdx& up_idx,
1913  Number<Hack>) const
1914  {
1915  static_assert(LowIdxDiff::Size() == 1 && UpIdxDiff::Size() == 1 && LowIdx::Size() == 1 &&
1916  UpIdx::Size() == 1,
1917  "wrong! inconsistent # of dimension");
1918 
1919  constexpr auto I0 = Number<0>{};
1920 
1921  const auto idx_low_old = idx_low;
1922  idx_low(I0) = (up_idx(I0) + idx_diff_up(I0)) % modulus_;
1923  idx_diff_low(I0) = idx_low - idx_low_old;
1924  }
1925 
1926  __host__ __device__ static constexpr bool IsLinearTransform() { return false; }
1927 
1928  __host__ __device__ static constexpr bool IsValidUpperIndexAlwaysMappedToValidLowerIndex()
1929  {
1930  return true;
1931  }
1932 
1933  template <typename UpIdx>
1934  __host__ __device__ static constexpr bool
1935  IsValidUpperIndexMappedToValidLowerIndex(const UpIdx& /* idx_up */)
1936  {
1937  return true;
1938  }
1939 
1940  __host__ __device__ static constexpr bool IsKnownAtCompileTime()
1941  {
1943  }
1944 
1945  __host__ __device__ void Print() const
1946  {
1947  printf("{");
1948  printf("Modulus, ");
1949  printf("up_lengths_");
1950  print_multi_index(up_lengths_);
1951  printf("}");
1952  }
1953 };
1954 
1955 template <typename LowLengths, bool ApplyModulo>
1956 struct Xor
1957 {
1960 
1961  using UpLengths = LowLengths;
1962 
1964 
1965  __host__ __device__ constexpr Xor() : up_lengths_{} {}
1966 
1967  __host__ __device__ constexpr Xor(const LowLengths& low_lengths) : up_lengths_{low_lengths} {}
1968 
1969  __host__ __device__ static constexpr index_t GetNumOfLowerDimension() { return 2; }
1970 
1971  __host__ __device__ static constexpr index_t GetNumOfUpperDimension() { return 2; }
1972 
1973  __host__ __device__ constexpr const auto& GetUpperLengths() const { return up_lengths_; }
1974 
1975  template <typename LowIdx, typename UpIdx>
1976  __host__ __device__ constexpr void CalculateLowerIndex(LowIdx& idx_low,
1977  const UpIdx& idx_up) const
1978  {
1979  static_assert(LowIdx::Size() == 2 && UpIdx::Size() == 2,
1980  "wrong! inconsistent # of dimension");
1981 
1982  idx_low(Number<0>{}) = idx_up[Number<0>{}];
1983 
1984  if constexpr(ApplyModulo)
1985  {
1986  idx_low(Number<1>{}) =
1987  idx_up[Number<1>{}] ^ (idx_up[Number<0>{}] % up_lengths_[Number<1>{}]);
1988  }
1989  else
1990  {
1991  idx_low(Number<1>{}) = idx_up[Number<1>{}] ^ idx_up[Number<0>{}];
1992  }
1993  }
1994 
1995  template <typename LowIdxDiff,
1996  typename UpIdxDiff,
1997  typename LowIdx,
1998  typename UpIdx,
1999  index_t Hack>
2000  __host__ __device__ void UpdateLowerIndex(LowIdxDiff& idx_diff_low,
2001  const UpIdxDiff&,
2002  LowIdx& idx_low,
2003  const UpIdx& idx_up,
2004  Number<Hack>) const
2005  {
2006  static_assert(LowIdxDiff::Size() == 2 && UpIdxDiff::Size() == 2 && LowIdx::Size() == 2 &&
2007  UpIdx::Size() == 2,
2008  "wrong! inconsistent # of dimension");
2009 
2010  const auto idx_low_old = idx_low;
2011 
2012  CalculateLowerIndex(idx_low, idx_up);
2013 
2014  idx_diff_low = idx_low - idx_low_old;
2015  }
2016 
2017  __host__ __device__ static constexpr bool IsValidUpperIndexAlwaysMappedToValidLowerIndex()
2018  {
2019  return true;
2020  }
2021 
2022  template <typename UpIdx>
2023  __host__ __device__ static constexpr bool
2024  IsValidUpperIndexMappedToValidLowerIndex(const UpIdx& /* idx_up */)
2025  {
2026  return true;
2027  }
2028 
2029  __host__ __device__ static constexpr bool IsKnownAtCompileTime()
2030  {
2032  }
2033 
2034  __host__ __device__ void Print() const
2035  {
2036  printf("Xor{");
2037 
2038  //
2039  printf("up_lengths_: ");
2040  print(up_lengths_);
2041  printf(", ");
2042 
2043  printf("}");
2044  }
2045 };
2046 } // namespace ck
__host__ __device__ multiplies() -> multiplies< void, void >
FIXME: create macro to replace 'host device' and nothing more.
Definition: ck.hpp:267
__host__ constexpr __device__ auto generate_tuple(F &&f, Number< N >)
Definition: tuple_helper.hpp:21
std::enable_if< B, T > enable_if
Definition: enable_if.hpp:24
__host__ constexpr __device__ auto container_reverse_exclusive_scan(const Array< TData, NSize > &x, Reduce f, TData init)
Definition: container_helper.hpp:213
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
int32_t index_t
Definition: ck.hpp:298
__host__ constexpr __device__ auto container_reduce(const Container &x, Reduce reduce, Init init, Number< IBegin >=Number< 0 >{}, Number< IEnd >=Number< Container::Size()>{}, Number< IStep >=Number< 1 >{})
Definition: container_helper.hpp:111
__host__ __device__ void print_multi_index(const Tuple< Xs... > &x)
Definition: statically_indexed_array_multi_index.hpp:147
Definition: array.hpp:14
Definition: multi_index_transform.hpp:385
static constexpr index_t NDimUp
Definition: multi_index_transform.hpp:386
__host__ constexpr __device__ const auto & GetUpperLengths() const
Definition: multi_index_transform.hpp:406
__host__ static constexpr __device__ bool IsValidUpperIndexMappedToValidLowerIndex(const UpIdx &)
Definition: multi_index_transform.hpp:454
__host__ static constexpr __device__ index_t GetNumOfUpperDimension()
Definition: multi_index_transform.hpp:404
Coefficients coefficients_
Definition: multi_index_transform.hpp:392
__host__ static constexpr __device__ bool IsLinearTransform()
Definition: multi_index_transform.hpp:445
__host__ constexpr __device__ Embed()=default
__host__ constexpr __device__ void CalculateLowerIndex(LowIdx &idx_low, const UpIdx &idx_up) const
Definition: multi_index_transform.hpp:409
UpLengths up_lengths_
Definition: multi_index_transform.hpp:391
__host__ static constexpr __device__ bool IsKnownAtCompileTime()
Definition: multi_index_transform.hpp:459
__host__ constexpr __device__ Embed(const UpLengths &up_lengths, const Coefficients &coefficients)
Definition: multi_index_transform.hpp:396
__host__ __device__ void Print() const
Definition: multi_index_transform.hpp:465
__host__ static constexpr __device__ bool IsValidUpperIndexAlwaysMappedToValidLowerIndex()
Definition: multi_index_transform.hpp:447
__host__ __device__ void UpdateLowerIndex(LowIdxDiff &idx_diff_low, const UpIdxDiff &idx_diff_up, LowIdx &idx_low, const UpIdx &, Number< Hack >) const
Definition: multi_index_transform.hpp:427
__host__ static constexpr __device__ index_t GetNumOfLowerDimension()
Definition: multi_index_transform.hpp:402
Definition: multi_index_transform.hpp:1558
__host__ static constexpr __device__ bool IsKnownAtCompileTime()
Definition: multi_index_transform.hpp:1609
LowerIndex low_idx_
Definition: multi_index_transform.hpp:1559
__host__ static constexpr __device__ index_t GetNumOfLowerDimension()
Definition: multi_index_transform.hpp:1565
__host__ static constexpr __device__ bool IsLinearTransform()
Definition: multi_index_transform.hpp:1595
__host__ static constexpr __device__ auto GetUpperLengths()
Definition: multi_index_transform.hpp:1569
__host__ static __device__ void UpdateLowerIndex(LowIdxDiff &idx_diff_low, const UpIdxDiff &, LowIdx &, const UpIdx &, Number< Hack >)
Definition: multi_index_transform.hpp:1586
__host__ static constexpr __device__ index_t GetNumOfUpperDimension()
Definition: multi_index_transform.hpp:1567
__host__ __device__ void Print() const
Definition: multi_index_transform.hpp:1614
__host__ static constexpr __device__ bool IsValidUpperIndexMappedToValidLowerIndex(const UpIdx &)
Definition: multi_index_transform.hpp:1604
__host__ constexpr __device__ void CalculateLowerIndex(LowIdx &idx_low, const UpIdx &) const
Definition: multi_index_transform.hpp:1572
__host__ static constexpr __device__ bool IsValidUpperIndexAlwaysMappedToValidLowerIndex()
Definition: multi_index_transform.hpp:1597
__host__ constexpr __device__ Freeze(const LowerIndex &low_idx)
Definition: multi_index_transform.hpp:1563
__host__ constexpr __device__ Freeze()=default
Definition: multi_index_transform.hpp:1624
__host__ static constexpr __device__ index_t GetNumOfUpperDimension()
Definition: multi_index_transform.hpp:1638
__host__ constexpr __device__ Insert()=default
__host__ constexpr __device__ void CalculateLowerIndex(LowIdx &, const UpIdx &) const
Definition: multi_index_transform.hpp:1643
__host__ static constexpr __device__ index_t GetNumOfLowerDimension()
Definition: multi_index_transform.hpp:1636
__host__ constexpr __device__ Insert(const UpperLength &up_length)
Definition: multi_index_transform.hpp:1631
__host__ static constexpr __device__ bool IsKnownAtCompileTime()
Definition: multi_index_transform.hpp:1676
__host__ static constexpr __device__ bool IsLinearTransform()
Definition: multi_index_transform.hpp:1662
__host__ __device__ void Print() const
Definition: multi_index_transform.hpp:1681
UpLengths up_lengths_
Definition: multi_index_transform.hpp:1627
decltype(make_tuple(UpperLength{})) UpLengths
Definition: multi_index_transform.hpp:1625
__host__ static constexpr __device__ bool IsValidUpperIndexMappedToValidLowerIndex(const UpIdx &)
Definition: multi_index_transform.hpp:1671
__host__ static __device__ void UpdateLowerIndex(LowIdxDiff &, const UpIdxDiff &, LowIdx &, const UpIdx &, Number< Hack >)
Definition: multi_index_transform.hpp:1655
__host__ constexpr __device__ auto GetUpperLengths() const
Definition: multi_index_transform.hpp:1640
__host__ static constexpr __device__ bool IsValidUpperIndexAlwaysMappedToValidLowerIndex()
Definition: multi_index_transform.hpp:1664
Definition: multi_index_transform.hpp:196
__host__ static constexpr __device__ index_t GetNumOfUpperDimension()
Definition: multi_index_transform.hpp:215
__host__ constexpr __device__ LeftPad(const LowLength &low_length, const LeftPadLength &left_pad_length)
Definition: multi_index_transform.hpp:207
__host__ static constexpr __device__ bool IsLinearTransform()
Definition: multi_index_transform.hpp:251
__host__ static constexpr __device__ bool IsKnownAtCompileTime()
Definition: multi_index_transform.hpp:265
__host__ static constexpr __device__ index_t GetNumOfLowerDimension()
Definition: multi_index_transform.hpp:213
__host__ constexpr __device__ void CalculateLowerIndex(LowIdx &idx_low, const UpIdx &idx_up) const
Definition: multi_index_transform.hpp:220
decltype(make_tuple(LowLength{}+LeftPadLength{})) UpLengths
Definition: multi_index_transform.hpp:200
__host__ constexpr __device__ LeftPad()=default
LeftPadLength left_pad_length_
Definition: multi_index_transform.hpp:203
__host__ static __device__ void UpdateLowerIndex(LowIdxDiff &idx_diff_low, const UpIdxDiff &idx_diff_up, LowIdx &idx_low, const UpIdx &, Number< Hack >)
Definition: multi_index_transform.hpp:234
__host__ __device__ void Print() const
Definition: multi_index_transform.hpp:271
__host__ constexpr __device__ const auto & GetUpperLengths() const
Definition: multi_index_transform.hpp:217
UpLengths up_lengths_
Definition: multi_index_transform.hpp:202
__host__ constexpr __device__ bool IsValidUpperIndexMappedToValidLowerIndex(const UpIdx &idx_up) const
Definition: multi_index_transform.hpp:260
__host__ static constexpr __device__ bool IsValidUpperIndexAlwaysMappedToValidLowerIndex()
Definition: multi_index_transform.hpp:253
Definition: multi_index_transform.hpp:481
__host__ __device__ void UpdateLowerIndex_2(LowIdxDiff &idx_diff_low, const UpIdxDiff &idx_diff_up, LowIdx &idx_low, const UpIdx &, Number< Hack >) const
Definition: multi_index_transform.hpp:822
__host__ static constexpr __device__ bool IsLinearTransform()
Definition: multi_index_transform.hpp:967
LowLengths low_lengths_
Definition: multi_index_transform.hpp:493
__host__ static constexpr __device__ index_t GetNumOfLowerDimension()
Definition: multi_index_transform.hpp:508
__host__ __device__ void UpdateLowerIndex_1b(LowIdxDiff &idx_diff_low, const UpIdxDiff &idx_diff_up, LowIdx &idx_low, const UpIdx &, Number< Hack >) const
Definition: multi_index_transform.hpp:680
__host__ static constexpr __device__ bool IsValidUpperIndexAlwaysMappedToValidLowerIndex()
Definition: multi_index_transform.hpp:969
UpLengths up_lengths_
Definition: multi_index_transform.hpp:495
static constexpr index_t NDimLow
Definition: multi_index_transform.hpp:482
__host__ static constexpr __device__ bool IsValidUpperIndexMappedToValidLowerIndex(const UpIdx &)
Definition: multi_index_transform.hpp:983
__host__ constexpr __device__ void CalculateLowerIndex(LowIdx &idx_low, const UpIdx &idx_up) const
Definition: multi_index_transform.hpp:515
decltype(make_tuple(container_reduce(LowLengths{}, math::multiplies{}, Number< 1 >{}))) UpLengths
Definition: multi_index_transform.hpp:491
__host__ constexpr __device__ Merge_v1_carry_check(const LowLengths &low_lengths)
Definition: multi_index_transform.hpp:499
__host__ static constexpr __device__ bool IsKnownAtCompileTime()
Definition: multi_index_transform.hpp:974
__host__ static constexpr __device__ index_t GetNumOfUpperDimension()
Definition: multi_index_transform.hpp:510
decltype(container_reverse_exclusive_scan(LowLengths{}, math::multiplies{}, Number< 1 >{})) LowLengthsScan
Definition: multi_index_transform.hpp:488
__host__ constexpr __device__ Merge_v1_carry_check()=default
__host__ __device__ void Print() const
Definition: multi_index_transform.hpp:988
__host__ constexpr __device__ const auto & GetUpperLengths() const
Definition: multi_index_transform.hpp:512
__host__ __device__ void UpdateLowerIndex(LowIdxDiff &idx_diff_low, const UpIdxDiff &idx_diff_up, LowIdx &idx_low, const UpIdx &idx_up_new, Number< Hack >) const
Definition: multi_index_transform.hpp:952
LowLengthsScan low_lengths_scan_
Definition: multi_index_transform.hpp:494
__host__ __device__ void UpdateLowerIndex_1a(LowIdxDiff &idx_diff_low, const UpIdxDiff &idx_diff_up, LowIdx &idx_low, const UpIdx &, Number< Hack >) const
Definition: multi_index_transform.hpp:537
Definition: multi_index_transform.hpp:1036
LowLengthsMagicDivisorShift low_lengths_magic_divisor_shift_
Definition: multi_index_transform.hpp:1055
__host__ static constexpr __device__ bool IsValidUpperIndexAlwaysMappedToValidLowerIndex()
Definition: multi_index_transform.hpp:1138
UpLengths up_lengths_
Definition: multi_index_transform.hpp:1056
__host__ constexpr __device__ const auto & GetUpperLengths() const
Definition: multi_index_transform.hpp:1077
__host__ static constexpr __device__ index_t GetNumOfLowerDimension()
Definition: multi_index_transform.hpp:1073
__host__ static constexpr __device__ index_t GetNumOfUpperDimension()
Definition: multi_index_transform.hpp:1075
LowLengthsMagicDivisorMultipiler low_lengths_magic_divisor_multiplier_
Definition: multi_index_transform.hpp:1054
decltype(make_tuple(container_reduce(LowLengths{}, math::multiplies{}, Number< 1 >{}))) UpLengths
Definition: multi_index_transform.hpp:1043
LowLengths low_lengths_
Definition: multi_index_transform.hpp:1053
__host__ static constexpr __device__ bool IsLinearTransform()
Definition: multi_index_transform.hpp:1136
__host__ constexpr __device__ void CalculateLowerIndex(LowIdx &idx_low, const UpIdx &idx_up) const
Definition: multi_index_transform.hpp:1080
__host__ __device__ void UpdateLowerIndex(LowIdxDiff &idx_diff_low, const UpIdxDiff &, LowIdx &idx_low, const UpIdx &idx_up_new, Number< Hack >) const
Definition: multi_index_transform.hpp:1105
__host__ static constexpr __device__ bool IsValidUpperIndexMappedToValidLowerIndex(const UpIdx &)
Definition: multi_index_transform.hpp:1153
__host__ static constexpr __device__ bool IsKnownAtCompileTime()
Definition: multi_index_transform.hpp:1143
decltype(generate_tuple(lambda_merge_generate_MagicDivision_calculate_magic_multiplier< LowLengths >{}, Number< NDimLow >{})) LowLengthsMagicDivisorMultipiler
Definition: multi_index_transform.hpp:1047
__host__ constexpr __device__ Merge_v2_magic_division(const LowLengths &low_lengths)
Definition: multi_index_transform.hpp:1060
decltype(generate_tuple(lambda_merge_generate_MagicDivision_calculate_magic_shift< LowLengths >{}, Number< NDimLow >{})) LowLengthsMagicDivisorShift
Definition: multi_index_transform.hpp:1051
__host__ __device__ void Print() const
Definition: multi_index_transform.hpp:1158
__host__ constexpr __device__ Merge_v2_magic_division()=default
Definition: multi_index_transform.hpp:1188
LowLengths low_lengths_
Definition: multi_index_transform.hpp:1208
__host__ __device__ void Print() const
Definition: multi_index_transform.hpp:1315
__host__ constexpr __device__ Merge_v2r2_magic_division()=default
decltype(container_reverse_exclusive_scan(LowLengths{}, math::multiplies{}, Number< 1 >{})) LowLengthsScan
Definition: multi_index_transform.hpp:1195
__host__ static constexpr __device__ bool IsValidUpperIndexAlwaysMappedToValidLowerIndex()
Definition: multi_index_transform.hpp:1295
decltype(make_tuple(container_reduce(LowLengths{}, math::multiplies{}, Number< 1 >{}))) UpLengths
Definition: multi_index_transform.hpp:1198
LowLengthsScanMagicDivisorShift low_lengths_scan_magic_divisor_shift_
Definition: multi_index_transform.hpp:1211
UpLengths up_lengths_
Definition: multi_index_transform.hpp:1212
__host__ static constexpr __device__ index_t GetNumOfUpperDimension()
Definition: multi_index_transform.hpp:1233
LowLengthsScanMagicDivisorMultipiler low_lengths_scan_magic_divisor_multiplier_
Definition: multi_index_transform.hpp:1210
__host__ static constexpr __device__ bool IsLinearTransform()
Definition: multi_index_transform.hpp:1293
__host__ constexpr __device__ Merge_v2r2_magic_division(const LowLengths &low_lengths)
Definition: multi_index_transform.hpp:1216
decltype(generate_tuple(lambda_merge_generate_MagicDivision_calculate_magic_shift< LowLengthsScan >{}, Number< NDimLow >{})) LowLengthsScanMagicDivisorShift
Definition: multi_index_transform.hpp:1206
decltype(generate_tuple(lambda_merge_generate_MagicDivision_calculate_magic_multiplier< LowLengthsScan >{}, Number< NDimLow >{})) LowLengthsScanMagicDivisorMultipiler
Definition: multi_index_transform.hpp:1202
__host__ __device__ void UpdateLowerIndex(LowIdxDiff &idx_diff_low, const UpIdxDiff &, LowIdx &idx_low, const UpIdx &idx_up_new, Number< Hack >) const
Definition: multi_index_transform.hpp:1263
LowLengthsScan low_lengths_scan_
Definition: multi_index_transform.hpp:1209
__host__ constexpr __device__ const auto & GetUpperLengths() const
Definition: multi_index_transform.hpp:1235
__host__ static constexpr __device__ bool IsValidUpperIndexMappedToValidLowerIndex(const UpIdx &)
Definition: multi_index_transform.hpp:1310
__host__ static constexpr __device__ index_t GetNumOfLowerDimension()
Definition: multi_index_transform.hpp:1231
__host__ static constexpr __device__ bool IsKnownAtCompileTime()
Definition: multi_index_transform.hpp:1300
__host__ constexpr __device__ void CalculateLowerIndex(LowIdx &idx_low, const UpIdx &idx_up) const
Definition: multi_index_transform.hpp:1238
Definition: multi_index_transform.hpp:1338
UpLengths up_lengths_
Definition: multi_index_transform.hpp:1352
__host__ __device__ void UpdateLowerIndex(LowIdxDiff &idx_diff_low, const UpIdxDiff &, LowIdx &idx_low, const UpIdx &idx_up_new, Number< Hack >) const
Definition: multi_index_transform.hpp:1394
__host__ __device__ void Print() const
Definition: multi_index_transform.hpp:1442
__host__ static constexpr __device__ index_t GetNumOfLowerDimension()
Definition: multi_index_transform.hpp:1365
__host__ static constexpr __device__ bool IsValidUpperIndexMappedToValidLowerIndex(const UpIdx &)
Definition: multi_index_transform.hpp:1437
__host__ static constexpr __device__ bool IsKnownAtCompileTime()
Definition: multi_index_transform.hpp:1428
LowLengthsScan low_lengths_scan_
Definition: multi_index_transform.hpp:1351
__host__ constexpr __device__ Merge_v3_division_mod(const LowLengths &low_lengths)
Definition: multi_index_transform.hpp:1356
LowLengths low_lengths_
Definition: multi_index_transform.hpp:1350
__host__ static constexpr __device__ index_t GetNumOfUpperDimension()
Definition: multi_index_transform.hpp:1367
__host__ constexpr __device__ void CalculateLowerIndex(LowIdx &idx_low, const UpIdx &idx_up) const
Definition: multi_index_transform.hpp:1372
decltype(make_tuple(container_reduce(LowLengths{}, math::multiplies{}, Number< 1 >{}))) UpLengths
Definition: multi_index_transform.hpp:1348
__host__ constexpr __device__ const auto & GetUpperLengths() const
Definition: multi_index_transform.hpp:1369
__host__ constexpr __device__ Merge_v3_division_mod()=default
__host__ static constexpr __device__ bool IsValidUpperIndexAlwaysMappedToValidLowerIndex()
Definition: multi_index_transform.hpp:1423
decltype(container_reverse_exclusive_scan(LowLengths{}, math::multiplies{}, Number< 1 >{})) LowLengthsScan
Definition: multi_index_transform.hpp:1345
__host__ static constexpr __device__ bool IsLinearTransform()
Definition: multi_index_transform.hpp:1421
Definition: multi_index_transform.hpp:1873
__host__ static constexpr __device__ bool IsValidUpperIndexMappedToValidLowerIndex(const UpIdx &)
Definition: multi_index_transform.hpp:1935
__host__ constexpr __device__ Modulo(const Modulus &modulus, const UpLength &up_length)
Definition: multi_index_transform.hpp:1883
Modulus modulus_
Definition: multi_index_transform.hpp:1878
__host__ __device__ void Print() const
Definition: multi_index_transform.hpp:1945
__host__ static constexpr __device__ bool IsValidUpperIndexAlwaysMappedToValidLowerIndex()
Definition: multi_index_transform.hpp:1928
__host__ static constexpr __device__ index_t GetNumOfUpperDimension()
Definition: multi_index_transform.hpp:1890
__host__ __device__ void UpdateLowerIndex(LowIdxDiff &idx_diff_low, const UpIdxDiff &idx_diff_up, LowIdx &idx_low, const UpIdx &up_idx, Number< Hack >) const
Definition: multi_index_transform.hpp:1909
__host__ constexpr __device__ const auto & GetUpperLengths() const
Definition: multi_index_transform.hpp:1892
__host__ static constexpr __device__ bool IsKnownAtCompileTime()
Definition: multi_index_transform.hpp:1940
__host__ static constexpr __device__ bool IsLinearTransform()
Definition: multi_index_transform.hpp:1926
decltype(make_tuple(UpLength{})) UpLengths
Definition: multi_index_transform.hpp:1876
__host__ static constexpr __device__ index_t GetNumOfLowerDimension()
Definition: multi_index_transform.hpp:1888
__host__ constexpr __device__ Modulo()=default
UpLengths up_lengths_
Definition: multi_index_transform.hpp:1879
__host__ constexpr __device__ void CalculateLowerIndex(LowIdx &idx_low, const UpIdx &idx_up) const
Definition: multi_index_transform.hpp:1895
Definition: multi_index_transform.hpp:100
__host__ static __device__ void UpdateLowerIndex(LowIdxDiff &idx_diff_low, const UpIdxDiff &idx_diff_up, LowIdx &idx_low, const UpIdx &, Number< Hack >)
Definition: multi_index_transform.hpp:142
LeftPadLength left_pad_length_
Definition: multi_index_transform.hpp:107
__host__ constexpr __device__ Pad()=default
__host__ constexpr __device__ const auto & GetUpperLengths() const
Definition: multi_index_transform.hpp:125
__host__ constexpr __device__ void CalculateLowerIndex(LowIdx &idx_low, const UpIdx &idx_up) const
Definition: multi_index_transform.hpp:128
decltype(make_tuple(LowLength{}+LeftPadLength{}+RightPadLength{})) UpLengths
Definition: multi_index_transform.hpp:104
__host__ static constexpr __device__ bool IsLinearTransform()
Definition: multi_index_transform.hpp:159
__host__ constexpr __device__ Pad(const LowLength &low_length, const LeftPadLength &left_pad_length, const RightPadLength &right_pad_length)
Definition: multi_index_transform.hpp:112
__host__ static constexpr __device__ bool IsValidUpperIndexAlwaysMappedToValidLowerIndex()
Definition: multi_index_transform.hpp:161
__host__ static constexpr __device__ bool IsKnownAtCompileTime()
Definition: multi_index_transform.hpp:175
__host__ static constexpr __device__ index_t GetNumOfUpperDimension()
Definition: multi_index_transform.hpp:123
UpLengths up_lengths_
Definition: multi_index_transform.hpp:106
RightPadLength right_pad_length_
Definition: multi_index_transform.hpp:108
__host__ constexpr __device__ bool IsValidUpperIndexMappedToValidLowerIndex(const UpIdx &idx_up) const
Definition: multi_index_transform.hpp:168
__host__ __device__ void Print() const
Definition: multi_index_transform.hpp:182
__host__ static constexpr __device__ index_t GetNumOfLowerDimension()
Definition: multi_index_transform.hpp:121
Definition: multi_index_transform.hpp:13
__host__ constexpr __device__ PassThrough(const LowLength &low_length)
Definition: multi_index_transform.hpp:23
__host__ static constexpr __device__ index_t GetNumOfUpperDimension()
Definition: multi_index_transform.hpp:30
__host__ static constexpr __device__ bool IsValidUpperIndexAlwaysMappedToValidLowerIndex()
Definition: multi_index_transform.hpp:68
__host__ constexpr __device__ const auto & GetUpperLengths() const
Definition: multi_index_transform.hpp:32
UpLengths up_lengths_
Definition: multi_index_transform.hpp:19
__host__ constexpr __device__ PassThrough()=default
__host__ static constexpr __device__ bool IsLinearTransform()
Definition: multi_index_transform.hpp:66
__host__ static __device__ void UpdateLowerIndex(LowIdxDiff &idx_diff_low, const UpIdxDiff &idx_diff_up, LowIdx &idx_low, const UpIdx &, Number< Hack >)
Definition: multi_index_transform.hpp:49
__host__ __device__ void Print() const
Definition: multi_index_transform.hpp:85
__host__ static constexpr __device__ bool IsValidUpperIndexMappedToValidLowerIndex(const UpIdx &)
Definition: multi_index_transform.hpp:75
decltype(make_tuple(LowLength{})) UpLengths
Definition: multi_index_transform.hpp:17
__host__ static constexpr __device__ index_t GetNumOfLowerDimension()
Definition: multi_index_transform.hpp:28
__host__ static constexpr __device__ bool IsKnownAtCompileTime()
Definition: multi_index_transform.hpp:80
__host__ static constexpr __device__ void CalculateLowerIndex(LowIdx &idx_low, const UpIdx &idx_up)
Definition: multi_index_transform.hpp:35
Definition: multi_index_transform.hpp:284
__host__ static constexpr __device__ void CalculateLowerIndex(LowIdx &idx_low, const UpIdx &idx_up)
Definition: multi_index_transform.hpp:311
__host__ constexpr __device__ RightPad()=default
__host__ static constexpr __device__ bool IsLinearTransform()
Definition: multi_index_transform.hpp:342
__host__ constexpr __device__ bool IsValidUpperIndexMappedToValidLowerIndex(const UpIdx &idx_up) const
Definition: multi_index_transform.hpp:351
__host__ constexpr __device__ const auto & GetUpperLengths() const
Definition: multi_index_transform.hpp:308
decltype(make_tuple(LowLength{}+RightPadLength{})) UpLengths
Definition: multi_index_transform.hpp:288
__host__ constexpr __device__ RightPad(const LowLength &low_length, const RightPadLength &right_pad_length)
Definition: multi_index_transform.hpp:296
UpLengths up_lengths_
Definition: multi_index_transform.hpp:290
__host__ static constexpr __device__ bool IsKnownAtCompileTime()
Definition: multi_index_transform.hpp:356
__host__ static constexpr __device__ index_t GetNumOfLowerDimension()
Definition: multi_index_transform.hpp:304
__host__ static __device__ void UpdateLowerIndex(LowIdxDiff &idx_diff_low, const UpIdxDiff &idx_diff_up, LowIdx &idx_low, const UpIdx &, Number< Hack >)
Definition: multi_index_transform.hpp:325
__host__ static constexpr __device__ index_t GetNumOfUpperDimension()
Definition: multi_index_transform.hpp:306
__host__ __device__ void Print() const
Definition: multi_index_transform.hpp:363
LowLength low_length_
Definition: multi_index_transform.hpp:291
__host__ static constexpr __device__ bool IsValidUpperIndexAlwaysMappedToValidLowerIndex()
Definition: multi_index_transform.hpp:344
RightPadLength right_pad_length_
Definition: multi_index_transform.hpp:292
Definition: multi_index_transform.hpp:1776
SliceEnd slice_end_
Definition: multi_index_transform.hpp:1784
decltype(make_tuple(SliceEnd{} - SliceBegin{})) UpLengths
Definition: multi_index_transform.hpp:1780
__host__ static constexpr __device__ index_t GetNumOfLowerDimension()
Definition: multi_index_transform.hpp:1797
__host__ __device__ void Print() const
Definition: multi_index_transform.hpp:1855
UpLengths up_lengths_
Definition: multi_index_transform.hpp:1782
SliceBegin slice_begin_
Definition: multi_index_transform.hpp:1783
__host__ constexpr __device__ bool IsValidUpperIndexMappedToValidLowerIndex(const UpIdx &) const
Definition: multi_index_transform.hpp:1843
__host__ constexpr __device__ Slice(const LowLength &, const SliceBegin &slice_begin, const SliceEnd &slice_end)
Definition: multi_index_transform.hpp:1788
__host__ constexpr __device__ void CalculateLowerIndex(LowIdx &idx_low, const UpIdx &idx_up) const
Definition: multi_index_transform.hpp:1804
__host__ static constexpr __device__ bool IsValidUpperIndexAlwaysMappedToValidLowerIndex()
Definition: multi_index_transform.hpp:1837
__host__ constexpr __device__ const auto & GetUpperLengths() const
Definition: multi_index_transform.hpp:1801
__host__ static __device__ void UpdateLowerIndex(LowIdxDiff &idx_diff_low, const UpIdxDiff &idx_diff_up, LowIdx &idx_low, const UpIdx &, Number< Hack >)
Definition: multi_index_transform.hpp:1818
__host__ constexpr __device__ Slice()=default
__host__ static constexpr __device__ index_t GetNumOfUpperDimension()
Definition: multi_index_transform.hpp:1799
__host__ static constexpr __device__ bool IsKnownAtCompileTime()
Definition: multi_index_transform.hpp:1848
__host__ static constexpr __device__ bool IsLinearTransform()
Definition: multi_index_transform.hpp:1835
Definition: tuple.hpp:186
Definition: multi_index_transform.hpp:1458
UpLengths up_lengths_
Definition: multi_index_transform.hpp:1467
__host__ static constexpr __device__ bool IsValidUpperIndexMappedToValidLowerIndex(const UpIdx &)
Definition: multi_index_transform.hpp:1533
__host__ constexpr __device__ const auto & GetUpperLengths() const
Definition: multi_index_transform.hpp:1483
__host__ __device__ void Print() const
Definition: multi_index_transform.hpp:1544
UpLengthsScan up_lengths_scan_
Definition: multi_index_transform.hpp:1468
__host__ static constexpr __device__ bool IsLinearTransform()
Definition: multi_index_transform.hpp:1524
__host__ constexpr __device__ UnMerge()=default
__host__ static constexpr __device__ bool IsKnownAtCompileTime()
Definition: multi_index_transform.hpp:1538
__host__ constexpr __device__ UnMerge(const UpLengths &up_lengths)
Definition: multi_index_transform.hpp:1472
__host__ constexpr __device__ void CalculateLowerIndex(LowIdx &idx_low, const UpIdx &idx_up) const
Definition: multi_index_transform.hpp:1486
__host__ __device__ void UpdateLowerIndex(LowIdxDiff &idx_diff_low, const UpIdxDiff &idx_diff_up, LowIdx &idx_low, const UpIdx &, Number< Hack >) const
Definition: multi_index_transform.hpp:1513
__host__ static constexpr __device__ index_t GetNumOfLowerDimension()
Definition: multi_index_transform.hpp:1479
__host__ static constexpr __device__ index_t GetNumOfUpperDimension()
Definition: multi_index_transform.hpp:1481
__host__ static constexpr __device__ bool IsValidUpperIndexAlwaysMappedToValidLowerIndex()
Definition: multi_index_transform.hpp:1526
decltype(container_reverse_exclusive_scan(UpLengths{}, math::multiplies{}, Number< 1 >{})) UpLengthsScan
Definition: multi_index_transform.hpp:1465
Definition: multi_index_transform.hpp:1690
__host__ constexpr __device__ const auto & GetUpperLengths() const
Definition: multi_index_transform.hpp:1711
__host__ static constexpr __device__ bool IsValidUpperIndexAlwaysMappedToValidLowerIndex()
Definition: multi_index_transform.hpp:1747
__host__ static constexpr __device__ bool IsKnownAtCompileTime()
Definition: multi_index_transform.hpp:1759
__host__ __device__ void UpdateLowerIndex(LowIdxDiff &idx_diff_low, const UpIdxDiff &idx_diff_up, LowIdx &idx_low, const UpIdx &, Number< Hack >) const
Definition: multi_index_transform.hpp:1728
__host__ constexpr __device__ Vectorize()=default
__host__ static constexpr __device__ index_t GetNumOfUpperDimension()
Definition: multi_index_transform.hpp:1709
__host__ static constexpr __device__ index_t GetNumOfLowerDimension()
Definition: multi_index_transform.hpp:1707
VectorSize vector_size_
Definition: multi_index_transform.hpp:1697
__host__ constexpr __device__ Vectorize(const VectorSize &vector_size, const UpLength &up_length)
Definition: multi_index_transform.hpp:1701
decltype(make_tuple(UpLength{})) UpLengths
Definition: multi_index_transform.hpp:1694
UpLengths up_lengths_
Definition: multi_index_transform.hpp:1696
__host__ constexpr __device__ void CalculateLowerIndex(LowIdx &idx_low, const UpIdx &idx_up) const
Definition: multi_index_transform.hpp:1714
__host__ __device__ void Print() const
Definition: multi_index_transform.hpp:1764
__host__ static constexpr __device__ bool IsLinearTransform()
Definition: multi_index_transform.hpp:1745
__host__ static constexpr __device__ bool IsValidUpperIndexMappedToValidLowerIndex(const UpIdx &)
Definition: multi_index_transform.hpp:1754
Definition: multi_index_transform.hpp:1957
__host__ constexpr __device__ const auto & GetUpperLengths() const
Definition: multi_index_transform.hpp:1973
__host__ constexpr __device__ void CalculateLowerIndex(LowIdx &idx_low, const UpIdx &idx_up) const
Definition: multi_index_transform.hpp:1976
__host__ static constexpr __device__ index_t GetNumOfUpperDimension()
Definition: multi_index_transform.hpp:1971
__host__ constexpr __device__ Xor()
Definition: multi_index_transform.hpp:1965
UpLengths up_lengths_
Definition: multi_index_transform.hpp:1963
__host__ __device__ void UpdateLowerIndex(LowIdxDiff &idx_diff_low, const UpIdxDiff &, LowIdx &idx_low, const UpIdx &idx_up, Number< Hack >) const
Definition: multi_index_transform.hpp:2000
__host__ static constexpr __device__ bool IsValidUpperIndexAlwaysMappedToValidLowerIndex()
Definition: multi_index_transform.hpp:2017
__host__ static constexpr __device__ bool IsValidUpperIndexMappedToValidLowerIndex(const UpIdx &)
Definition: multi_index_transform.hpp:2024
__host__ __device__ void Print() const
Definition: multi_index_transform.hpp:2034
__host__ static constexpr __device__ bool IsKnownAtCompileTime()
Definition: multi_index_transform.hpp:2029
__host__ static constexpr __device__ index_t GetNumOfLowerDimension()
Definition: multi_index_transform.hpp:1969
__host__ constexpr __device__ Xor(const LowLengths &low_lengths)
Definition: multi_index_transform.hpp:1967
LowLengths UpLengths
Definition: multi_index_transform.hpp:1961
Definition: integral_constant.hpp:20
Definition: is_known_at_compile_time.hpp:14
__host__ constexpr __device__ auto operator()(Number< I > i) const
Definition: multi_index_transform.hpp:1006
Definition: multi_index_transform.hpp:1014
__host__ constexpr __device__ auto operator()(Number< I > i) const
Definition: multi_index_transform.hpp:1016
Definition: math.hpp:34
Definition: functional2.hpp:33