/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_description/multi_index_transform_helper.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_description/multi_index_transform_helper.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_description/multi_index_transform_helper.hpp Source File
multi_index_transform_helper.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
8 
9 namespace ck {
10 
11 template <typename LowLength>
12 __host__ __device__ constexpr auto make_pass_through_transform(const LowLength& low_length)
13 {
14  return PassThrough<LowLength>{low_length};
15 }
16 
17 template <typename LowLength, typename LeftPad, typename RightPad, bool SkipIsValidCheck = false>
18 __host__ __device__ constexpr auto
19 make_pad_transform(const LowLength& low_length,
20  const LeftPad& left_pad,
21  const RightPad& right_pad,
23 {
24  return Pad<LowLength, LeftPad, RightPad, SkipIsValidCheck>{low_length, left_pad, right_pad};
25 }
26 
27 template <typename LowLength, typename LeftPadLength, bool SkipIsValidCheck = false>
28 __host__ __device__ constexpr auto make_left_pad_transform(
29  const LowLength& low_length,
30  const LeftPadLength& left_pad,
32 {
33  return LeftPad<LowLength, LeftPadLength, SkipIsValidCheck>{low_length, left_pad};
34 }
35 
36 template <typename LowLength, typename RightPadLength, bool SkipIsValidCheck = false>
37 __host__ __device__ constexpr auto make_right_pad_transform(
38  const LowLength& low_length,
39  const RightPadLength& right_pad,
41 {
42  return RightPad<LowLength, RightPadLength, SkipIsValidCheck>{low_length, right_pad};
43 }
44 
45 template <typename UpLengths,
46  typename Coefficients,
47  typename enable_if<UpLengths::Size() == Coefficients::Size(), bool>::type = false>
48 __host__ __device__ constexpr auto make_embed_transform(const UpLengths& up_lengths,
49  const Coefficients& coefficients)
50 {
51  return Embed<UpLengths, Coefficients>{up_lengths, coefficients};
52 }
53 
54 template <typename LowLengths>
55 __host__ __device__ constexpr auto make_merge_transform(const LowLengths& low_lengths)
56 {
57 #if CK_EXPERIMENTAL_MERGE_USE_MAGIC_DIVISION
58  return make_merge_transform_v2_magic_division(low_lengths);
59 #else
60  return make_merge_transform_v1_carry_check(low_lengths);
61 #endif
62 }
63 
64 template <typename LowLengths>
65 __host__ __device__ constexpr auto
66 make_merge_transform_v1_carry_check(const LowLengths& low_lengths)
67 {
68  return Merge_v1_carry_check<LowLengths>{low_lengths};
69 }
70 
71 template <typename LowLengths>
72 __host__ __device__ constexpr auto
73 make_merge_transform_v2_magic_division(const LowLengths& low_lengths)
74 {
75 #if 1
76  return Merge_v2_magic_division<LowLengths>{low_lengths};
77 #else
78  return Merge_v2r2_magic_division<LowLengths>{low_lengths};
79 #endif
80 }
81 
82 template <typename LowLengths>
83 __host__ __device__ constexpr auto
84 make_merge_transform_v3_division_mod(const LowLengths& low_lengths)
85 {
86  return Merge_v3_division_mod<LowLengths>{low_lengths};
87 }
88 
89 template <typename UpLengths, bool Use24BitIntegerCalculation = false>
90 __host__ __device__ constexpr auto make_unmerge_transform(
91  const UpLengths& up_lengths,
93 {
94  return UnMerge<UpLengths, Use24BitIntegerCalculation>{up_lengths};
95 }
96 
97 template <typename LowerIndex>
98 __host__ __device__ constexpr auto make_freeze_transform(const LowerIndex& low_idx)
99 {
100  return Freeze<LowerIndex>{low_idx};
101 }
102 
103 template <typename UpperIndex>
104 __host__ __device__ constexpr auto make_insert_transform(const UpperIndex& up_idx)
105 {
106  return Insert<UpperIndex>{up_idx};
107 }
108 
109 template <typename LowLength, typename SliceBegin, typename SliceEnd>
110 __host__ __device__ constexpr auto make_slice_transform(const LowLength& low_length,
111  const SliceBegin& slice_begin,
112  const SliceEnd& slice_end)
113 {
114  return Slice<LowLength, SliceBegin, SliceEnd>{low_length, slice_begin, slice_end};
115 }
116 
117 template <typename VectorSize, typename UpLength>
118 __host__ __device__ constexpr auto make_vectorize_transform(const VectorSize& vector_size,
119  const UpLength& up_length)
120 {
121  return Vectorize<VectorSize, UpLength>{vector_size, up_length};
122 }
123 
124 template <typename Modulus, typename UpLength>
125 __host__ __device__ constexpr auto make_modulo_transform(const Modulus& modulus,
126  const UpLength& up_length)
127 {
128  return Modulo<Modulus, UpLength>{modulus, up_length};
129 }
130 
131 template <typename LowLengths>
132 __host__ __device__ constexpr auto make_xor_with_modulo_transform(const LowLengths& low_lengths)
133 {
134  return Xor<LowLengths, true /*ApplyModulo*/>{low_lengths};
135 }
136 
137 template <typename LowLengths>
138 __host__ __device__ constexpr auto make_xor_transform(const LowLengths& low_lengths)
139 {
140  return Xor<LowLengths, false /*ApplyModulo*/>{low_lengths};
141 }
142 } // namespace ck
Definition: ck.hpp:266
__host__ constexpr __device__ auto make_left_pad_transform(const LowLength &low_length, const LeftPadLength &left_pad, integral_constant< bool, SkipIsValidCheck >=integral_constant< bool, false >{})
Definition: multi_index_transform_helper.hpp:28
__host__ constexpr __device__ auto make_xor_transform(const LowLengths &low_lengths)
Definition: multi_index_transform_helper.hpp:138
__host__ constexpr __device__ auto make_merge_transform_v2_magic_division(const LowLengths &low_lengths)
Definition: multi_index_transform_helper.hpp:73
__host__ constexpr __device__ auto make_merge_transform_v1_carry_check(const LowLengths &low_lengths)
Definition: multi_index_transform_helper.hpp:66
__host__ constexpr __device__ auto make_vectorize_transform(const VectorSize &vector_size, const UpLength &up_length)
Definition: multi_index_transform_helper.hpp:118
__host__ constexpr __device__ auto make_merge_transform(const LowLengths &low_lengths)
Definition: multi_index_transform_helper.hpp:55
__host__ constexpr __device__ auto make_merge_transform_v3_division_mod(const LowLengths &low_lengths)
Definition: multi_index_transform_helper.hpp:84
__host__ constexpr __device__ auto make_freeze_transform(const LowerIndex &low_idx)
Definition: multi_index_transform_helper.hpp:98
__host__ constexpr __device__ auto make_embed_transform(const UpLengths &up_lengths, const Coefficients &coefficients)
Definition: multi_index_transform_helper.hpp:48
__host__ constexpr __device__ auto make_xor_with_modulo_transform(const LowLengths &low_lengths)
Definition: multi_index_transform_helper.hpp:132
__host__ constexpr __device__ auto make_slice_transform(const LowLength &low_length, const SliceBegin &slice_begin, const SliceEnd &slice_end)
Definition: multi_index_transform_helper.hpp:110
std::enable_if< B, T > enable_if
Definition: enable_if.hpp:24
__host__ constexpr __device__ auto make_pass_through_transform(const LowLength &low_length)
Definition: multi_index_transform_helper.hpp:12
__host__ constexpr __device__ auto make_insert_transform(const UpperIndex &up_idx)
Definition: multi_index_transform_helper.hpp:104
__host__ constexpr __device__ auto make_modulo_transform(const Modulus &modulus, const UpLength &up_length)
Definition: multi_index_transform_helper.hpp:125
__host__ constexpr __device__ auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition: multi_index_transform_helper.hpp:90
__host__ constexpr __device__ auto make_pad_transform(const LowLength &low_length, const LeftPad &left_pad, const RightPad &right_pad, integral_constant< bool, SkipIsValidCheck >=integral_constant< bool, false >{})
Definition: multi_index_transform_helper.hpp:19
__host__ constexpr __device__ auto make_right_pad_transform(const LowLength &low_length, const RightPadLength &right_pad, integral_constant< bool, SkipIsValidCheck >=integral_constant< bool, false >{})
Definition: multi_index_transform_helper.hpp:37
Definition: multi_index_transform.hpp:385
Definition: multi_index_transform.hpp:1558
Definition: multi_index_transform.hpp:1624
Definition: multi_index_transform.hpp:196
Definition: multi_index_transform.hpp:481
Definition: multi_index_transform.hpp:1036
Definition: multi_index_transform.hpp:1188
Definition: multi_index_transform.hpp:1338
Definition: multi_index_transform.hpp:1873
Definition: multi_index_transform.hpp:13
Definition: multi_index_transform.hpp:284
Definition: multi_index_transform.hpp:1776
Definition: multi_index_transform.hpp:1690
Definition: multi_index_transform.hpp:1957
Definition: integral_constant.hpp:20