/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_description/multi_index_transform.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_description/multi_index_transform.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_description/multi_index_transform.hpp Source File
multi_index_transform.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
8 
9 namespace ck {
10 
11 template <typename LowLength>
13 {
16 
17  using UpLengths = decltype(make_tuple(LowLength{}));
18 
20 
21  __host__ __device__ constexpr PassThrough() = default;
22 
23  __host__ __device__ constexpr PassThrough(const LowLength& low_length)
24  : up_lengths_{make_tuple(low_length)}
25  {
26  }
27 
28  __host__ __device__ static constexpr index_t GetNumOfLowerDimension() { return 1; }
29 
30  __host__ __device__ static constexpr index_t GetNumOfUpperDimension() { return 1; }
31 
32  __host__ __device__ constexpr const auto& GetUpperLengths() const { return up_lengths_; }
33 
34  template <typename LowIdx, typename UpIdx>
35  __host__ __device__ static constexpr void CalculateLowerIndex(LowIdx& idx_low,
36  const UpIdx& idx_up)
37  {
38  static_assert(LowIdx::Size() == 1 && UpIdx::Size() == 1,
39  "wrong! inconsistent # of dimension");
40 
41  idx_low(Number<0>{}) = idx_up[Number<0>{}];
42  }
43 
44  template <typename LowIdxDiff,
45  typename UpIdxDiff,
46  typename LowIdx,
47  typename UpIdx,
48  index_t Hack>
49  __host__ __device__ static void UpdateLowerIndex(LowIdxDiff& idx_diff_low,
50  const UpIdxDiff& idx_diff_up,
51  LowIdx& idx_low,
52  const UpIdx&,
54  {
55  static_assert(LowIdxDiff::Size() == 1 && UpIdxDiff::Size() == 1 && LowIdx::Size() == 1 &&
56  UpIdx::Size() == 1,
57  "wrong! inconsistent # of dimension");
58 
59  constexpr auto I0 = Number<0>{};
60 
61  idx_diff_low(I0) = idx_diff_up[I0];
62 
63  idx_low += idx_diff_low;
64  }
65 
66  __host__ __device__ static constexpr bool IsLinearTransform() { return true; }
67 
68  __host__ __device__ static constexpr bool IsValidUpperIndexAlwaysMappedToValidLowerIndex()
69  {
70  return true;
71  }
72 
73  template <typename UpIdx>
74  __host__ __device__ static constexpr bool
75  IsValidUpperIndexMappedToValidLowerIndex(const UpIdx& /* idx_up */)
76  {
77  return true;
78  }
79 
80  __host__ __device__ static constexpr bool IsKnownAtCompileTime()
81  {
83  }
84 
85  __host__ __device__ void Print() const
86  {
87  printf("{");
88  printf("PassThrough, ");
89  printf("up_lengths_");
91  printf("}");
92  }
93 };
94 
95 template <typename LowLength,
96  typename LeftPadLength,
97  typename RightPadLength,
98  bool SkipIsValidCheck = false>
99 struct Pad
100 {
103 
104  using UpLengths = decltype(make_tuple(LowLength{} + LeftPadLength{} + RightPadLength{}));
105 
107  LeftPadLength left_pad_length_;
108  RightPadLength right_pad_length_;
109 
110  __host__ __device__ constexpr Pad() = default;
111 
112  __host__ __device__ constexpr Pad(const LowLength& low_length,
113  const LeftPadLength& left_pad_length,
114  const RightPadLength& right_pad_length)
115  : up_lengths_{make_tuple(low_length + left_pad_length + right_pad_length)},
116  left_pad_length_{left_pad_length},
117  right_pad_length_{right_pad_length}
118  {
119  }
120 
121  __host__ __device__ static constexpr index_t GetNumOfLowerDimension() { return 1; }
122 
123  __host__ __device__ static constexpr index_t GetNumOfUpperDimension() { return 1; }
124 
125  __host__ __device__ constexpr const auto& GetUpperLengths() const { return up_lengths_; }
126 
127  template <typename LowIdx, typename UpIdx>
128  __host__ __device__ constexpr void CalculateLowerIndex(LowIdx& idx_low,
129  const UpIdx& idx_up) const
130  {
131  static_assert(LowIdx::Size() == 1 && UpIdx::Size() == 1,
132  "wrong! inconsistent # of dimension");
133 
134  idx_low(Number<0>{}) = idx_up[Number<0>{}] - left_pad_length_;
135  }
136 
137  template <typename LowIdxDiff,
138  typename UpIdxDiff,
139  typename LowIdx,
140  typename UpIdx,
141  index_t Hack>
142  __host__ __device__ static void UpdateLowerIndex(LowIdxDiff& idx_diff_low,
143  const UpIdxDiff& idx_diff_up,
144  LowIdx& idx_low,
145  const UpIdx&,
146  Number<Hack>)
147  {
148  static_assert(LowIdxDiff::Size() == 1 && UpIdxDiff::Size() == 1 && LowIdx::Size() == 1 &&
149  UpIdx::Size() == 1,
150  "wrong! inconsistent # of dimension");
151 
152  constexpr auto I0 = Number<0>{};
153 
154  idx_diff_low(I0) = idx_diff_up[I0];
155 
156  idx_low += idx_diff_low;
157  }
158 
159  __host__ __device__ static constexpr bool IsLinearTransform() { return true; }
160 
161  __host__ __device__ static constexpr bool IsValidUpperIndexAlwaysMappedToValidLowerIndex()
162  {
163  return SkipIsValidCheck;
164  }
165 
166  template <typename UpIdx>
167  __host__ __device__ constexpr bool
168  IsValidUpperIndexMappedToValidLowerIndex(const UpIdx& idx_up) const
169  {
170  return SkipIsValidCheck ||
171  ((idx_up[Number<0>{}] >= left_pad_length_) &&
172  (idx_up[Number<0>{}] < up_lengths_[Number<0>{}] - right_pad_length_));
173  }
174 
175  __host__ __device__ static constexpr bool IsKnownAtCompileTime()
176  {
180  }
181 
182  __host__ __device__ void Print() const
183  {
184  printf("{");
185  printf("Pad, ");
186  printf("up_lengths_");
188  printf("left_pad_length %d", index_t{left_pad_length_});
189  printf("right_pad_length %d", index_t{right_pad_length_});
190  printf("}");
191  }
192 };
193 
194 template <typename LowLength, typename LeftPadLength, bool SkipIsValidCheck = false>
195 struct LeftPad
196 {
199 
200  using UpLengths = decltype(make_tuple(LowLength{} + LeftPadLength{}));
201 
203  LeftPadLength left_pad_length_;
204 
205  __host__ __device__ constexpr LeftPad() = default;
206 
207  __host__ __device__ constexpr LeftPad(const LowLength& low_length,
208  const LeftPadLength& left_pad_length)
209  : up_lengths_{make_tuple(low_length + left_pad_length)}, left_pad_length_{left_pad_length}
210  {
211  }
212 
213  __host__ __device__ static constexpr index_t GetNumOfLowerDimension() { return 1; }
214 
215  __host__ __device__ static constexpr index_t GetNumOfUpperDimension() { return 1; }
216 
217  __host__ __device__ constexpr const auto& GetUpperLengths() const { return up_lengths_; }
218 
219  template <typename LowIdx, typename UpIdx>
220  __host__ __device__ constexpr void CalculateLowerIndex(LowIdx& idx_low,
221  const UpIdx& idx_up) const
222  {
223  static_assert(LowIdx::Size() == 1 && UpIdx::Size() == 1,
224  "wrong! inconsistent # of dimension");
225 
226  idx_low(Number<0>{}) = idx_up[Number<0>{}] - left_pad_length_;
227  }
228 
229  template <typename LowIdxDiff,
230  typename UpIdxDiff,
231  typename LowIdx,
232  typename UpIdx,
233  index_t Hack>
234  __host__ __device__ static void UpdateLowerIndex(LowIdxDiff& idx_diff_low,
235  const UpIdxDiff& idx_diff_up,
236  LowIdx& idx_low,
237  const UpIdx&,
238  Number<Hack>)
239  {
240  static_assert(LowIdxDiff::Size() == 1 && UpIdxDiff::Size() == 1 && LowIdx::Size() == 1 &&
241  UpIdx::Size() == 1,
242  "wrong! inconsistent # of dimension");
243 
244  constexpr auto I0 = Number<0>{};
245 
246  idx_diff_low(I0) = idx_diff_up[I0];
247 
248  idx_low += idx_diff_low;
249  }
250 
251  __host__ __device__ static constexpr bool IsLinearTransform() { return true; }
252 
253  __host__ __device__ static constexpr bool IsValidUpperIndexAlwaysMappedToValidLowerIndex()
254  {
255  return SkipIsValidCheck;
256  }
257 
258  template <typename UpIdx>
259  __host__ __device__ constexpr bool
260  IsValidUpperIndexMappedToValidLowerIndex(const UpIdx& idx_up) const
261  {
262  return SkipIsValidCheck || (idx_up[Number<0>{}] >= left_pad_length_);
263  }
264 
265  __host__ __device__ static constexpr bool IsKnownAtCompileTime()
266  {
269  }
270 
271  __host__ __device__ void Print() const
272  {
273  printf("{");
274  printf("LeftPad, ");
275  printf("up_lengths_");
277  printf("left_pad_length_ %d", index_t{left_pad_length_});
278  printf("}");
279  }
280 };
281 
282 template <typename LowLength, typename RightPadLength, bool SkipIsValidCheck = false>
283 struct RightPad
284 {
287 
288  using UpLengths = decltype(make_tuple(LowLength{} + RightPadLength{}));
289 
291  LowLength low_length_;
292  RightPadLength right_pad_length_;
293 
294  __host__ __device__ constexpr RightPad() = default;
295 
296  __host__ __device__ constexpr RightPad(const LowLength& low_length,
297  const RightPadLength& right_pad_length)
298  : up_lengths_{make_tuple(low_length + right_pad_length)},
299  low_length_{low_length},
300  right_pad_length_{right_pad_length}
301  {
302  }
303 
304  __host__ __device__ static constexpr index_t GetNumOfLowerDimension() { return 1; }
305 
306  __host__ __device__ static constexpr index_t GetNumOfUpperDimension() { return 1; }
307 
308  __host__ __device__ constexpr const auto& GetUpperLengths() const { return up_lengths_; }
309 
310  template <typename LowIdx, typename UpIdx>
311  __host__ __device__ static constexpr void CalculateLowerIndex(LowIdx& idx_low,
312  const UpIdx& idx_up)
313  {
314  static_assert(LowIdx::Size() == 1 && UpIdx::Size() == 1,
315  "wrong! inconsistent # of dimension");
316 
317  idx_low(Number<0>{}) = idx_up[Number<0>{}];
318  }
319 
320  template <typename LowIdxDiff,
321  typename UpIdxDiff,
322  typename LowIdx,
323  typename UpIdx,
324  index_t Hack>
325  __host__ __device__ static void UpdateLowerIndex(LowIdxDiff& idx_diff_low,
326  const UpIdxDiff& idx_diff_up,
327  LowIdx& idx_low,
328  const UpIdx&,
329  Number<Hack>)
330  {
331  static_assert(LowIdxDiff::Size() == 1 && UpIdxDiff::Size() == 1 && LowIdx::Size() == 1 &&
332  UpIdx::Size() == 1,
333  "wrong! inconsistent # of dimension");
334 
335  constexpr auto I0 = Number<0>{};
336 
337  idx_diff_low(I0) = idx_diff_up[I0];
338 
339  idx_low += idx_diff_low;
340  }
341 
342  __host__ __device__ static constexpr bool IsLinearTransform() { return true; }
343 
344  __host__ __device__ static constexpr bool IsValidUpperIndexAlwaysMappedToValidLowerIndex()
345  {
346  return SkipIsValidCheck;
347  }
348 
349  template <typename UpIdx>
350  __host__ __device__ constexpr bool
351  IsValidUpperIndexMappedToValidLowerIndex(const UpIdx& idx_up) const
352  {
353  return SkipIsValidCheck || (idx_up[Number<0>{}] < low_length_);
354  }
355 
356  __host__ __device__ static constexpr bool IsKnownAtCompileTime()
357  {
361  }
362 
363  __host__ __device__ void Print() const
364  {
365  printf("{");
366  printf("RightPad, ");
367  printf("up_lengths_");
369  printf("low_length_ %d", index_t{low_length_});
370  printf("left_pad_length_ %d", index_t{right_pad_length_});
371  printf("}");
372  }
373 };
374 
375 // idx_low = coefficients[0, ...nDimUp-1] * idx_up[0, ...nDimUp-1]
376 // UpLengths and Coefficients can be either of the followings:
377 // 1) Tuple of index_t, which is known at run-time, or
378 // 2) Tuple of Number, which is known at compile-time, or
379 // 3) Tuple of mixture of index_t and Number, which is known partially at run-time and partially
380 // at compile-time
381 template <typename UpLengths,
382  typename Coefficients,
383  typename enable_if<UpLengths::Size() == Coefficients::Size(), bool>::type = false>
384 struct Embed
385 {
386  static constexpr index_t NDimUp = UpLengths::Size();
387 
390 
391  UpLengths up_lengths_;
392  Coefficients coefficients_;
393 
394  __host__ __device__ constexpr Embed() = default;
395 
396  __host__ __device__ constexpr Embed(const UpLengths& up_lengths,
397  const Coefficients& coefficients)
398  : up_lengths_{up_lengths}, coefficients_{coefficients}
399  {
400  }
401 
402  __host__ __device__ static constexpr index_t GetNumOfLowerDimension() { return 1; }
403 
404  __host__ __device__ static constexpr index_t GetNumOfUpperDimension() { return NDimUp; }
405 
406  __host__ __device__ constexpr const auto& GetUpperLengths() const { return up_lengths_; }
407 
408  template <typename LowIdx, typename UpIdx>
409  __host__ __device__ constexpr void CalculateLowerIndex(LowIdx& idx_low,
410  const UpIdx& idx_up) const
411  {
412  static_assert(LowIdx::Size() == 1 && UpIdx::Size() == NDimUp,
413  "wrong! inconsistent # of dimension");
414 
415  idx_low(Number<0>{}) = 0;
416 
417  static_for<0, NDimUp, 1>{}([&idx_low, &idx_up, this](auto i) {
418  idx_low(Number<0>{}) += idx_up[i] * this->coefficients_[i];
419  });
420  }
421 
422  template <typename LowIdxDiff,
423  typename UpIdxDiff,
424  typename LowIdx,
425  typename UpIdx,
426  index_t Hack>
427  __host__ __device__ void UpdateLowerIndex(LowIdxDiff& idx_diff_low,
428  const UpIdxDiff& idx_diff_up,
429  LowIdx& idx_low,
430  const UpIdx&,
431  Number<Hack>) const
432  {
433  static_assert(LowIdxDiff::Size() == 1 && UpIdxDiff::Size() == NDimUp &&
434  LowIdx::Size() == 1 && UpIdx::Size() == NDimUp,
435  "wrong! inconsistent # of dimension");
436 
437  idx_diff_low(Number<0>{}) = 0;
438 
440  [&](auto i) { idx_diff_low(Number<0>{}) += idx_diff_up[i] * coefficients_[i]; });
441 
442  idx_low += idx_diff_low;
443  }
444 
445  __host__ __device__ static constexpr bool IsLinearTransform() { return true; }
446 
447  __host__ __device__ static constexpr bool IsValidUpperIndexAlwaysMappedToValidLowerIndex()
448  {
449  return true;
450  }
451 
452  template <typename UpIdx>
453  __host__ __device__ static constexpr bool
454  IsValidUpperIndexMappedToValidLowerIndex(const UpIdx& /* idx_up */)
455  {
456  return true;
457  }
458 
459  __host__ __device__ static constexpr bool IsKnownAtCompileTime()
460  {
463  }
464 
465  __host__ __device__ void Print() const
466  {
467  printf("{");
468  printf("Embed, ");
469  printf("up_lengths_ ");
471  printf("coefficients_ ");
473  printf("}");
474  }
475 };
476 
477 // Implementation of "Merge" transformation primitive that uses regular to do lowering of
478 // multi-index and use carry-and-borrow check to do lowering of multi-index delta
479 template <typename LowLengths>
481 {
482  static constexpr index_t NDimLow = LowLengths::Size();
483 
486 
488  decltype(container_reverse_exclusive_scan(LowLengths{}, math::multiplies{}, Number<1>{}));
489 
490  using UpLengths =
491  decltype(make_tuple(container_reduce(LowLengths{}, math::multiplies{}, Number<1>{})));
492 
493  LowLengths low_lengths_;
496 
497  __host__ __device__ constexpr Merge_v1_carry_check() = default;
498 
499  __host__ __device__ constexpr Merge_v1_carry_check(const LowLengths& low_lengths)
500  : low_lengths_{low_lengths},
502  container_reverse_exclusive_scan(low_lengths, math::multiplies{}, Number<1>{})},
503  up_lengths_{make_tuple(container_reduce(low_lengths, math::multiplies{}, Number<1>{}))}
504  {
505  static_assert(LowerIndex::Size() == NDimLow, "wrong!");
506  }
507 
508  __host__ __device__ static constexpr index_t GetNumOfLowerDimension() { return NDimLow; }
509 
510  __host__ __device__ static constexpr index_t GetNumOfUpperDimension() { return 1; }
511 
512  __host__ __device__ constexpr const auto& GetUpperLengths() const { return up_lengths_; }
513 
514  template <typename LowIdx, typename UpIdx>
515  __host__ __device__ constexpr void CalculateLowerIndex(LowIdx& idx_low,
516  const UpIdx& idx_up) const
517  {
518  static_assert(LowIdx::Size() == NDimLow && UpIdx::Size() == 1,
519  "wrong! inconsistent # of dimension");
520 
521  index_t tmp = idx_up[Number<0>{}];
522 
523  // normal division
524  static_for<0, NDimLow - 1, 1>{}([&](auto i) {
525  idx_low(i) = tmp / this->low_lengths_scan_[i];
526  tmp -= idx_low[i] * this->low_lengths_scan_[i];
527  });
528 
529  idx_low(Number<NDimLow - 1>{}) = tmp;
530  }
531 
532  template <typename LowIdxDiff,
533  typename UpIdxDiff,
534  typename LowIdx,
535  typename UpIdx,
536  index_t Hack>
537  __host__ __device__ void UpdateLowerIndex_1a(LowIdxDiff& idx_diff_low,
538  const UpIdxDiff& idx_diff_up,
539  LowIdx& idx_low,
540  const UpIdx& /* idx_up_new */,
541  Number<Hack>) const
542  {
543  static_assert(LowIdxDiff::Size() == NDimLow && UpIdxDiff::Size() == 1 &&
544  LowIdx::Size() == NDimLow && UpIdx::Size() == 1,
545  "wrong! inconsistent # of dimension");
546 
547  // CalculateLowerIndex(idx_diff_low_const) has multiple integer divisions.
548  // However,
549  // 1) If idx_diff_up is known at compile-time, then idx_diff_low_const
550  // can be calculated at compile-time.
551  // 2) If idx_diff_up is not known at compile-time, but its value
552  // doesn't change during the whole kernel execution, then
553  // idx_diff_low_const also
554  // doesn't change during the whole kernel execution. Compiler generated
555  // ISA should
556  // only caclculate idx_diff_low_const once and save it durinng the whole
557  // kernel execution
558  // If neither 1) nor 2) is satisfied, then the calculation will also be
559  // computed at
560  // run-time each time this function is called, and can be very expensive.
561  LowerIndex idx_diff_low_const;
562  LowerIndex idx_low_length_minus_idx_diff_low_const;
563  LowerIndex idx_low_length_plus_idx_diff_low_const;
564 
565 #if !CK_HACK_MERGE_CALCULATE_IDX_DIFF_LOW_CONST_USE_AMD_GCN_READ_FIRST_LANE
566  index_t tmp = idx_diff_up[Number<0>{}];
567 
568  static_for<0, NDimLow - 1, 1>{}([&](auto i) {
569  idx_diff_low_const(i) = tmp / low_lengths_scan_[i];
570  tmp -= idx_diff_low_const[i] * low_lengths_scan_[i];
571  });
572 
573  idx_diff_low_const(Number<NDimLow - 1>{}) = tmp;
574 
575  static_for<0, NDimLow, 1>{}([&](auto i) {
576  idx_low_length_minus_idx_diff_low_const(i) = low_lengths_[i] - idx_diff_low_const[i];
577 
578  idx_low_length_plus_idx_diff_low_const(i) = low_lengths_[i] + idx_diff_low_const[i];
579  });
580 #else
581  // Hack: this force result into SGPR. Need to make sure the result is thread invariant
582  index_t tmp = idx_diff_up[Number<0>{}];
583 
584  static_for<0, NDimLow - 1, 1>{}([&](auto i) {
585  idx_diff_low_const(i) = __builtin_amdgcn_readfirstlane(tmp / low_lengths_scan_[i]);
586  tmp -= idx_diff_low_const[i] * low_lengths_scan_[i];
587  });
588 
589  idx_diff_low_const(Number<NDimLow - 1>{}) = __builtin_amdgcn_readfirstlane(tmp);
590 
591  static_for<0, NDimLow, 1>{}([&](auto i) {
592  idx_low_length_minus_idx_diff_low_const(i) =
593  __builtin_amdgcn_readfirstlane(low_lengths_[i] - idx_diff_low_const[i]);
594 
595  idx_low_length_plus_idx_diff_low_const(i) =
596  __builtin_amdgcn_readfirstlane(low_lengths_[i] + idx_diff_low_const[i]);
597  });
598 #endif
599 
600  if constexpr(Hack == 1)
601  {
602  // do carry check on each low dimension in reversed order
603  // do not need to check the first dimension
604  index_t carry = 0;
605 
606  static_for<NDimLow - 1, 0, -1>{}([&](auto i) {
607  index_t idx_low_tmp = idx_low[i] + carry;
608 
609  bool do_carry = idx_low_tmp >= idx_low_length_minus_idx_diff_low_const[i];
610 
611  idx_diff_low(i) =
612  do_carry ? -idx_low_length_minus_idx_diff_low_const[i] : idx_diff_low_const[i];
613 
614  idx_diff_low(i) += carry;
615 
616  carry = do_carry ? 1 : 0;
617  });
618 
619  idx_diff_low(Number<0>{}) = idx_diff_low_const[Number<0>{}] + carry;
620 
621  idx_low += idx_diff_low;
622  }
623  else if constexpr(Hack == 2)
624  {
625  // do carry check on each low dimension in reversed order
626  // do not need to check the first dimension
627  index_t borrow = 0;
628 
629  static_for<NDimLow - 1, 0, -1>{}([&](auto i) {
630  index_t idx_low_tmp = idx_low[i] - borrow;
631 
632  bool do_borrow = idx_low_tmp < -idx_diff_low_const[i];
633 
634  idx_diff_low(i) =
635  do_borrow ? idx_low_length_plus_idx_diff_low_const[i] : idx_diff_low_const[i];
636 
637  idx_diff_low(i) -= borrow;
638 
639  borrow = do_borrow ? 1 : 0;
640  });
641 
642  idx_diff_low(Number<0>{}) = idx_diff_low_const[Number<0>{}] - borrow;
643 
644  idx_low += idx_diff_low;
645  }
646  else
647  {
648  // do carry check on each low dimension in reversed order
649  // do not need to check the first dimension
650  index_t carry = 0;
651 
652  static_for<NDimLow - 1, 0, -1>{}([&](auto i) {
653  index_t idx_low_tmp = idx_low[i] + carry;
654 
655  bool do_carry = idx_low_tmp >= idx_low_length_minus_idx_diff_low_const[i];
656  bool do_borrow = idx_low_tmp < -idx_diff_low_const[i];
657 
658  idx_diff_low(i) =
659  do_carry ? -idx_low_length_minus_idx_diff_low_const[i] : idx_diff_low_const[i];
660  idx_diff_low(i) =
661  do_borrow ? idx_low_length_plus_idx_diff_low_const[i] : idx_diff_low[i];
662 
663  idx_diff_low(i) += carry;
664 
665  carry = do_carry ? 1 : 0;
666  carry = do_borrow ? -1 : carry;
667  });
668 
669  idx_diff_low(Number<0>{}) = idx_diff_low_const[Number<0>{}] + carry;
670 
671  idx_low += idx_diff_low;
672  }
673  }
674 
675  template <typename LowIdxDiff,
676  typename UpIdxDiff,
677  typename LowIdx,
678  typename UpIdx,
679  index_t Hack>
680  __host__ __device__ void UpdateLowerIndex_1b(LowIdxDiff& idx_diff_low,
681  const UpIdxDiff& idx_diff_up,
682  LowIdx& idx_low,
683  const UpIdx& /* idx_up_new */,
684  Number<Hack>) const
685  {
686  static_assert(LowIdxDiff::Size() == NDimLow && UpIdxDiff::Size() == 1 &&
687  LowIdx::Size() == NDimLow && UpIdx::Size() == 1,
688  "wrong! inconsistent # of dimension");
689 
690  // CalculateLowerIndex(idx_diff_low_const) has multiple integer divisions.
691  // However,
692  // 1) If idx_diff_up is known at compile-time, then idx_diff_low_const
693  // can be calculated at compile-time.
694  // 2) If idx_diff_up is not known at compile-time, but its value
695  // doesn't change during the whole kernel execution, then
696  // idx_diff_low_const also
697  // doesn't change during the whole kernel execution. Compiler generated
698  // ISA should
699  // only caclculate idx_diff_low_const once and save it durinng the whole
700  // kernel execution
701  // If neither 1) nor 2) is satisfied, then the calculation will also be
702  // computed at
703  // run-time each time this function is called, and can be very expensive.
704  LowerIndex idx_diff_low_const;
705  LowerIndex idx_low_length_minus_idx_diff_low_const;
706  LowerIndex idx_low_length_plus_idx_diff_low_const;
707 
708 #if !CK_HACK_MERGE_CALCULATE_IDX_DIFF_LOW_CONST_USE_AMD_GCN_READ_FIRST_LANE
709  index_t tmp = idx_diff_up[Number<0>{}];
710 
711  static_for<0, NDimLow - 1, 1>{}([&](auto i) {
712  idx_diff_low_const(i) = tmp / low_lengths_scan_[i];
713  tmp -= idx_diff_low_const[i] * low_lengths_scan_[i];
714  });
715 
716  idx_diff_low_const(Number<NDimLow - 1>{}) = tmp;
717 
718  static_for<0, NDimLow, 1>{}([&](auto i) {
719  idx_low_length_minus_idx_diff_low_const(i) = low_lengths_[i] - idx_diff_low_const[i];
720 
721  idx_low_length_plus_idx_diff_low_const(i) = low_lengths_[i] + idx_diff_low_const[i];
722  });
723 #else
724  // Hack: this force result into SGPR. Need to make sure the result is thread invariant
725  index_t tmp = idx_diff_up[Number<0>{}];
726 
727  static_for<0, NDimLow - 1, 1>{}([&](auto i) {
728  idx_diff_low_const(i) = __builtin_amdgcn_readfirstlane(tmp / low_lengths_scan_[i]);
729  tmp -= idx_diff_low_const[i] * low_lengths_scan_[i];
730  });
731 
732  idx_diff_low_const(Number<NDimLow - 1>{}) = __builtin_amdgcn_readfirstlane(tmp);
733 
734  static_for<0, NDimLow, 1>{}([&](auto i) {
735  idx_low_length_minus_idx_diff_low_const(i) =
736  __builtin_amdgcn_readfirstlane(low_lengths_[i] - idx_diff_low_const[i]);
737 
738  idx_low_length_plus_idx_diff_low_const(i) = low_lengths_[i] + idx_diff_low_const[i];
739  });
740 #endif
741 
742  if constexpr(Hack == 1)
743  {
744  // do carry check on each low dimension in reversed order
745  // do not need to check the first dimension
746  index_t carry = 0;
747 
748  static_for<NDimLow - 1, 0, -1>{}([&](auto i) {
749  index_t idx_low_tmp = idx_low[i] + carry;
750 
751  bool do_carry = idx_low_tmp >= idx_low_length_minus_idx_diff_low_const[i];
752 
753  idx_diff_low(i) =
754  do_carry ? -idx_low_length_minus_idx_diff_low_const[i] : idx_diff_low_const[i];
755 
756  idx_diff_low(i) += carry;
757 
758  carry = do_carry ? 1 : 0;
759  });
760 
761  idx_diff_low(Number<0>{}) = idx_diff_low_const[Number<0>{}] + carry;
762 
763  idx_low += idx_diff_low;
764  }
765  else if constexpr(Hack == 2)
766  {
767  // do carry check on each low dimension in reversed order
768  // do not need to check the first dimension
769  index_t borrow = 0;
770 
771  static_for<NDimLow - 1, 0, -1>{}([&](auto i) {
772  index_t negative_idx_low_tmp = borrow - idx_low[i];
773 
774  bool do_borrow = negative_idx_low_tmp > idx_diff_low_const[i];
775 
776  idx_diff_low(i) =
777  do_borrow ? idx_low_length_plus_idx_diff_low_const[i] : idx_diff_low_const[i];
778 
779  idx_diff_low(i) -= borrow;
780 
781  borrow = do_borrow ? 1 : 0;
782  });
783 
784  idx_diff_low(Number<0>{}) = idx_diff_low_const[Number<0>{}] - borrow;
785 
786  idx_low += idx_diff_low;
787  }
788  else
789  {
790  // do carry check on each low dimension in reversed order
791  // do not need to check the first dimension
792  index_t carry = 0;
793 
794  static_for<NDimLow - 1, 0, -1>{}([&](auto i) {
795  index_t idx_low_tmp = idx_low[i] + carry;
796 
797  bool do_carry = idx_low_tmp >= idx_low_length_minus_idx_diff_low_const[i];
798  bool do_borrow = idx_low_tmp < -idx_diff_low_const[i];
799 
800  idx_diff_low(i) =
801  do_carry ? -idx_low_length_minus_idx_diff_low_const[i] : idx_diff_low_const[i];
802  idx_diff_low(i) =
803  do_borrow ? idx_low_length_plus_idx_diff_low_const[i] : idx_diff_low[i];
804 
805  idx_diff_low(i) += carry;
806 
807  carry = do_carry ? 1 : 0;
808  carry = do_borrow ? -1 : carry;
809  });
810 
811  idx_diff_low(Number<0>{}) = idx_diff_low_const[Number<0>{}] + carry;
812 
813  idx_low += idx_diff_low;
814  }
815  }
816 
817  template <typename LowIdxDiff,
818  typename UpIdxDiff,
819  typename LowIdx,
820  typename UpIdx,
821  index_t Hack>
822  __host__ __device__ void UpdateLowerIndex_2(LowIdxDiff& idx_diff_low,
823  const UpIdxDiff& idx_diff_up,
824  LowIdx& idx_low,
825  const UpIdx& /* idx_up_new */,
826  Number<Hack>) const
827  {
828  static_assert(LowIdxDiff::Size() == NDimLow && UpIdxDiff::Size() == 1 &&
829  LowIdx::Size() == NDimLow && UpIdx::Size() == 1,
830  "wrong! inconsistent # of dimension");
831 
832  // CalculateLowerIndex(idx_diff_low_const) has multiple integer divisions.
833  // However,
834  // 1) If idx_diff_up is known at compile-time, then idx_diff_low_const
835  // can be calculated at compile-time.
836  // 2) If idx_diff_up is not known at compile-time, but its value
837  // doesn't change during the whole kernel execution, then
838  // idx_diff_low_const also
839  // doesn't change during the whole kernel execution. Compiler generated
840  // ISA should
841  // only caclculate idx_diff_low_const once and save it durinng the whole
842  // kernel execution
843  // If neither 1) nor 2) is satisfied, then the calculation will also be
844  // computed at run-time each time this function is called, and can be
845  // very expensive.
846  LowerIndex idx_diff_low_const;
847 
848 #if !CK_HACK_MERGE_CALCULATE_IDX_DIFF_LOW_CONST_USE_AMD_GCN_READ_FIRST_LANE
849  index_t tmp = idx_diff_up[Number<0>{}];
850 
851  static_for<0, NDimLow - 1, 1>{}([&](auto i) {
852  idx_diff_low_const(i) = tmp / low_lengths_scan_[i];
853  tmp -= idx_diff_low_const[i] * low_lengths_scan_[i];
854  });
855 
856  idx_diff_low_const(Number<NDimLow - 1>{}) = tmp;
857 #else
858  // Hack: this force result into SGPR. Need to make sure the result is thread invariant
859  index_t tmp = idx_diff_up[Number<0>{}];
860 
861  static_for<0, NDimLow - 1, 1>{}([&](auto i) {
862  idx_diff_low_const(i) = __builtin_amdgcn_readfirstlane(tmp / low_lengths_scan_[i]);
863  tmp -= idx_diff_low_const[i] * low_lengths_scan_[i];
864  });
865 
866  idx_diff_low_const(Number<NDimLow - 1>{}) = __builtin_amdgcn_readfirstlane(tmp);
867 #endif
868 
869  if constexpr(Hack == 1)
870  {
871  // do carry check on each low dimension in reversed order
872  // do not need to check the first dimension
873  bool do_carry = 0;
874 
875  static_for<NDimLow - 1, 0, -1>{}([&](auto i) {
876  idx_diff_low(i) = idx_diff_low_const[i] + do_carry;
877 
878  index_t idx_low_tmp = idx_low[i] + idx_diff_low[i];
879 
880  do_carry = idx_low_tmp >= low_lengths_[i];
881 
882 #if 0
883  // TODO: use exec-mask inline asm, which use 1 VALU
884  if(do_carry)
885  {
886  idx_diff_low(i) -= low_lengths_[i];
887  }
888 #elif 1
889  // this use 2 VALU
890  idx_diff_low(i) = do_carry ? idx_diff_low[i] - low_lengths_[i] : idx_diff_low[i];
891 #elif 1
892  // this use 2 VALU
893  index_t idx_diff_low_tmp = idx_diff_low[i] - low_lengths_[i];
894  idx_diff_low(i) = do_carry ? idx_diff_low_tmp : idx_diff_low[i];
895 #endif
896 
897  idx_low(i) += idx_diff_low[i];
898  });
899 
900  constexpr auto I0 = Number<0>{};
901 
902  idx_diff_low(I0) = idx_diff_low_const[I0] + do_carry;
903 
904  idx_low(I0) += idx_diff_low[I0];
905  }
906  else if constexpr(Hack == 2)
907  {
908  // do borrow check on each low dimension in reversed order
909  // do not need to check the first dimension
910  bool do_borrow = 0;
911 
912  static_for<NDimLow - 1, 0, -1>{}([&](auto i) {
913  idx_diff_low(i) = idx_diff_low_const[i] - do_borrow;
914 
915  index_t idx_low_tmp = idx_low[i] + idx_diff_low[i];
916 
917  do_borrow = idx_low_tmp < 0;
918 
919 #if 0
920  // TODO: use exec-mask inline asm
921  if(do_borrow)
922  {
923  idx_diff_low(i) += low_lengths_[i];
924  }
925 #elif 1
926  idx_diff_low(i) = do_borrow ? idx_diff_low[i] + low_lengths_[i] : idx_diff_low[i];
927 #elif 1
928  index_t idx_diff_low_tmp = idx_diff_low[i] + low_lengths_[i];
929  idx_diff_low(i) = do_borrow ? idx_diff_low_tmp : idx_diff_low[i];
930 #endif
931 
932  idx_low(i) += idx_diff_low[i];
933  });
934 
935  constexpr auto I0 = Number<0>{};
936 
937  idx_diff_low(I0) = idx_diff_low_const[I0] - do_borrow;
938 
939  idx_low(I0) += idx_diff_low[I0];
940  }
941  else
942  {
943  // not implemented
944  }
945  }
946 
947  template <typename LowIdxDiff,
948  typename UpIdxDiff,
949  typename LowIdx,
950  typename UpIdx,
951  index_t Hack>
952  __host__ __device__ void UpdateLowerIndex(LowIdxDiff& idx_diff_low,
953  const UpIdxDiff& idx_diff_up,
954  LowIdx& idx_low,
955  const UpIdx& idx_up_new,
956  Number<Hack>) const
957  {
958 #if 1
959  UpdateLowerIndex_1a(idx_diff_low, idx_diff_up, idx_low, idx_up_new, Number<Hack>{});
960 #elif 0
961  UpdateLowerIndex_1b(idx_diff_low, idx_diff_up, idx_low, idx_up_new, Number<Hack>{});
962 #else
963  UpdateLowerIndex_2(idx_diff_low, idx_diff_up, idx_low, idx_up_new, Number<Hack>{});
964 #endif
965  }
966 
967  __host__ __device__ static constexpr bool IsLinearTransform() { return false; }
968 
969  __host__ __device__ static constexpr bool IsValidUpperIndexAlwaysMappedToValidLowerIndex()
970  {
971  return true;
972  }
973 
974  __host__ __device__ static constexpr bool IsKnownAtCompileTime()
975  {
979  }
980 
981  template <typename UpIdx>
982  __host__ __device__ static constexpr bool
983  IsValidUpperIndexMappedToValidLowerIndex(const UpIdx& /* idx_up */)
984  {
985  return true;
986  }
987 
988  __host__ __device__ void Print() const
989  {
990  printf("{");
991  printf("Merge_v1_carry_check, ");
992  printf("low_lengths_ ");
993  print_multi_index(low_lengths_);
994  printf("low_lengths_scan_ ");
995  print_multi_index(low_lengths_scan_);
996  printf("up_lengths_ ");
997  print_multi_index(up_lengths_);
998  printf("}");
999  }
1000 };
1001 
1002 template <typename LowLengths>
1004 {
1005  template <index_t I>
1006  __host__ __device__ constexpr auto operator()(Number<I> i) const
1007  {
1008  return MagicDivision::CalculateMagicMultiplier(LowLengths{}[i]);
1009  }
1010 };
1011 
1012 template <typename LowLengths>
1014 {
1015  template <index_t I>
1016  __host__ __device__ constexpr auto operator()(Number<I> i) const
1017  {
1018  return MagicDivision::CalculateMagicShift(LowLengths{}[i]);
1019  }
1020 };
1021 
1022 // Implementation of "Merge" transformation primitive that uses magic-number-division to do lowering
1023 // of both multi-index and delta of multi-index
1024 // Caution:
1025 // 1. The magic number division implementation being used would produce correct result if the
1026 // dividended is uint32_t and its value is with in 31-bit value range of uint32_t.
1027 // 2. The magic number division for int32_t dividened has not been implemented, the int32_t
1028 // dividend would be bit-wise interpreted as uint32_t and magic number division implementation for
1029 // uint32_t is then used.
1030 // 3. For Merge primitive, upper-index is the dividend.
1031 // 4. When upper-index is uint32_t, its value need to be within 31-bit range.
1032 // 5. When upper-index is int32_t type (when index_t is int32_t), its value need to be
1033 // non-negative.
1034 template <typename LowLengths>
1036 {
1037  static constexpr index_t NDimLow = LowLengths::Size();
1038 
1041 
1042  using UpLengths =
1043  decltype(make_tuple(container_reduce(LowLengths{}, math::multiplies{}, Number<1>{})));
1044 
1047  Number<NDimLow>{}));
1048 
1051  Number<NDimLow>{}));
1052 
1053  LowLengths low_lengths_;
1057 
1058  __host__ __device__ constexpr Merge_v2_magic_division() = default;
1059 
1060  __host__ __device__ constexpr Merge_v2_magic_division(const LowLengths& low_lengths)
1061  : low_lengths_{low_lengths},
1062  low_lengths_magic_divisor_multiplier_{generate_tuple(
1063  [&](auto i) { return MagicDivision::CalculateMagicMultiplier(low_lengths[i]); },
1064  Number<NDimLow>{})},
1065  low_lengths_magic_divisor_shift_{generate_tuple(
1066  [&](auto i) { return MagicDivision::CalculateMagicShift(low_lengths[i]); },
1067  Number<NDimLow>{})},
1068  up_lengths_{make_tuple(container_reduce(low_lengths, math::multiplies{}, Number<1>{}))}
1069  {
1070  static_assert(LowerIndex::Size() == NDimLow, "wrong!");
1071  }
1072 
1073  __host__ __device__ static constexpr index_t GetNumOfLowerDimension() { return NDimLow; }
1074 
1075  __host__ __device__ static constexpr index_t GetNumOfUpperDimension() { return 1; }
1076 
1077  __host__ __device__ constexpr const auto& GetUpperLengths() const { return up_lengths_; }
1078 
1079  template <typename LowIdx, typename UpIdx>
1080  __host__ __device__ constexpr void CalculateLowerIndex(LowIdx& idx_low,
1081  const UpIdx& idx_up) const
1082  {
1083  static_assert(LowIdx::Size() == NDimLow && UpIdx::Size() == 1,
1084  "wrong! inconsistent # of dimension");
1085 
1086  index_t tmp = idx_up[Number<0>{}];
1087 
1088  static_for<NDimLow - 1, 0, -1>{}([&, this](auto i) {
1089  index_t tmp2 =
1090  MagicDivision::DoMagicDivision(tmp,
1091  this->low_lengths_magic_divisor_multiplier_[i],
1092  this->low_lengths_magic_divisor_shift_[i]);
1093  idx_low(i) = tmp - tmp2 * this->low_lengths_[i];
1094  tmp = tmp2;
1095  });
1096 
1097  idx_low(Number<0>{}) = tmp;
1098  }
1099 
1100  template <typename LowIdxDiff,
1101  typename UpIdxDiff,
1102  typename LowIdx,
1103  typename UpIdx,
1104  index_t Hack>
1105  __host__ __device__ void UpdateLowerIndex(LowIdxDiff& idx_diff_low,
1106  const UpIdxDiff&,
1107  LowIdx& idx_low,
1108  const UpIdx& idx_up_new,
1109  Number<Hack>) const
1110  {
1111  static_assert(LowIdxDiff::Size() == NDimLow && UpIdxDiff::Size() == 1 &&
1112  LowIdx::Size() == NDimLow && UpIdx::Size() == 1,
1113  "wrong! inconsistent # of dimension");
1114 
1115  index_t tmp = idx_up_new[Number<0>{}];
1116 
1117  static_for<NDimLow - 1, 0, -1>{}([&, this](auto i) {
1118  index_t tmp2 =
1119  MagicDivision::DoMagicDivision(tmp,
1120  this->low_lengths_magic_divisor_multiplier_[i],
1121  this->low_lengths_magic_divisor_shift_[i]);
1122 
1123  index_t idx_low_old = idx_low[i];
1124 
1125  idx_low(i) = tmp - tmp2 * this->low_lengths_[i];
1126  tmp = tmp2;
1127 
1128  idx_diff_low(i) = idx_low[i] - idx_low_old;
1129  });
1130 
1131  idx_diff_low(Number<0>{}) = tmp - idx_low(Number<0>{});
1132 
1133  idx_low(Number<0>{}) = tmp;
1134  }
1135 
1136  __host__ __device__ static constexpr bool IsLinearTransform() { return false; }
1137 
1138  __host__ __device__ static constexpr bool IsValidUpperIndexAlwaysMappedToValidLowerIndex()
1139  {
1140  return true;
1141  }
1142 
1143  __host__ __device__ static constexpr bool IsKnownAtCompileTime()
1144  {
1149  }
1150 
1151  template <typename UpIdx>
1152  __host__ __device__ static constexpr bool
1153  IsValidUpperIndexMappedToValidLowerIndex(const UpIdx& /* idx_up */)
1154  {
1155  return true;
1156  }
1157 
1158  __host__ __device__ void Print() const
1159  {
1160  printf("{");
1161  printf("Merge_v2_magic_division, ");
1162  printf("low_lengths_ ");
1163  print_multi_index(low_lengths_);
1164  printf("low_lengths_magic_divisor_multiplier_ ");
1165  print_multi_index(low_lengths_magic_divisor_multiplier_);
1166  printf("low_lengths_magic_divisor_shift_ ");
1167  print_multi_index(low_lengths_magic_divisor_shift_);
1168  printf("up_lengths_ ");
1169  print_multi_index(up_lengths_);
1170  printf("}");
1171  }
1172 };
1173 
1174 // Implementation of "Merge" transformation primitive that uses magic-number-division to do lowering
1175 // of both multi-index and delta of multi-index
1176 // Caution:
1177 // 1. The magic number division implementation being used would produce correct result if the
1178 // dividended is uint32_t and its value is with in 31-bit value range of uint32_t.
1179 // 2. The magic number division for int32_t dividened has not been implemented, the int32_t
1180 // dividend would be bit-wise interpreted as uint32_t and magic number division implementation for
1181 // uint32_t is then used.
1182 // 3. For Merge primitive, upper-index is the dividend.
1183 // 4. When upper-index is uint32_t, its value need to be within 31-bit range.
1184 // 5. When upper-index is int32_t type (when index_t is int32_t), its value need to be
1185 // non-negative.
1186 template <typename LowLengths>
1188 {
1189  static constexpr index_t NDimLow = LowLengths::Size();
1190 
1193 
1195  decltype(container_reverse_exclusive_scan(LowLengths{}, math::multiplies{}, Number<1>{}));
1196 
1197  using UpLengths =
1198  decltype(make_tuple(container_reduce(LowLengths{}, math::multiplies{}, Number<1>{})));
1199 
1202  Number<NDimLow>{}));
1203 
1206  Number<NDimLow>{}));
1207 
1208  LowLengths low_lengths_;
1213 
1214  __host__ __device__ constexpr Merge_v2r2_magic_division() = default;
1215 
1216  __host__ __device__ constexpr Merge_v2r2_magic_division(const LowLengths& low_lengths)
1217  : low_lengths_{low_lengths},
1218  low_lengths_scan_{
1219  container_reverse_exclusive_scan(low_lengths, math::multiplies{}, Number<1>{})},
1220  low_lengths_scan_magic_divisor_multiplier_{generate_tuple(
1221  [&](auto i) { return MagicDivision::CalculateMagicMultiplier(low_lengths_scan_[i]); },
1222  Number<NDimLow>{})},
1223  low_lengths_scan_magic_divisor_shift_{generate_tuple(
1224  [&](auto i) { return MagicDivision::CalculateMagicShift(low_lengths_scan_[i]); },
1225  Number<NDimLow>{})},
1226  up_lengths_{make_tuple(container_reduce(low_lengths, math::multiplies{}, Number<1>{}))}
1227  {
1228  static_assert(LowerIndex::Size() == NDimLow, "wrong!");
1229  }
1230 
1231  __host__ __device__ static constexpr index_t GetNumOfLowerDimension() { return NDimLow; }
1232 
1233  __host__ __device__ static constexpr index_t GetNumOfUpperDimension() { return 1; }
1234 
1235  __host__ __device__ constexpr const auto& GetUpperLengths() const { return up_lengths_; }
1236 
1237  template <typename LowIdx, typename UpIdx>
1238  __host__ __device__ constexpr void CalculateLowerIndex(LowIdx& idx_low,
1239  const UpIdx& idx_up) const
1240  {
1241  static_assert(LowIdx::Size() == NDimLow && UpIdx::Size() == 1,
1242  "wrong! inconsistent # of dimension");
1243 
1244  index_t tmp = idx_up[Number<0>{}];
1245 
1246  static_for<0, NDimLow - 1, 1>{}([&, this](auto i) {
1247  idx_low(i) =
1248  MagicDivision::DoMagicDivision(tmp,
1249  this->low_lengths_scan_magic_divisor_multiplier_[i],
1250  this->low_lengths_scan_magic_divisor_shift_[i]);
1251 
1252  tmp -= idx_low[i] * this->low_lengths_scan_[i];
1253  });
1254 
1255  idx_low(Number<NDimLow - 1>{}) = tmp;
1256  }
1257 
1258  template <typename LowIdxDiff,
1259  typename UpIdxDiff,
1260  typename LowIdx,
1261  typename UpIdx,
1262  index_t Hack>
1263  __host__ __device__ void UpdateLowerIndex(LowIdxDiff& idx_diff_low,
1264  const UpIdxDiff&,
1265  LowIdx& idx_low,
1266  const UpIdx& idx_up_new,
1267  Number<Hack>) const
1268  {
1269  static_assert(LowIdxDiff::Size() == NDimLow && UpIdxDiff::Size() == 1 &&
1270  LowIdx::Size() == NDimLow && UpIdx::Size() == 1,
1271  "wrong! inconsistent # of dimension");
1272 
1273  index_t tmp = idx_up_new[Number<0>{}];
1274 
1275  static_for<0, NDimLow - 1, 1>{}([&, this](auto i) {
1276  index_t idx_low_old = idx_low[i];
1277 
1278  idx_low(i) =
1279  MagicDivision::DoMagicDivision(tmp,
1280  this->low_lengths_scan_magic_divisor_multiplier_[i],
1281  this->low_lengths_scan_magic_divisor_shift_[i]);
1282 
1283  idx_diff_low(i) = idx_low[i] - idx_low_old;
1284 
1285  tmp -= idx_low[i] * this->low_lengths_scan_[i];
1286  });
1287 
1288  idx_diff_low(Number<NDimLow - 1>{}) = tmp - idx_low[Number<NDimLow - 1>{}];
1289 
1290  idx_low(Number<NDimLow - 1>{}) = tmp;
1291  }
1292 
1293  __host__ __device__ static constexpr bool IsLinearTransform() { return false; }
1294 
1295  __host__ __device__ static constexpr bool IsValidUpperIndexAlwaysMappedToValidLowerIndex()
1296  {
1297  return true;
1298  }
1299 
1300  __host__ __device__ static constexpr bool IsKnownAtCompileTime()
1301  {
1306  }
1307 
1308  template <typename UpIdx>
1309  __host__ __device__ static constexpr bool
1310  IsValidUpperIndexMappedToValidLowerIndex(const UpIdx& /* idx_up */)
1311  {
1312  return true;
1313  }
1314 
1315  __host__ __device__ void Print() const
1316  {
1317  printf("{");
1318  printf("Merge_v2r2_magic_division, ");
1319  printf("low_lengths_ ");
1320  print_multi_index(low_lengths_);
1321  printf("low_lengths_scan ");
1322  print_multi_index(low_lengths_scan_);
1323  printf("low_lengths_scan_magic_divisor_multiplier_ ");
1324  print_multi_index(low_lengths_scan_magic_divisor_multiplier_);
1325  printf("low_lengths_scan_magic_divisor_shift_ ");
1326  print_multi_index(low_lengths_scan_magic_divisor_shift_);
1327  printf("up_lengths_ ");
1328  print_multi_index(up_lengths_);
1329  printf("}");
1330  }
1331 };
1332 
1333 // Implementation of "Merge" transformation primitive that uses division and mod. It is supposed to
1334 // be used for low_lengths that are known at compile time and are power of 2, otherwise performance
1335 // will be very bad
1336 template <typename LowLengths>
1338 {
1339  static constexpr index_t NDimLow = LowLengths::Size();
1340 
1343 
1345  decltype(container_reverse_exclusive_scan(LowLengths{}, math::multiplies{}, Number<1>{}));
1346 
1347  using UpLengths =
1348  decltype(make_tuple(container_reduce(LowLengths{}, math::multiplies{}, Number<1>{})));
1349 
1350  LowLengths low_lengths_;
1353 
1354  __host__ __device__ constexpr Merge_v3_division_mod() = default;
1355 
1356  __host__ __device__ constexpr Merge_v3_division_mod(const LowLengths& low_lengths)
1357  : low_lengths_{low_lengths},
1358  low_lengths_scan_{
1359  container_reverse_exclusive_scan(low_lengths, math::multiplies{}, Number<1>{})},
1360  up_lengths_{make_tuple(container_reduce(low_lengths, math::multiplies{}, Number<1>{}))}
1361  {
1362  static_assert(LowerIndex::Size() == NDimLow, "wrong!");
1363  }
1364 
1365  __host__ __device__ static constexpr index_t GetNumOfLowerDimension() { return NDimLow; }
1366 
1367  __host__ __device__ static constexpr index_t GetNumOfUpperDimension() { return 1; }
1368 
1369  __host__ __device__ constexpr const auto& GetUpperLengths() const { return up_lengths_; }
1370 
1371  template <typename LowIdx, typename UpIdx>
1372  __host__ __device__ constexpr void CalculateLowerIndex(LowIdx& idx_low,
1373  const UpIdx& idx_up) const
1374  {
1375  static_assert(LowIdx::Size() == NDimLow && UpIdx::Size() == 1,
1376  "wrong! inconsistent # of dimension");
1377 
1378  index_t tmp = idx_up[Number<0>{}];
1379 
1380  // division and mod
1381  static_for<0, NDimLow - 1, 1>{}([&](auto i) {
1382  idx_low(i) = tmp / this->low_lengths_scan_[i];
1383  tmp %= this->low_lengths_scan_[i];
1384  });
1385 
1386  idx_low(Number<NDimLow - 1>{}) = tmp;
1387  }
1388 
1389  template <typename LowIdxDiff,
1390  typename UpIdxDiff,
1391  typename LowIdx,
1392  typename UpIdx,
1393  index_t Hack>
1394  __host__ __device__ void UpdateLowerIndex(LowIdxDiff& idx_diff_low,
1395  const UpIdxDiff&,
1396  LowIdx& idx_low,
1397  const UpIdx& idx_up_new,
1398  Number<Hack>) const
1399  {
1400  static_assert(LowIdxDiff::Size() == NDimLow && UpIdxDiff::Size() == 1 &&
1401  LowIdx::Size() == NDimLow && UpIdx::Size() == 1,
1402  "wrong! inconsistent # of dimension");
1403 
1404  constexpr auto I0 = Number<0>{};
1405  constexpr auto INm1 = Number<NDimLow - 1>{};
1406 
1407  index_t tmp = idx_up_new[I0];
1408 
1409  static_for<0, NDimLow - 1, 1>{}([&](auto i) {
1410  const index_t tmp2 = idx_low[i];
1411  idx_low(i) = tmp / this->low_lengths_scan_[i];
1412  idx_diff_low(i) = idx_low[i] - tmp2;
1413  tmp %= this->low_lengths_scan_[i];
1414  });
1415 
1416  const index_t tmp2 = idx_low[INm1];
1417  idx_low(INm1) = tmp;
1418  idx_diff_low(INm1) = idx_low[INm1] - tmp2;
1419  }
1420 
1421  __host__ __device__ static constexpr bool IsLinearTransform() { return false; }
1422 
1423  __host__ __device__ static constexpr bool IsValidUpperIndexAlwaysMappedToValidLowerIndex()
1424  {
1425  return true;
1426  }
1427 
1428  __host__ __device__ static constexpr bool IsKnownAtCompileTime()
1429  {
1433  }
1434 
1435  template <typename UpIdx>
1436  __host__ __device__ static constexpr bool
1437  IsValidUpperIndexMappedToValidLowerIndex(const UpIdx& /* idx_up */)
1438  {
1439  return true;
1440  }
1441 
1442  __host__ __device__ void Print() const
1443  {
1444  printf("{");
1445  printf("Merge_v3_direct_division_mod, ");
1446  printf("low_lengths_ ");
1447  print_multi_index(low_lengths_);
1448  printf("low_lengths_scan_ ");
1449  print_multi_index(low_lengths_scan_);
1450  printf("up_lengths_ ");
1451  print_multi_index(up_lengths_);
1452  printf("}");
1453  }
1454 };
1455 
1456 template <typename UpLengths, bool Use24BitIntegerCalculation>
1457 struct UnMerge
1458 {
1459  static constexpr index_t NDimUp = UpLengths::Size();
1460 
1463 
1465  decltype(container_reverse_exclusive_scan(UpLengths{}, math::multiplies{}, Number<1>{}));
1466 
1467  UpLengths up_lengths_;
1469 
1470  __host__ __device__ constexpr UnMerge() = default;
1471 
1472  __host__ __device__ constexpr UnMerge(const UpLengths& up_lengths)
1473  : up_lengths_{up_lengths},
1474  up_lengths_scan_{
1475  container_reverse_exclusive_scan(up_lengths, math::multiplies{}, Number<1>{})}
1476  {
1477  }
1478 
1479  __host__ __device__ static constexpr index_t GetNumOfLowerDimension() { return 1; }
1480 
1481  __host__ __device__ static constexpr index_t GetNumOfUpperDimension() { return NDimUp; }
1482 
1483  __host__ __device__ constexpr const auto& GetUpperLengths() const { return up_lengths_; }
1484 
1485  template <typename LowIdx, typename UpIdx>
1486  __host__ __device__ constexpr void CalculateLowerIndex(LowIdx& idx_low,
1487  const UpIdx& idx_up) const
1488  {
1489  if constexpr(!Use24BitIntegerCalculation)
1490  {
1491  idx_low(Number<0>{}) = idx_up[Number<NDimUp - 1>{}];
1492 
1493  static_for<0, NDimUp - 1, 1>{}(
1494  [&](auto i) { idx_low(Number<0>{}) += idx_up[i] * up_lengths_scan_[i]; });
1495  }
1496  else
1497  {
1498  idx_low(Number<0>{}) = idx_up[Number<NDimUp - 1>{}];
1499 
1500  static_for<0, NDimUp - 1, 1>{}([&](auto i) {
1501  idx_low(Number<0>{}) =
1502  (0x00ffffff & idx_low[Number<0>{}]) +
1503  (0x00ffffff & idx_up[i]) * (0x00ffffff & up_lengths_scan_[i]);
1504  });
1505  }
1506  }
1507 
1508  template <typename LowIdxDiff,
1509  typename UpIdxDiff,
1510  typename LowIdx,
1511  typename UpIdx,
1512  index_t Hack>
1513  __host__ __device__ void UpdateLowerIndex(LowIdxDiff& idx_diff_low,
1514  const UpIdxDiff& idx_diff_up,
1515  LowIdx& idx_low,
1516  const UpIdx&,
1517  Number<Hack>) const
1518  {
1519  CalculateLowerIndex(idx_diff_low, idx_diff_up);
1520 
1521  idx_low += idx_diff_low;
1522  }
1523 
1524  __host__ __device__ static constexpr bool IsLinearTransform() { return true; }
1525 
1526  __host__ __device__ static constexpr bool IsValidUpperIndexAlwaysMappedToValidLowerIndex()
1527  {
1528  return true;
1529  }
1530 
1531  template <typename UpIdx>
1532  __host__ __device__ static constexpr bool
1533  IsValidUpperIndexMappedToValidLowerIndex(const UpIdx& /* idx_up */)
1534  {
1535  return true;
1536  }
1537 
1538  __host__ __device__ static constexpr bool IsKnownAtCompileTime()
1539  {
1542  }
1543 
1544  __host__ __device__ void Print() const
1545  {
1546  printf("{");
1547  printf("UnMerge, ");
1548  printf("up_lengths_");
1549  print_multi_index(up_lengths_);
1550  printf("up_lengths_scan_");
1551  print_multi_index(up_lengths_scan_);
1552  printf("}");
1553  }
1554 };
1555 
1565 {
1566  static constexpr auto I0 = Number<0>{};
1567  static constexpr auto I1 = Number<1>{};
1568  static constexpr auto I2 = Number<2>{};
1569  static constexpr auto I3 = Number<3>{};
1570 
1571  using LowerIndex = MultiIndex<4>; // N, Ho, Wo, K
1572  using UpperIndex = MultiIndex<3>; // K0, M, K1
1573 
1574  index_t N_, Ho_, Wo_, K_;
1576  index_t HTilde_, WTilde_;
1577  index_t WTildeSlice_, TildeSlice_;
1578  index_t IHTildeSliceBegin_, IWTildeSliceBegin_;
1579  index_t HRatio_, WRatio_;
1581  index_t MPad_, KPad_;
1583 
1585  low_lengths_magic_divisor_multiplier_; // XDotSlice_K_, K_, TildeSlice_, WTildeSlice_
1587  low_lengths_magic_divisor_shift_; // XDotSlice_K_, K_, TildeSlice_, WTildeSlice_
1588 
1589  __host__ __device__ constexpr ConvBwdDataImplicitGemmOutTransform() = default;
1590 
1591  __host__ __device__ constexpr ConvBwdDataImplicitGemmOutTransform(index_t N,
1592  index_t Ho,
1593  index_t Wo,
1594  index_t K,
1595  index_t XDot,
1596  index_t HTilde,
1597  index_t WTilde,
1598  index_t WTildeSlice,
1599  index_t HWTildeSlice,
1600  index_t IHTildeSliceBegin,
1601  index_t IWTildeSliceBegin,
1602  index_t HRatio,
1603  index_t WRatio,
1604  index_t XDotSlice_K,
1605  index_t K0,
1606  index_t MPadded,
1607  index_t K1,
1608  index_t MPad,
1609  index_t KPad)
1610  : N_{N},
1611  Ho_{Ho},
1612  Wo_{Wo},
1613  K_{K},
1614  XDot_{XDot},
1615  HTilde_{HTilde},
1616  WTilde_{WTilde},
1617  WTildeSlice_{WTildeSlice},
1618  TildeSlice_{HWTildeSlice},
1619  IHTildeSliceBegin_{IHTildeSliceBegin},
1620  IWTildeSliceBegin_{IWTildeSliceBegin},
1621  HRatio_{HRatio},
1622  WRatio_{WRatio},
1623  XDotSlice_K_{XDotSlice_K},
1624  MPad_{MPad},
1625  KPad_{KPad},
1626  up_lengths_{make_tuple(K0, MPadded, K1)},
1627  low_lengths_magic_divisor_multiplier_{
1628  MagicDivision::CalculateMagicMultiplier(XDotSlice_K_),
1629  MagicDivision::CalculateMagicMultiplier(K_),
1630  MagicDivision::CalculateMagicMultiplier(TildeSlice_),
1631  MagicDivision::CalculateMagicMultiplier(WTildeSlice_)},
1632  low_lengths_magic_divisor_shift_{MagicDivision::CalculateMagicShift(XDotSlice_K_),
1633  MagicDivision::CalculateMagicShift(K_),
1634  MagicDivision::CalculateMagicShift(TildeSlice_),
1635  MagicDivision::CalculateMagicShift(WTildeSlice_)}
1636  {
1637  }
1638 
1639  __host__ __device__ static constexpr index_t GetNumOfLowerDimension() { return 4; }
1640 
1641  __host__ __device__ static constexpr index_t GetNumOfUpperDimension() { return 3; }
1642 
1643  __host__ __device__ constexpr const auto& GetUpperLengths() const { return up_lengths_; }
1644 
1645  template <typename UpIdx>
1646  __host__ __device__ constexpr auto CalculateLowerIndexN(const UpIdx& idx_up) const
1647  {
1648  index_t NStep, HStep, WStep;
1649  // Merge
1650  // NStep = M_id / TildeSlice_
1651  NStep = MagicDivision::DoMagicDivision(idx_up[I1],
1652  this->low_lengths_magic_divisor_multiplier_[I2],
1653  this->low_lengths_magic_divisor_shift_[I2]);
1654  HStep = idx_up[I1] - NStep * TildeSlice_;
1655  // HStep = HStep / WTildeSlice_
1656  HStep = MagicDivision::DoMagicDivision(HStep,
1657  this->low_lengths_magic_divisor_multiplier_[I3],
1658  this->low_lengths_magic_divisor_shift_[I3]);
1659  WStep = idx_up[I1] - NStep * TildeSlice_ - HStep * WTildeSlice_;
1660  // Slice
1661  HStep += IHTildeSliceBegin_;
1662  WStep += IWTildeSliceBegin_;
1663 
1664  return make_tuple(NStep, HStep, WStep, 0);
1665  }
1666 
1667  template <typename UpIdx>
1668  __host__ __device__ constexpr auto CalculateLowerIndexK(const UpIdx& idx_up) const
1669  {
1670  // UnMerge
1671  // K_idx <- K0_idx * K1 + K1_idx
1672  index_t K_idx = idx_up[I0] * up_lengths_[I2] + idx_up[I2];
1673  // Merge
1674  // YStep = K_idx / XDotSlice_K_
1675  index_t YStep =
1676  MagicDivision::DoMagicDivision(K_idx,
1677  this->low_lengths_magic_divisor_multiplier_[I0],
1678  this->low_lengths_magic_divisor_shift_[I0]);
1679  index_t KStep = K_idx - YStep * XDotSlice_K_;
1680  // Xstep = KStep / K_
1681  index_t XStep =
1682  MagicDivision::DoMagicDivision(KStep,
1683  this->low_lengths_magic_divisor_multiplier_[I1],
1684  this->low_lengths_magic_divisor_shift_[I1]);
1685  KStep -= XStep * K_;
1686  // Embed
1687  YStep *= HRatio_;
1688  XStep *= WRatio_;
1689 
1690  return make_tuple(0, YStep, XStep, KStep);
1691  }
1692 
1693  template <typename LowIdx, typename UpIdx>
1694  __host__ __device__ constexpr void CalculateLowerIndex(LowIdx& idx_low,
1695  const UpIdx& idx_up) const
1696  {
1697  idx_low = CalculateLowerIndexN(idx_up) + CalculateLowerIndexK(idx_up);
1698  }
1699 
1700  template <typename LowIdxDiff,
1701  typename UpIdxDiff,
1702  typename LowIdx,
1703  typename UpIdx,
1704  index_t Hack>
1705  __host__ __device__ void UpdateLowerIndex(LowIdxDiff& idx_diff_low,
1706  const UpIdxDiff& /* idx_diff_up */,
1707  LowIdx& idx_low,
1708  const UpIdx& idx_up,
1709  Number<Hack>) const
1710  {
1711  LowIdx low_old = idx_low;
1712  idx_low = CalculateLowerIndexN(idx_up) + CalculateLowerIndexK(idx_up);
1713  idx_diff_low = idx_low - low_old;
1714  }
1715 
1716  __host__ __device__ static constexpr bool IsLinearTransform() { return false; }
1717 
1718  __host__ __device__ static constexpr bool IsValidUpperIndexAlwaysMappedToValidLowerIndex()
1719  {
1720  return true;
1721  }
1722 
1723  template <typename UpIdx>
1724  __host__ __device__ constexpr bool
1725  IsValidUpperIndexMappedToValidLowerIndex(const UpIdx& idx_up) const
1726  {
1727  // Padding
1728  index_t K_idx = idx_up[Number<0>{}] * up_lengths_[Number<2>{}] + idx_up[Number<2>{}];
1729  index_t& M_idx = idx_up[Number<1>{}];
1730 
1731  bool pad_valid = M_idx < up_lengths_[Number<1>{}] - MPad_ &&
1732  K_idx < up_lengths_[Number<0>{}] * up_lengths_[Number<2>{}] - KPad_;
1733  return pad_valid;
1734  }
1735 
1736  __host__ __device__ static constexpr bool IsKnownAtCompileTime() { return false; }
1737 
1738  __host__ __device__ void Print() const
1739  {
1740  printf("{");
1741  printf("ConvBwdDataImplicitGemmOutTransform, ");
1742  printf("up_lengths_");
1743  print_multi_index(up_lengths_);
1744  printf("}");
1745  }
1746 };
1747 
1748 template <typename LowerIndex>
1749 struct Freeze
1750 {
1751  LowerIndex low_idx_;
1752 
1753  __host__ __device__ constexpr Freeze() = default;
1754 
1755  __host__ __device__ constexpr Freeze(const LowerIndex& low_idx) : low_idx_{low_idx} {}
1756 
1757  __host__ __device__ static constexpr index_t GetNumOfLowerDimension() { return 1; }
1758 
1759  __host__ __device__ static constexpr index_t GetNumOfUpperDimension() { return 0; }
1760 
1761  __host__ __device__ static constexpr auto GetUpperLengths() { return Tuple<>{}; }
1762 
1763  template <typename LowIdx, typename UpIdx>
1764  __host__ __device__ constexpr void CalculateLowerIndex(LowIdx& idx_low,
1765  const UpIdx& /* idx_up */) const
1766  {
1767  static_assert(LowIdx::Size() == 1 && UpIdx::Size() == 0,
1768  "wrong! inconsistent # of dimension");
1769 
1770  idx_low(Number<0>{}) = low_idx_;
1771  }
1772 
1773  template <typename LowIdxDiff,
1774  typename UpIdxDiff,
1775  typename LowIdx,
1776  typename UpIdx,
1777  index_t Hack>
1778  __host__ __device__ static void UpdateLowerIndex(LowIdxDiff& idx_diff_low,
1779  const UpIdxDiff& /* idx_diff_up */,
1780  LowIdx& /* idx_low */,
1781  const UpIdx& /* idx_up_new */,
1782  Number<Hack>)
1783  {
1784  idx_diff_low(Number<0>{}) = 0;
1785  }
1786 
1787  __host__ __device__ static constexpr bool IsLinearTransform() { return true; }
1788 
1789  __host__ __device__ static constexpr bool IsValidUpperIndexAlwaysMappedToValidLowerIndex()
1790  {
1791  return true;
1792  }
1793 
1794  template <typename UpIdx>
1795  __host__ __device__ static constexpr bool
1796  IsValidUpperIndexMappedToValidLowerIndex(const UpIdx& /* idx_up */)
1797  {
1798  return true;
1799  }
1800 
1801  __host__ __device__ static constexpr bool IsKnownAtCompileTime()
1802  {
1804  }
1805 
1806  __host__ __device__ void Print() const
1807  {
1808  printf("Freeze");
1809  printf("low_idx_ %d", index_t{low_idx_});
1810  }
1811 };
1812 
1813 // Insert a dangling upper dimension without lower dimension
1814 template <typename UpperLength>
1815 struct Insert
1816 {
1817  using UpLengths = decltype(make_tuple(UpperLength{}));
1818 
1820 
1821  __host__ __device__ constexpr Insert() = default;
1822 
1823  __host__ __device__ constexpr Insert(const UpperLength& up_length)
1824  : up_lengths_{make_tuple(up_length)}
1825  {
1826  }
1827 
1828  __host__ __device__ static constexpr index_t GetNumOfLowerDimension() { return 0; }
1829 
1830  __host__ __device__ static constexpr index_t GetNumOfUpperDimension() { return 1; }
1831 
1832  __host__ __device__ constexpr auto GetUpperLengths() const { return up_lengths_; }
1833 
1834  template <typename LowIdx, typename UpIdx>
1835  __host__ __device__ constexpr void CalculateLowerIndex(LowIdx&, const UpIdx&) const
1836  {
1837  static_assert(LowIdx::Size() == 0 && UpIdx::Size() == 1,
1838  "wrong! inconsistent # of dimension");
1839  }
1840 
1841  template <typename LowIdxDiff,
1842  typename UpIdxDiff,
1843  typename LowIdx,
1844  typename UpIdx,
1845  index_t Hack>
1846  __host__ __device__ static void
1847  UpdateLowerIndex(LowIdxDiff&, const UpIdxDiff&, LowIdx&, const UpIdx&, Number<Hack>)
1848  {
1849  static_assert(LowIdxDiff::Size() == 0 && UpIdxDiff::Size() == 1 && LowIdx::Size() == 0 &&
1850  UpIdx::Size() == 1,
1851  "wrong! inconsistent # of dimension");
1852  }
1853 
1854  __host__ __device__ static constexpr bool IsLinearTransform() { return true; }
1855 
1856  __host__ __device__ static constexpr bool IsValidUpperIndexAlwaysMappedToValidLowerIndex()
1857  {
1858  return true;
1859  }
1860 
1861  template <typename UpIdx>
1862  __host__ __device__ static constexpr bool
1863  IsValidUpperIndexMappedToValidLowerIndex(const UpIdx& /* idx_up */)
1864  {
1865  return true;
1866  }
1867 
1868  __host__ __device__ static constexpr bool IsKnownAtCompileTime()
1869  {
1871  }
1872 
1873  __host__ __device__ void Print() const
1874  {
1875  printf("Insert");
1876  print_multi_index(up_lengths_);
1877  }
1878 };
1879 
1880 template <typename VectorSize, typename UpLength>
1882 {
1885 
1886  using UpLengths = decltype(make_tuple(UpLength{}));
1887 
1889  VectorSize vector_size_;
1890 
1891  __host__ __device__ constexpr Vectorize() = default;
1892 
1893  __host__ __device__ constexpr Vectorize(const VectorSize& vector_size,
1894  const UpLength& up_length)
1895  : vector_size_{vector_size}, up_lengths_{make_tuple(up_length)}
1896  {
1897  }
1898 
1899  __host__ __device__ static constexpr index_t GetNumOfLowerDimension() { return 1; }
1900 
1901  __host__ __device__ static constexpr index_t GetNumOfUpperDimension() { return 1; }
1902 
1903  __host__ __device__ constexpr const auto& GetUpperLengths() const { return up_lengths_; }
1904 
1905  template <typename LowIdx, typename UpIdx>
1906  __host__ __device__ constexpr void CalculateLowerIndex(LowIdx& idx_low,
1907  const UpIdx& idx_up) const
1908  {
1909  static_assert(LowIdx::Size() == 1 && UpIdx::Size() == 1,
1910  "wrong! inconsistent # of dimension");
1911 
1912  idx_low(Number<0>{}) = vector_size_ * idx_up[Number<0>{}];
1913  }
1914 
1915  template <typename LowIdxDiff,
1916  typename UpIdxDiff,
1917  typename LowIdx,
1918  typename UpIdx,
1919  index_t Hack>
1920  __host__ __device__ void UpdateLowerIndex(LowIdxDiff& idx_diff_low,
1921  const UpIdxDiff& idx_diff_up,
1922  LowIdx& idx_low,
1923  const UpIdx&,
1924  Number<Hack>) const
1925  {
1926  static_assert(LowIdxDiff::Size() == 1 && UpIdxDiff::Size() == 1 && LowIdx::Size() == 1 &&
1927  UpIdx::Size() == 1,
1928  "wrong! inconsistent # of dimension");
1929 
1930  constexpr auto I0 = Number<0>{};
1931 
1932  idx_diff_low(I0) = vector_size_ * idx_diff_up[I0];
1933 
1934  idx_low += idx_diff_low;
1935  }
1936 
1937  __host__ __device__ static constexpr bool IsLinearTransform() { return true; }
1938 
1939  __host__ __device__ static constexpr bool IsValidUpperIndexAlwaysMappedToValidLowerIndex()
1940  {
1941  return true;
1942  }
1943 
1944  template <typename UpIdx>
1945  __host__ __device__ static constexpr bool
1946  IsValidUpperIndexMappedToValidLowerIndex(const UpIdx& /* idx_up */)
1947  {
1948  return true;
1949  }
1950 
1951  __host__ __device__ static constexpr bool IsKnownAtCompileTime()
1952  {
1954  }
1955 
1956  __host__ __device__ void Print() const
1957  {
1958  printf("{");
1959  printf("Vectorize, ");
1960  printf("up_lengths_");
1961  print_multi_index(up_lengths_);
1962  printf("}");
1963  }
1964 };
1965 
1966 template <typename LowLength, typename SliceBegin, typename SliceEnd>
1967 struct Slice
1968 {
1971 
1972  using UpLengths = decltype(make_tuple(SliceEnd{} - SliceBegin{}));
1973 
1975  SliceBegin slice_begin_;
1976  SliceEnd slice_end_;
1977 
1978  __host__ __device__ constexpr Slice() = default;
1979 
1980  __host__ __device__ constexpr Slice(const LowLength&,
1981  const SliceBegin& slice_begin,
1982  const SliceEnd& slice_end)
1983  : up_lengths_{make_tuple(slice_end - slice_begin)},
1984  slice_begin_{slice_begin},
1985  slice_end_{slice_end}
1986  {
1987  }
1988 
1989  __host__ __device__ static constexpr index_t GetNumOfLowerDimension() { return 1; }
1990 
1991  __host__ __device__ static constexpr index_t GetNumOfUpperDimension() { return 1; }
1992 
1993  __host__ __device__ constexpr const auto& GetUpperLengths() const { return up_lengths_; }
1994 
1995  template <typename LowIdx, typename UpIdx>
1996  __host__ __device__ constexpr void CalculateLowerIndex(LowIdx& idx_low,
1997  const UpIdx& idx_up) const
1998  {
1999  static_assert(LowIdx::Size() == 1 && UpIdx::Size() == 1,
2000  "wrong! inconsistent # of dimension");
2001 
2002  idx_low(Number<0>{}) = idx_up[Number<0>{}] + slice_begin_;
2003  }
2004 
2005  template <typename LowIdxDiff,
2006  typename UpIdxDiff,
2007  typename LowIdx,
2008  typename UpIdx,
2009  index_t Hack>
2010  __host__ __device__ static void UpdateLowerIndex(LowIdxDiff& idx_diff_low,
2011  const UpIdxDiff& idx_diff_up,
2012  LowIdx& idx_low,
2013  const UpIdx&,
2014  Number<Hack>)
2015  {
2016  static_assert(LowIdxDiff::Size() == 1 && UpIdxDiff::Size() == 1 && LowIdx::Size() == 1 &&
2017  UpIdx::Size() == 1,
2018  "wrong! inconsistent # of dimension");
2019 
2020  constexpr auto I0 = Number<0>{};
2021 
2022  idx_diff_low(I0) = idx_diff_up[I0];
2023 
2024  idx_low += idx_diff_low;
2025  }
2026 
2027  __host__ __device__ static constexpr bool IsLinearTransform() { return true; }
2028 
2029  __host__ __device__ static constexpr bool IsValidUpperIndexAlwaysMappedToValidLowerIndex()
2030  {
2031  return true;
2032  }
2033 
2034  template <typename UpIdx>
2035  __host__ __device__ constexpr bool IsValidUpperIndexMappedToValidLowerIndex(const UpIdx&) const
2036  {
2037  return true;
2038  }
2039 
2040  __host__ __device__ static constexpr bool IsKnownAtCompileTime()
2041  {
2045  }
2046 
2047  __host__ __device__ void Print() const
2048  {
2049  printf("{");
2050  printf("Slice, ");
2051  printf("up_lengths_");
2052  print_multi_index(up_lengths_);
2053  printf("slice_begin_ %d", index_t{slice_begin_});
2054  printf("slice_end %d", index_t{slice_end_});
2055  printf("}");
2056  }
2057 };
2058 
2059 /*
2060  * \brief lower_idx = upper_idx % modulus.
2061  * TODO: Need an improved implementation since the modulo operation is expensive.
2062  */
2063 template <typename Modulus, typename UpLength>
2064 struct Modulo
2065 {
2068  using UpLengths = decltype(make_tuple(UpLength{}));
2069 
2070  Modulus modulus_;
2072 
2073  __host__ __device__ constexpr Modulo() = default;
2074 
2075  __host__ __device__ constexpr Modulo(const Modulus& modulus, const UpLength& up_length)
2076  : modulus_{modulus}, up_lengths_{make_tuple(up_length)}
2077  {
2078  }
2079 
2080  __host__ __device__ static constexpr index_t GetNumOfLowerDimension() { return 1; }
2081 
2082  __host__ __device__ static constexpr index_t GetNumOfUpperDimension() { return 1; }
2083 
2084  __host__ __device__ constexpr const auto& GetUpperLengths() const { return up_lengths_; }
2085 
2086  template <typename LowIdx, typename UpIdx>
2087  __host__ __device__ constexpr void CalculateLowerIndex(LowIdx& idx_low,
2088  const UpIdx& idx_up) const
2089  {
2090  static_assert(LowIdx::Size() == 1 && UpIdx::Size() == 1,
2091  "wrong! inconsistent # of dimension");
2092 
2093  idx_low(Number<0>{}) = idx_up[Number<0>{}] % modulus_;
2094  }
2095 
2096  template <typename LowIdxDiff,
2097  typename UpIdxDiff,
2098  typename LowIdx,
2099  typename UpIdx,
2100  index_t Hack>
2101  __host__ __device__ void UpdateLowerIndex(LowIdxDiff& idx_diff_low,
2102  const UpIdxDiff& idx_diff_up,
2103  LowIdx& idx_low,
2104  const UpIdx& up_idx,
2105  Number<Hack>) const
2106  {
2107  static_assert(LowIdxDiff::Size() == 1 && UpIdxDiff::Size() == 1 && LowIdx::Size() == 1 &&
2108  UpIdx::Size() == 1,
2109  "wrong! inconsistent # of dimension");
2110 
2111  constexpr auto I0 = Number<0>{};
2112 
2113  const auto idx_low_old = idx_low;
2114  idx_low(I0) = (up_idx(I0) + idx_diff_up(I0)) % modulus_;
2115  idx_diff_low(I0) = idx_low - idx_low_old;
2116  }
2117 
2118  __host__ __device__ static constexpr bool IsLinearTransform() { return false; }
2119 
2120  __host__ __device__ static constexpr bool IsValidUpperIndexAlwaysMappedToValidLowerIndex()
2121  {
2122  return true;
2123  }
2124 
2125  template <typename UpIdx>
2126  __host__ __device__ static constexpr bool
2127  IsValidUpperIndexMappedToValidLowerIndex(const UpIdx& /* idx_up */)
2128  {
2129  return true;
2130  }
2131 
2132  __host__ __device__ static constexpr bool IsKnownAtCompileTime()
2133  {
2135  }
2136 
2137  __host__ __device__ void Print() const
2138  {
2139  printf("{");
2140  printf("Modulus, ");
2141  printf("up_lengths_");
2142  print_multi_index(up_lengths_);
2143  printf("}");
2144  }
2145 };
2146 
2147 template <typename LowLengths, bool ApplyModulo>
2148 struct Xor
2149 {
2152 
2153  using UpLengths = LowLengths;
2154 
2156 
2157  __host__ __device__ constexpr Xor() : up_lengths_{} {}
2158 
2159  __host__ __device__ constexpr Xor(const LowLengths& low_lengths) : up_lengths_{low_lengths} {}
2160 
2161  __host__ __device__ static constexpr index_t GetNumOfLowerDimension() { return 2; }
2162 
2163  __host__ __device__ static constexpr index_t GetNumOfUpperDimension() { return 2; }
2164 
2165  __host__ __device__ constexpr const auto& GetUpperLengths() const { return up_lengths_; }
2166 
2167  template <typename LowIdx, typename UpIdx>
2168  __host__ __device__ constexpr void CalculateLowerIndex(LowIdx& idx_low,
2169  const UpIdx& idx_up) const
2170  {
2171  static_assert(LowIdx::Size() == 2 && UpIdx::Size() == 2,
2172  "wrong! inconsistent # of dimension");
2173 
2174  idx_low(Number<0>{}) = idx_up[Number<0>{}];
2175 
2176  if constexpr(ApplyModulo)
2177  {
2178  idx_low(Number<1>{}) =
2179  idx_up[Number<1>{}] ^ (idx_up[Number<0>{}] % up_lengths_[Number<1>{}]);
2180  }
2181  else
2182  {
2183  idx_low(Number<1>{}) = idx_up[Number<1>{}] ^ idx_up[Number<0>{}];
2184  }
2185  }
2186 
2187  template <typename LowIdxDiff,
2188  typename UpIdxDiff,
2189  typename LowIdx,
2190  typename UpIdx,
2191  index_t Hack>
2192  __host__ __device__ void UpdateLowerIndex(LowIdxDiff& idx_diff_low,
2193  const UpIdxDiff&,
2194  LowIdx& idx_low,
2195  const UpIdx& idx_up,
2196  Number<Hack>) const
2197  {
2198  static_assert(LowIdxDiff::Size() == 2 && UpIdxDiff::Size() == 2 && LowIdx::Size() == 2 &&
2199  UpIdx::Size() == 2,
2200  "wrong! inconsistent # of dimension");
2201 
2202  const auto idx_low_old = idx_low;
2203 
2204  CalculateLowerIndex(idx_low, idx_up);
2205 
2206  idx_diff_low = idx_low - idx_low_old;
2207  }
2208 
2209  __host__ __device__ static constexpr bool IsValidUpperIndexAlwaysMappedToValidLowerIndex()
2210  {
2211  return true;
2212  }
2213 
2214  template <typename UpIdx>
2215  __host__ __device__ static constexpr bool
2216  IsValidUpperIndexMappedToValidLowerIndex(const UpIdx& /* idx_up */)
2217  {
2218  return true;
2219  }
2220 
2221  __host__ __device__ static constexpr bool IsKnownAtCompileTime()
2222  {
2224  }
2225 
2226  __host__ __device__ void Print() const
2227  {
2228  printf("Xor{");
2229 
2230  //
2231  printf("up_lengths_: ");
2232  print(up_lengths_);
2233  printf(", ");
2234 
2235  printf("}");
2236  }
2237 };
2238 } // namespace ck
__host__ __device__ multiplies() -> multiplies< void, void >
FIXME: create macro to replace 'host device' and nothing more.
Definition: ck.hpp:268
__host__ constexpr __device__ auto generate_tuple(F &&f, Number< N >)
Definition: tuple_helper.hpp:21
std::enable_if< B, T > enable_if
Definition: enable_if.hpp:24
__host__ constexpr __device__ auto container_reverse_exclusive_scan(const Array< TData, NSize > &x, Reduce f, TData init)
Definition: container_helper.hpp:213
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
int32_t index_t
Definition: ck.hpp:299
__host__ constexpr __device__ auto container_reduce(const Container &x, Reduce reduce, Init init, Number< IBegin >=Number< 0 >{}, Number< IEnd >=Number< Container::Size()>{}, Number< IStep >=Number< 1 >{})
Definition: container_helper.hpp:111
__host__ __device__ void print_multi_index(const Tuple< Xs... > &x)
Definition: statically_indexed_array_multi_index.hpp:147
Definition: array.hpp:14
Transformation struct for convolution backward data output indices to GEMM indices.
Definition: multi_index_transform.hpp:1565
__host__ constexpr __device__ void CalculateLowerIndex(LowIdx &idx_low, const UpIdx &idx_up) const
Definition: multi_index_transform.hpp:1694
__host__ static constexpr __device__ bool IsLinearTransform()
Definition: multi_index_transform.hpp:1716
index_t KPad_
Definition: multi_index_transform.hpp:1581
__host__ static constexpr __device__ index_t GetNumOfLowerDimension()
Definition: multi_index_transform.hpp:1639
__host__ constexpr __device__ auto CalculateLowerIndexN(const UpIdx &idx_up) const
Definition: multi_index_transform.hpp:1646
__host__ static constexpr __device__ bool IsValidUpperIndexAlwaysMappedToValidLowerIndex()
Definition: multi_index_transform.hpp:1718
__host__ __device__ void Print() const
Definition: multi_index_transform.hpp:1738
index_t IHTildeSliceBegin_
Definition: multi_index_transform.hpp:1578
Tuple< index_t, index_t, index_t > up_lengths_
Definition: multi_index_transform.hpp:1582
index_t Ho_
Definition: multi_index_transform.hpp:1574
__host__ static constexpr __device__ index_t GetNumOfUpperDimension()
Definition: multi_index_transform.hpp:1641
Tuple< index_t, index_t, index_t, index_t > low_lengths_magic_divisor_multiplier_
Definition: multi_index_transform.hpp:1585
__host__ static constexpr __device__ bool IsKnownAtCompileTime()
Definition: multi_index_transform.hpp:1736
Tuple< index_t, index_t, index_t, index_t > low_lengths_magic_divisor_shift_
Definition: multi_index_transform.hpp:1587
index_t XDot_
Definition: multi_index_transform.hpp:1575
index_t XDotSlice_K_
Definition: multi_index_transform.hpp:1580
__host__ constexpr __device__ auto CalculateLowerIndexK(const UpIdx &idx_up) const
Definition: multi_index_transform.hpp:1668
index_t HRatio_
Definition: multi_index_transform.hpp:1579
__host__ constexpr __device__ const auto & GetUpperLengths() const
Definition: multi_index_transform.hpp:1643
index_t TildeSlice_
Definition: multi_index_transform.hpp:1577
index_t HTilde_
Definition: multi_index_transform.hpp:1576
__host__ constexpr __device__ bool IsValidUpperIndexMappedToValidLowerIndex(const UpIdx &idx_up) const
Definition: multi_index_transform.hpp:1725
__host__ constexpr __device__ ConvBwdDataImplicitGemmOutTransform(index_t N, index_t Ho, index_t Wo, index_t K, index_t XDot, index_t HTilde, index_t WTilde, index_t WTildeSlice, index_t HWTildeSlice, index_t IHTildeSliceBegin, index_t IWTildeSliceBegin, index_t HRatio, index_t WRatio, index_t XDotSlice_K, index_t K0, index_t MPadded, index_t K1, index_t MPad, index_t KPad)
Definition: multi_index_transform.hpp:1591
__host__ __device__ void UpdateLowerIndex(LowIdxDiff &idx_diff_low, const UpIdxDiff &, LowIdx &idx_low, const UpIdx &idx_up, Number< Hack >) const
Definition: multi_index_transform.hpp:1705
__host__ constexpr __device__ ConvBwdDataImplicitGemmOutTransform()=default
Definition: multi_index_transform.hpp:385
static constexpr index_t NDimUp
Definition: multi_index_transform.hpp:386
__host__ constexpr __device__ const auto & GetUpperLengths() const
Definition: multi_index_transform.hpp:406
__host__ static constexpr __device__ bool IsValidUpperIndexMappedToValidLowerIndex(const UpIdx &)
Definition: multi_index_transform.hpp:454
__host__ static constexpr __device__ index_t GetNumOfUpperDimension()
Definition: multi_index_transform.hpp:404
Coefficients coefficients_
Definition: multi_index_transform.hpp:392
__host__ static constexpr __device__ bool IsLinearTransform()
Definition: multi_index_transform.hpp:445
__host__ constexpr __device__ Embed()=default
__host__ constexpr __device__ void CalculateLowerIndex(LowIdx &idx_low, const UpIdx &idx_up) const
Definition: multi_index_transform.hpp:409
UpLengths up_lengths_
Definition: multi_index_transform.hpp:391
__host__ static constexpr __device__ bool IsKnownAtCompileTime()
Definition: multi_index_transform.hpp:459
__host__ constexpr __device__ Embed(const UpLengths &up_lengths, const Coefficients &coefficients)
Definition: multi_index_transform.hpp:396
__host__ __device__ void Print() const
Definition: multi_index_transform.hpp:465
__host__ static constexpr __device__ bool IsValidUpperIndexAlwaysMappedToValidLowerIndex()
Definition: multi_index_transform.hpp:447
__host__ __device__ void UpdateLowerIndex(LowIdxDiff &idx_diff_low, const UpIdxDiff &idx_diff_up, LowIdx &idx_low, const UpIdx &, Number< Hack >) const
Definition: multi_index_transform.hpp:427
__host__ static constexpr __device__ index_t GetNumOfLowerDimension()
Definition: multi_index_transform.hpp:402
Definition: multi_index_transform.hpp:1750
__host__ static constexpr __device__ bool IsKnownAtCompileTime()
Definition: multi_index_transform.hpp:1801
LowerIndex low_idx_
Definition: multi_index_transform.hpp:1751
__host__ static constexpr __device__ index_t GetNumOfLowerDimension()
Definition: multi_index_transform.hpp:1757
__host__ static constexpr __device__ bool IsLinearTransform()
Definition: multi_index_transform.hpp:1787
__host__ static constexpr __device__ auto GetUpperLengths()
Definition: multi_index_transform.hpp:1761
__host__ static __device__ void UpdateLowerIndex(LowIdxDiff &idx_diff_low, const UpIdxDiff &, LowIdx &, const UpIdx &, Number< Hack >)
Definition: multi_index_transform.hpp:1778
__host__ static constexpr __device__ index_t GetNumOfUpperDimension()
Definition: multi_index_transform.hpp:1759
__host__ __device__ void Print() const
Definition: multi_index_transform.hpp:1806
__host__ static constexpr __device__ bool IsValidUpperIndexMappedToValidLowerIndex(const UpIdx &)
Definition: multi_index_transform.hpp:1796
__host__ constexpr __device__ void CalculateLowerIndex(LowIdx &idx_low, const UpIdx &) const
Definition: multi_index_transform.hpp:1764
__host__ static constexpr __device__ bool IsValidUpperIndexAlwaysMappedToValidLowerIndex()
Definition: multi_index_transform.hpp:1789
__host__ constexpr __device__ Freeze(const LowerIndex &low_idx)
Definition: multi_index_transform.hpp:1755
__host__ constexpr __device__ Freeze()=default
Definition: multi_index_transform.hpp:1816
__host__ static constexpr __device__ index_t GetNumOfUpperDimension()
Definition: multi_index_transform.hpp:1830
__host__ constexpr __device__ Insert()=default
__host__ constexpr __device__ void CalculateLowerIndex(LowIdx &, const UpIdx &) const
Definition: multi_index_transform.hpp:1835
__host__ static constexpr __device__ index_t GetNumOfLowerDimension()
Definition: multi_index_transform.hpp:1828
__host__ constexpr __device__ Insert(const UpperLength &up_length)
Definition: multi_index_transform.hpp:1823
__host__ static constexpr __device__ bool IsKnownAtCompileTime()
Definition: multi_index_transform.hpp:1868
__host__ static constexpr __device__ bool IsLinearTransform()
Definition: multi_index_transform.hpp:1854
__host__ __device__ void Print() const
Definition: multi_index_transform.hpp:1873
UpLengths up_lengths_
Definition: multi_index_transform.hpp:1819
decltype(make_tuple(UpperLength{})) UpLengths
Definition: multi_index_transform.hpp:1817
__host__ static constexpr __device__ bool IsValidUpperIndexMappedToValidLowerIndex(const UpIdx &)
Definition: multi_index_transform.hpp:1863
__host__ static __device__ void UpdateLowerIndex(LowIdxDiff &, const UpIdxDiff &, LowIdx &, const UpIdx &, Number< Hack >)
Definition: multi_index_transform.hpp:1847
__host__ constexpr __device__ auto GetUpperLengths() const
Definition: multi_index_transform.hpp:1832
__host__ static constexpr __device__ bool IsValidUpperIndexAlwaysMappedToValidLowerIndex()
Definition: multi_index_transform.hpp:1856
Definition: multi_index_transform.hpp:196
__host__ static constexpr __device__ index_t GetNumOfUpperDimension()
Definition: multi_index_transform.hpp:215
__host__ constexpr __device__ LeftPad(const LowLength &low_length, const LeftPadLength &left_pad_length)
Definition: multi_index_transform.hpp:207
__host__ static constexpr __device__ bool IsLinearTransform()
Definition: multi_index_transform.hpp:251
__host__ static constexpr __device__ bool IsKnownAtCompileTime()
Definition: multi_index_transform.hpp:265
__host__ static constexpr __device__ index_t GetNumOfLowerDimension()
Definition: multi_index_transform.hpp:213
__host__ constexpr __device__ void CalculateLowerIndex(LowIdx &idx_low, const UpIdx &idx_up) const
Definition: multi_index_transform.hpp:220
decltype(make_tuple(LowLength{}+LeftPadLength{})) UpLengths
Definition: multi_index_transform.hpp:200
__host__ constexpr __device__ LeftPad()=default
LeftPadLength left_pad_length_
Definition: multi_index_transform.hpp:203
__host__ static __device__ void UpdateLowerIndex(LowIdxDiff &idx_diff_low, const UpIdxDiff &idx_diff_up, LowIdx &idx_low, const UpIdx &, Number< Hack >)
Definition: multi_index_transform.hpp:234
__host__ __device__ void Print() const
Definition: multi_index_transform.hpp:271
__host__ constexpr __device__ const auto & GetUpperLengths() const
Definition: multi_index_transform.hpp:217
UpLengths up_lengths_
Definition: multi_index_transform.hpp:202
__host__ constexpr __device__ bool IsValidUpperIndexMappedToValidLowerIndex(const UpIdx &idx_up) const
Definition: multi_index_transform.hpp:260
__host__ static constexpr __device__ bool IsValidUpperIndexAlwaysMappedToValidLowerIndex()
Definition: multi_index_transform.hpp:253
Definition: magic_division.hpp:27
Definition: multi_index_transform.hpp:481
__host__ __device__ void UpdateLowerIndex_2(LowIdxDiff &idx_diff_low, const UpIdxDiff &idx_diff_up, LowIdx &idx_low, const UpIdx &, Number< Hack >) const
Definition: multi_index_transform.hpp:822
__host__ static constexpr __device__ bool IsLinearTransform()
Definition: multi_index_transform.hpp:967
LowLengths low_lengths_
Definition: multi_index_transform.hpp:493
__host__ static constexpr __device__ index_t GetNumOfLowerDimension()
Definition: multi_index_transform.hpp:508
__host__ __device__ void UpdateLowerIndex_1b(LowIdxDiff &idx_diff_low, const UpIdxDiff &idx_diff_up, LowIdx &idx_low, const UpIdx &, Number< Hack >) const
Definition: multi_index_transform.hpp:680
__host__ static constexpr __device__ bool IsValidUpperIndexAlwaysMappedToValidLowerIndex()
Definition: multi_index_transform.hpp:969
UpLengths up_lengths_
Definition: multi_index_transform.hpp:495
static constexpr index_t NDimLow
Definition: multi_index_transform.hpp:482
__host__ static constexpr __device__ bool IsValidUpperIndexMappedToValidLowerIndex(const UpIdx &)
Definition: multi_index_transform.hpp:983
__host__ constexpr __device__ void CalculateLowerIndex(LowIdx &idx_low, const UpIdx &idx_up) const
Definition: multi_index_transform.hpp:515
decltype(make_tuple(container_reduce(LowLengths{}, math::multiplies{}, Number< 1 >{}))) UpLengths
Definition: multi_index_transform.hpp:491
__host__ constexpr __device__ Merge_v1_carry_check(const LowLengths &low_lengths)
Definition: multi_index_transform.hpp:499
__host__ static constexpr __device__ bool IsKnownAtCompileTime()
Definition: multi_index_transform.hpp:974
__host__ static constexpr __device__ index_t GetNumOfUpperDimension()
Definition: multi_index_transform.hpp:510
decltype(container_reverse_exclusive_scan(LowLengths{}, math::multiplies{}, Number< 1 >{})) LowLengthsScan
Definition: multi_index_transform.hpp:488
__host__ constexpr __device__ Merge_v1_carry_check()=default
__host__ __device__ void Print() const
Definition: multi_index_transform.hpp:988
__host__ constexpr __device__ const auto & GetUpperLengths() const
Definition: multi_index_transform.hpp:512
__host__ __device__ void UpdateLowerIndex(LowIdxDiff &idx_diff_low, const UpIdxDiff &idx_diff_up, LowIdx &idx_low, const UpIdx &idx_up_new, Number< Hack >) const
Definition: multi_index_transform.hpp:952
LowLengthsScan low_lengths_scan_
Definition: multi_index_transform.hpp:494
__host__ __device__ void UpdateLowerIndex_1a(LowIdxDiff &idx_diff_low, const UpIdxDiff &idx_diff_up, LowIdx &idx_low, const UpIdx &, Number< Hack >) const
Definition: multi_index_transform.hpp:537
Definition: multi_index_transform.hpp:1036
LowLengthsMagicDivisorShift low_lengths_magic_divisor_shift_
Definition: multi_index_transform.hpp:1055
__host__ static constexpr __device__ bool IsValidUpperIndexAlwaysMappedToValidLowerIndex()
Definition: multi_index_transform.hpp:1138
UpLengths up_lengths_
Definition: multi_index_transform.hpp:1056
__host__ constexpr __device__ const auto & GetUpperLengths() const
Definition: multi_index_transform.hpp:1077
__host__ static constexpr __device__ index_t GetNumOfLowerDimension()
Definition: multi_index_transform.hpp:1073
__host__ static constexpr __device__ index_t GetNumOfUpperDimension()
Definition: multi_index_transform.hpp:1075
LowLengthsMagicDivisorMultipiler low_lengths_magic_divisor_multiplier_
Definition: multi_index_transform.hpp:1054
decltype(make_tuple(container_reduce(LowLengths{}, math::multiplies{}, Number< 1 >{}))) UpLengths
Definition: multi_index_transform.hpp:1043
LowLengths low_lengths_
Definition: multi_index_transform.hpp:1053
__host__ static constexpr __device__ bool IsLinearTransform()
Definition: multi_index_transform.hpp:1136
__host__ constexpr __device__ void CalculateLowerIndex(LowIdx &idx_low, const UpIdx &idx_up) const
Definition: multi_index_transform.hpp:1080
__host__ __device__ void UpdateLowerIndex(LowIdxDiff &idx_diff_low, const UpIdxDiff &, LowIdx &idx_low, const UpIdx &idx_up_new, Number< Hack >) const
Definition: multi_index_transform.hpp:1105
__host__ static constexpr __device__ bool IsValidUpperIndexMappedToValidLowerIndex(const UpIdx &)
Definition: multi_index_transform.hpp:1153
__host__ static constexpr __device__ bool IsKnownAtCompileTime()
Definition: multi_index_transform.hpp:1143
decltype(generate_tuple(lambda_merge_generate_MagicDivision_calculate_magic_multiplier< LowLengths >{}, Number< NDimLow >{})) LowLengthsMagicDivisorMultipiler
Definition: multi_index_transform.hpp:1047
__host__ constexpr __device__ Merge_v2_magic_division(const LowLengths &low_lengths)
Definition: multi_index_transform.hpp:1060
decltype(generate_tuple(lambda_merge_generate_MagicDivision_calculate_magic_shift< LowLengths >{}, Number< NDimLow >{})) LowLengthsMagicDivisorShift
Definition: multi_index_transform.hpp:1051
__host__ __device__ void Print() const
Definition: multi_index_transform.hpp:1158
__host__ constexpr __device__ Merge_v2_magic_division()=default
Definition: multi_index_transform.hpp:1188
LowLengths low_lengths_
Definition: multi_index_transform.hpp:1208
__host__ __device__ void Print() const
Definition: multi_index_transform.hpp:1315
__host__ constexpr __device__ Merge_v2r2_magic_division()=default
decltype(container_reverse_exclusive_scan(LowLengths{}, math::multiplies{}, Number< 1 >{})) LowLengthsScan
Definition: multi_index_transform.hpp:1195
__host__ static constexpr __device__ bool IsValidUpperIndexAlwaysMappedToValidLowerIndex()
Definition: multi_index_transform.hpp:1295
decltype(make_tuple(container_reduce(LowLengths{}, math::multiplies{}, Number< 1 >{}))) UpLengths
Definition: multi_index_transform.hpp:1198
LowLengthsScanMagicDivisorShift low_lengths_scan_magic_divisor_shift_
Definition: multi_index_transform.hpp:1211
UpLengths up_lengths_
Definition: multi_index_transform.hpp:1212
__host__ static constexpr __device__ index_t GetNumOfUpperDimension()
Definition: multi_index_transform.hpp:1233
LowLengthsScanMagicDivisorMultipiler low_lengths_scan_magic_divisor_multiplier_
Definition: multi_index_transform.hpp:1210
__host__ static constexpr __device__ bool IsLinearTransform()
Definition: multi_index_transform.hpp:1293
__host__ constexpr __device__ Merge_v2r2_magic_division(const LowLengths &low_lengths)
Definition: multi_index_transform.hpp:1216
decltype(generate_tuple(lambda_merge_generate_MagicDivision_calculate_magic_shift< LowLengthsScan >{}, Number< NDimLow >{})) LowLengthsScanMagicDivisorShift
Definition: multi_index_transform.hpp:1206
decltype(generate_tuple(lambda_merge_generate_MagicDivision_calculate_magic_multiplier< LowLengthsScan >{}, Number< NDimLow >{})) LowLengthsScanMagicDivisorMultipiler
Definition: multi_index_transform.hpp:1202
__host__ __device__ void UpdateLowerIndex(LowIdxDiff &idx_diff_low, const UpIdxDiff &, LowIdx &idx_low, const UpIdx &idx_up_new, Number< Hack >) const
Definition: multi_index_transform.hpp:1263
LowLengthsScan low_lengths_scan_
Definition: multi_index_transform.hpp:1209
__host__ constexpr __device__ const auto & GetUpperLengths() const
Definition: multi_index_transform.hpp:1235
__host__ static constexpr __device__ bool IsValidUpperIndexMappedToValidLowerIndex(const UpIdx &)
Definition: multi_index_transform.hpp:1310
__host__ static constexpr __device__ index_t GetNumOfLowerDimension()
Definition: multi_index_transform.hpp:1231
__host__ static constexpr __device__ bool IsKnownAtCompileTime()
Definition: multi_index_transform.hpp:1300
__host__ constexpr __device__ void CalculateLowerIndex(LowIdx &idx_low, const UpIdx &idx_up) const
Definition: multi_index_transform.hpp:1238
Definition: multi_index_transform.hpp:1338
UpLengths up_lengths_
Definition: multi_index_transform.hpp:1352
__host__ __device__ void UpdateLowerIndex(LowIdxDiff &idx_diff_low, const UpIdxDiff &, LowIdx &idx_low, const UpIdx &idx_up_new, Number< Hack >) const
Definition: multi_index_transform.hpp:1394
__host__ __device__ void Print() const
Definition: multi_index_transform.hpp:1442
__host__ static constexpr __device__ index_t GetNumOfLowerDimension()
Definition: multi_index_transform.hpp:1365
__host__ static constexpr __device__ bool IsValidUpperIndexMappedToValidLowerIndex(const UpIdx &)
Definition: multi_index_transform.hpp:1437
__host__ static constexpr __device__ bool IsKnownAtCompileTime()
Definition: multi_index_transform.hpp:1428
LowLengthsScan low_lengths_scan_
Definition: multi_index_transform.hpp:1351
__host__ constexpr __device__ Merge_v3_division_mod(const LowLengths &low_lengths)
Definition: multi_index_transform.hpp:1356
LowLengths low_lengths_
Definition: multi_index_transform.hpp:1350
__host__ static constexpr __device__ index_t GetNumOfUpperDimension()
Definition: multi_index_transform.hpp:1367
__host__ constexpr __device__ void CalculateLowerIndex(LowIdx &idx_low, const UpIdx &idx_up) const
Definition: multi_index_transform.hpp:1372
decltype(make_tuple(container_reduce(LowLengths{}, math::multiplies{}, Number< 1 >{}))) UpLengths
Definition: multi_index_transform.hpp:1348
__host__ constexpr __device__ const auto & GetUpperLengths() const
Definition: multi_index_transform.hpp:1369
__host__ constexpr __device__ Merge_v3_division_mod()=default
__host__ static constexpr __device__ bool IsValidUpperIndexAlwaysMappedToValidLowerIndex()
Definition: multi_index_transform.hpp:1423
decltype(container_reverse_exclusive_scan(LowLengths{}, math::multiplies{}, Number< 1 >{})) LowLengthsScan
Definition: multi_index_transform.hpp:1345
__host__ static constexpr __device__ bool IsLinearTransform()
Definition: multi_index_transform.hpp:1421
Definition: multi_index_transform.hpp:2065
__host__ static constexpr __device__ bool IsValidUpperIndexMappedToValidLowerIndex(const UpIdx &)
Definition: multi_index_transform.hpp:2127
__host__ constexpr __device__ Modulo(const Modulus &modulus, const UpLength &up_length)
Definition: multi_index_transform.hpp:2075
Modulus modulus_
Definition: multi_index_transform.hpp:2070
__host__ __device__ void Print() const
Definition: multi_index_transform.hpp:2137
__host__ static constexpr __device__ bool IsValidUpperIndexAlwaysMappedToValidLowerIndex()
Definition: multi_index_transform.hpp:2120
__host__ static constexpr __device__ index_t GetNumOfUpperDimension()
Definition: multi_index_transform.hpp:2082
__host__ __device__ void UpdateLowerIndex(LowIdxDiff &idx_diff_low, const UpIdxDiff &idx_diff_up, LowIdx &idx_low, const UpIdx &up_idx, Number< Hack >) const
Definition: multi_index_transform.hpp:2101
__host__ constexpr __device__ const auto & GetUpperLengths() const
Definition: multi_index_transform.hpp:2084
__host__ static constexpr __device__ bool IsKnownAtCompileTime()
Definition: multi_index_transform.hpp:2132
__host__ static constexpr __device__ bool IsLinearTransform()
Definition: multi_index_transform.hpp:2118
decltype(make_tuple(UpLength{})) UpLengths
Definition: multi_index_transform.hpp:2068
__host__ static constexpr __device__ index_t GetNumOfLowerDimension()
Definition: multi_index_transform.hpp:2080
__host__ constexpr __device__ Modulo()=default
UpLengths up_lengths_
Definition: multi_index_transform.hpp:2071
__host__ constexpr __device__ void CalculateLowerIndex(LowIdx &idx_low, const UpIdx &idx_up) const
Definition: multi_index_transform.hpp:2087
Definition: multi_index_transform.hpp:100
__host__ static __device__ void UpdateLowerIndex(LowIdxDiff &idx_diff_low, const UpIdxDiff &idx_diff_up, LowIdx &idx_low, const UpIdx &, Number< Hack >)
Definition: multi_index_transform.hpp:142
LeftPadLength left_pad_length_
Definition: multi_index_transform.hpp:107
__host__ constexpr __device__ Pad()=default
__host__ constexpr __device__ const auto & GetUpperLengths() const
Definition: multi_index_transform.hpp:125
__host__ constexpr __device__ void CalculateLowerIndex(LowIdx &idx_low, const UpIdx &idx_up) const
Definition: multi_index_transform.hpp:128
decltype(make_tuple(LowLength{}+LeftPadLength{}+RightPadLength{})) UpLengths
Definition: multi_index_transform.hpp:104
__host__ static constexpr __device__ bool IsLinearTransform()
Definition: multi_index_transform.hpp:159
__host__ constexpr __device__ Pad(const LowLength &low_length, const LeftPadLength &left_pad_length, const RightPadLength &right_pad_length)
Definition: multi_index_transform.hpp:112
__host__ static constexpr __device__ bool IsValidUpperIndexAlwaysMappedToValidLowerIndex()
Definition: multi_index_transform.hpp:161
__host__ static constexpr __device__ bool IsKnownAtCompileTime()
Definition: multi_index_transform.hpp:175
__host__ static constexpr __device__ index_t GetNumOfUpperDimension()
Definition: multi_index_transform.hpp:123
UpLengths up_lengths_
Definition: multi_index_transform.hpp:106
RightPadLength right_pad_length_
Definition: multi_index_transform.hpp:108
__host__ constexpr __device__ bool IsValidUpperIndexMappedToValidLowerIndex(const UpIdx &idx_up) const
Definition: multi_index_transform.hpp:168
__host__ __device__ void Print() const
Definition: multi_index_transform.hpp:182
__host__ static constexpr __device__ index_t GetNumOfLowerDimension()
Definition: multi_index_transform.hpp:121
Definition: multi_index_transform.hpp:13
__host__ constexpr __device__ PassThrough(const LowLength &low_length)
Definition: multi_index_transform.hpp:23
__host__ static constexpr __device__ index_t GetNumOfUpperDimension()
Definition: multi_index_transform.hpp:30
__host__ static constexpr __device__ bool IsValidUpperIndexAlwaysMappedToValidLowerIndex()
Definition: multi_index_transform.hpp:68
__host__ constexpr __device__ const auto & GetUpperLengths() const
Definition: multi_index_transform.hpp:32
UpLengths up_lengths_
Definition: multi_index_transform.hpp:19
__host__ constexpr __device__ PassThrough()=default
__host__ static constexpr __device__ bool IsLinearTransform()
Definition: multi_index_transform.hpp:66
__host__ static __device__ void UpdateLowerIndex(LowIdxDiff &idx_diff_low, const UpIdxDiff &idx_diff_up, LowIdx &idx_low, const UpIdx &, Number< Hack >)
Definition: multi_index_transform.hpp:49
__host__ __device__ void Print() const
Definition: multi_index_transform.hpp:85
__host__ static constexpr __device__ bool IsValidUpperIndexMappedToValidLowerIndex(const UpIdx &)
Definition: multi_index_transform.hpp:75
decltype(make_tuple(LowLength{})) UpLengths
Definition: multi_index_transform.hpp:17
__host__ static constexpr __device__ index_t GetNumOfLowerDimension()
Definition: multi_index_transform.hpp:28
__host__ static constexpr __device__ bool IsKnownAtCompileTime()
Definition: multi_index_transform.hpp:80
__host__ static constexpr __device__ void CalculateLowerIndex(LowIdx &idx_low, const UpIdx &idx_up)
Definition: multi_index_transform.hpp:35
Definition: multi_index_transform.hpp:284
__host__ static constexpr __device__ void CalculateLowerIndex(LowIdx &idx_low, const UpIdx &idx_up)
Definition: multi_index_transform.hpp:311
__host__ constexpr __device__ RightPad()=default
__host__ static constexpr __device__ bool IsLinearTransform()
Definition: multi_index_transform.hpp:342
__host__ constexpr __device__ bool IsValidUpperIndexMappedToValidLowerIndex(const UpIdx &idx_up) const
Definition: multi_index_transform.hpp:351
__host__ constexpr __device__ const auto & GetUpperLengths() const
Definition: multi_index_transform.hpp:308
decltype(make_tuple(LowLength{}+RightPadLength{})) UpLengths
Definition: multi_index_transform.hpp:288
__host__ constexpr __device__ RightPad(const LowLength &low_length, const RightPadLength &right_pad_length)
Definition: multi_index_transform.hpp:296
UpLengths up_lengths_
Definition: multi_index_transform.hpp:290
__host__ static constexpr __device__ bool IsKnownAtCompileTime()
Definition: multi_index_transform.hpp:356
__host__ static constexpr __device__ index_t GetNumOfLowerDimension()
Definition: multi_index_transform.hpp:304
__host__ static __device__ void UpdateLowerIndex(LowIdxDiff &idx_diff_low, const UpIdxDiff &idx_diff_up, LowIdx &idx_low, const UpIdx &, Number< Hack >)
Definition: multi_index_transform.hpp:325
__host__ static constexpr __device__ index_t GetNumOfUpperDimension()
Definition: multi_index_transform.hpp:306
__host__ __device__ void Print() const
Definition: multi_index_transform.hpp:363
LowLength low_length_
Definition: multi_index_transform.hpp:291
__host__ static constexpr __device__ bool IsValidUpperIndexAlwaysMappedToValidLowerIndex()
Definition: multi_index_transform.hpp:344
RightPadLength right_pad_length_
Definition: multi_index_transform.hpp:292
Definition: multi_index_transform.hpp:1968
SliceEnd slice_end_
Definition: multi_index_transform.hpp:1976
decltype(make_tuple(SliceEnd{} - SliceBegin{})) UpLengths
Definition: multi_index_transform.hpp:1972
__host__ static constexpr __device__ index_t GetNumOfLowerDimension()
Definition: multi_index_transform.hpp:1989
__host__ __device__ void Print() const
Definition: multi_index_transform.hpp:2047
UpLengths up_lengths_
Definition: multi_index_transform.hpp:1974
SliceBegin slice_begin_
Definition: multi_index_transform.hpp:1975
__host__ constexpr __device__ bool IsValidUpperIndexMappedToValidLowerIndex(const UpIdx &) const
Definition: multi_index_transform.hpp:2035
__host__ constexpr __device__ Slice(const LowLength &, const SliceBegin &slice_begin, const SliceEnd &slice_end)
Definition: multi_index_transform.hpp:1980
__host__ constexpr __device__ void CalculateLowerIndex(LowIdx &idx_low, const UpIdx &idx_up) const
Definition: multi_index_transform.hpp:1996
__host__ static constexpr __device__ bool IsValidUpperIndexAlwaysMappedToValidLowerIndex()
Definition: multi_index_transform.hpp:2029
__host__ constexpr __device__ const auto & GetUpperLengths() const
Definition: multi_index_transform.hpp:1993
__host__ static __device__ void UpdateLowerIndex(LowIdxDiff &idx_diff_low, const UpIdxDiff &idx_diff_up, LowIdx &idx_low, const UpIdx &, Number< Hack >)
Definition: multi_index_transform.hpp:2010
__host__ constexpr __device__ Slice()=default
__host__ static constexpr __device__ index_t GetNumOfUpperDimension()
Definition: multi_index_transform.hpp:1991
__host__ static constexpr __device__ bool IsKnownAtCompileTime()
Definition: multi_index_transform.hpp:2040
__host__ static constexpr __device__ bool IsLinearTransform()
Definition: multi_index_transform.hpp:2027
Definition: tuple.hpp:186
Definition: multi_index_transform.hpp:1458
UpLengths up_lengths_
Definition: multi_index_transform.hpp:1467
__host__ static constexpr __device__ bool IsValidUpperIndexMappedToValidLowerIndex(const UpIdx &)
Definition: multi_index_transform.hpp:1533
__host__ constexpr __device__ const auto & GetUpperLengths() const
Definition: multi_index_transform.hpp:1483
__host__ __device__ void Print() const
Definition: multi_index_transform.hpp:1544
UpLengthsScan up_lengths_scan_
Definition: multi_index_transform.hpp:1468
__host__ static constexpr __device__ bool IsLinearTransform()
Definition: multi_index_transform.hpp:1524
__host__ constexpr __device__ UnMerge()=default
__host__ static constexpr __device__ bool IsKnownAtCompileTime()
Definition: multi_index_transform.hpp:1538
__host__ constexpr __device__ UnMerge(const UpLengths &up_lengths)
Definition: multi_index_transform.hpp:1472
__host__ constexpr __device__ void CalculateLowerIndex(LowIdx &idx_low, const UpIdx &idx_up) const
Definition: multi_index_transform.hpp:1486
__host__ __device__ void UpdateLowerIndex(LowIdxDiff &idx_diff_low, const UpIdxDiff &idx_diff_up, LowIdx &idx_low, const UpIdx &, Number< Hack >) const
Definition: multi_index_transform.hpp:1513
__host__ static constexpr __device__ index_t GetNumOfLowerDimension()
Definition: multi_index_transform.hpp:1479
__host__ static constexpr __device__ index_t GetNumOfUpperDimension()
Definition: multi_index_transform.hpp:1481
__host__ static constexpr __device__ bool IsValidUpperIndexAlwaysMappedToValidLowerIndex()
Definition: multi_index_transform.hpp:1526
decltype(container_reverse_exclusive_scan(UpLengths{}, math::multiplies{}, Number< 1 >{})) UpLengthsScan
Definition: multi_index_transform.hpp:1465
Definition: multi_index_transform.hpp:1882
__host__ constexpr __device__ const auto & GetUpperLengths() const
Definition: multi_index_transform.hpp:1903
__host__ static constexpr __device__ bool IsValidUpperIndexAlwaysMappedToValidLowerIndex()
Definition: multi_index_transform.hpp:1939
__host__ static constexpr __device__ bool IsKnownAtCompileTime()
Definition: multi_index_transform.hpp:1951
__host__ __device__ void UpdateLowerIndex(LowIdxDiff &idx_diff_low, const UpIdxDiff &idx_diff_up, LowIdx &idx_low, const UpIdx &, Number< Hack >) const
Definition: multi_index_transform.hpp:1920
__host__ constexpr __device__ Vectorize()=default
__host__ static constexpr __device__ index_t GetNumOfUpperDimension()
Definition: multi_index_transform.hpp:1901
__host__ static constexpr __device__ index_t GetNumOfLowerDimension()
Definition: multi_index_transform.hpp:1899
VectorSize vector_size_
Definition: multi_index_transform.hpp:1889
__host__ constexpr __device__ Vectorize(const VectorSize &vector_size, const UpLength &up_length)
Definition: multi_index_transform.hpp:1893
decltype(make_tuple(UpLength{})) UpLengths
Definition: multi_index_transform.hpp:1886
UpLengths up_lengths_
Definition: multi_index_transform.hpp:1888
__host__ constexpr __device__ void CalculateLowerIndex(LowIdx &idx_low, const UpIdx &idx_up) const
Definition: multi_index_transform.hpp:1906
__host__ __device__ void Print() const
Definition: multi_index_transform.hpp:1956
__host__ static constexpr __device__ bool IsLinearTransform()
Definition: multi_index_transform.hpp:1937
__host__ static constexpr __device__ bool IsValidUpperIndexMappedToValidLowerIndex(const UpIdx &)
Definition: multi_index_transform.hpp:1946
Definition: multi_index_transform.hpp:2149
__host__ constexpr __device__ const auto & GetUpperLengths() const
Definition: multi_index_transform.hpp:2165
__host__ constexpr __device__ void CalculateLowerIndex(LowIdx &idx_low, const UpIdx &idx_up) const
Definition: multi_index_transform.hpp:2168
__host__ static constexpr __device__ index_t GetNumOfUpperDimension()
Definition: multi_index_transform.hpp:2163
__host__ constexpr __device__ Xor()
Definition: multi_index_transform.hpp:2157
UpLengths up_lengths_
Definition: multi_index_transform.hpp:2155
__host__ __device__ void UpdateLowerIndex(LowIdxDiff &idx_diff_low, const UpIdxDiff &, LowIdx &idx_low, const UpIdx &idx_up, Number< Hack >) const
Definition: multi_index_transform.hpp:2192
__host__ static constexpr __device__ bool IsValidUpperIndexAlwaysMappedToValidLowerIndex()
Definition: multi_index_transform.hpp:2209
__host__ static constexpr __device__ bool IsValidUpperIndexMappedToValidLowerIndex(const UpIdx &)
Definition: multi_index_transform.hpp:2216
__host__ __device__ void Print() const
Definition: multi_index_transform.hpp:2226
__host__ static constexpr __device__ bool IsKnownAtCompileTime()
Definition: multi_index_transform.hpp:2221
__host__ static constexpr __device__ index_t GetNumOfLowerDimension()
Definition: multi_index_transform.hpp:2161
__host__ constexpr __device__ Xor(const LowLengths &low_lengths)
Definition: multi_index_transform.hpp:2159
LowLengths UpLengths
Definition: multi_index_transform.hpp:2153
Definition: integral_constant.hpp:20
Definition: is_known_at_compile_time.hpp:14
__host__ constexpr __device__ auto operator()(Number< I > i) const
Definition: multi_index_transform.hpp:1006
Definition: multi_index_transform.hpp:1014
__host__ constexpr __device__ auto operator()(Number< I > i) const
Definition: multi_index_transform.hpp:1016
Definition: math.hpp:34
Definition: functional2.hpp:33