/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/utility/math.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/utility/math.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/utility/math.hpp Source File
math.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include "ck/ck.hpp"
7 #include "integral_constant.hpp"
8 #include "number.hpp"
9 #include "type.hpp"
10 #include "enable_if.hpp"
11 
12 namespace ck {
13 namespace math {
14 
15 template <typename T, T s>
16 struct scales
17 {
18  __host__ __device__ constexpr T operator()(T a) const { return s * a; }
19 };
20 
21 template <typename T>
22 struct plus
23 {
24  __host__ __device__ constexpr T operator()(T a, T b) const { return a + b; }
25 };
26 
27 template <typename T>
28 struct minus
29 {
30  __host__ __device__ constexpr T operator()(T a, T b) const { return a - b; }
31 };
32 
33 struct multiplies
34 {
35  template <typename A, typename B>
36  __host__ __device__ constexpr auto operator()(const A& a, const B& b) const
37  {
38  return a * b;
39  }
40 };
41 
42 template <typename T>
43 struct maximize
44 {
45  __host__ __device__ constexpr T operator()(T a, T b) const { return a >= b ? a : b; }
46 };
47 
48 template <typename T>
49 struct minimize
50 {
51  __host__ __device__ constexpr T operator()(T a, T b) const { return a <= b ? a : b; }
52 };
53 
54 template <typename T>
56 {
57  __host__ __device__ constexpr T operator()(T a, T b) const
58  {
59  static_assert(is_same<T, index_t>{} || is_same<T, int>{}, "wrong type");
60 
61  return (a + b - Number<1>{}) / b;
62  }
63 };
64 
65 template <typename X, typename Y>
66 __host__ __device__ constexpr auto integer_divide_floor(X x, Y y)
67 {
68  return x / y;
69 }
70 
71 template <typename X, typename Y>
72 __host__ __device__ constexpr auto integer_divide_ceil(X x, Y y)
73 {
74  return (x + y - Number<1>{}) / y;
75 }
76 
77 template <typename X, typename Y>
78 __host__ __device__ constexpr auto integer_least_multiple(X x, Y y)
79 {
80  return y * integer_divide_ceil(x, y);
81 }
82 
83 template <typename T>
84 __host__ __device__ constexpr T max(T x)
85 {
86  return x;
87 }
88 
89 template <typename T>
90 __host__ __device__ constexpr T max(T x, T y)
91 {
92  return x > y ? x : y;
93 }
94 
95 template <index_t X>
96 __host__ __device__ constexpr index_t max(Number<X>, index_t y)
97 {
98  return X > y ? X : y;
99 }
100 
101 template <index_t Y>
102 __host__ __device__ constexpr index_t max(index_t x, Number<Y>)
103 {
104  return x > Y ? x : Y;
105 }
106 
107 template <typename X, typename... Ys>
108 __host__ __device__ constexpr auto max(X x, Ys... ys)
109 {
110  static_assert(sizeof...(Ys) > 0, "not enough argument");
111 
112  return max(x, max(ys...));
113 }
114 
115 template <typename T>
116 __host__ __device__ constexpr T min(T x)
117 {
118  return x;
119 }
120 
121 template <typename T>
122 __host__ __device__ constexpr T min(T x, T y)
123 {
124  return x < y ? x : y;
125 }
126 
127 template <index_t X>
128 __host__ __device__ constexpr index_t min(Number<X>, index_t y)
129 {
130  return X < y ? X : y;
131 }
132 
133 template <index_t Y>
134 __host__ __device__ constexpr index_t min(index_t x, Number<Y>)
135 {
136  return x < Y ? x : Y;
137 }
138 
139 template <typename X, typename... Ys>
140 __host__ __device__ constexpr auto min(X x, Ys... ys)
141 {
142  static_assert(sizeof...(Ys) > 0, "not enough argument");
143 
144  return min(x, min(ys...));
145 }
146 
147 template <typename T>
148 __host__ __device__ constexpr T clamp(const T& x, const T& lowerbound, const T& upperbound)
149 {
150  return min(max(x, lowerbound), upperbound);
151 }
152 
153 // greatest common divisor, aka highest common factor
154 __host__ __device__ constexpr index_t gcd(index_t x, index_t y)
155 {
156  if(x < 0)
157  {
158  return gcd(-x, y);
159  }
160  else if(y < 0)
161  {
162  return gcd(x, -y);
163  }
164  else if(x == y || x == 0)
165  {
166  return y;
167  }
168  else if(y == 0)
169  {
170  return x;
171  }
172  else if(x > y)
173  {
174  return gcd(x % y, y);
175  }
176  else
177  {
178  return gcd(x, y % x);
179  }
180 }
181 
182 template <index_t X, index_t Y>
183 __host__ __device__ constexpr auto gcd(Number<X>, Number<Y>)
184 {
185  constexpr auto r = gcd(X, Y);
186 
187  return Number<r>{};
188 }
189 
190 template <typename X, typename... Ys, typename enable_if<sizeof...(Ys) >= 2, bool>::type = false>
191 __host__ __device__ constexpr auto gcd(X x, Ys... ys)
192 {
193  return gcd(x, gcd(ys...));
194 }
195 
196 // least common multiple
197 template <typename X, typename Y>
198 __host__ __device__ constexpr auto lcm(X x, Y y)
199 {
200  return (x * y) / gcd(x, y);
201 }
202 
203 template <typename X, typename... Ys, typename enable_if<sizeof...(Ys) >= 2, bool>::type = false>
204 __host__ __device__ constexpr auto lcm(X x, Ys... ys)
205 {
206  return lcm(x, lcm(ys...));
207 }
208 
209 template <typename T>
210 struct equal
211 {
212  __host__ __device__ constexpr bool operator()(T x, T y) const { return x == y; }
213 };
214 
215 template <typename T>
216 struct less
217 {
218  __host__ __device__ constexpr bool operator()(T x, T y) const { return x < y; }
219 };
220 
221 template <index_t X>
222 __host__ __device__ constexpr auto next_power_of_two()
223 {
224  // TODO: X need to be 2 ~ 0x7fffffff. 0, 1, or larger than 0x7fffffff will compile fail
225  constexpr index_t Y = 1 << (32 - __builtin_clz(X - 1));
226  return Y;
227 }
228 
229 template <index_t X>
230 __host__ __device__ constexpr auto next_power_of_two(Number<X> x)
231 {
232  // TODO: X need to be 2 ~ 0x7fffffff. 0, 1, or larger than 0x7fffffff will compile fail
233  constexpr index_t Y = 1 << (32 - __builtin_clz(x.value - 1));
234  return Number<Y>{};
235 }
236 
237 } // namespace math
238 } // namespace ck
__host__ constexpr __device__ auto next_power_of_two()
Definition: math.hpp:222
Y __host__ constexpr __device__ auto lcm(X x, Y y)
Definition: math.hpp:198
__host__ constexpr __device__ auto integer_least_multiple(X x, Y y)
Definition: math.hpp:78
__host__ constexpr __device__ auto integer_divide_ceil(X x, Y y)
Definition: math.hpp:72
__host__ constexpr __device__ T clamp(const T &x, const T &lowerbound, const T &upperbound)
Definition: math.hpp:148
__host__ constexpr __device__ auto integer_divide_floor(X x, Y y)
Definition: math.hpp:66
__host__ constexpr __device__ T max(T x)
Definition: math.hpp:84
__host__ constexpr __device__ T min(T x)
Definition: math.hpp:116
__host__ constexpr __device__ index_t gcd(index_t x, index_t y)
Definition: math.hpp:154
__host__ __device__ equal() -> equal< void, void >
FIXME: create macro to replace 'host device' and nothing more.
Definition: ck.hpp:267
std::enable_if< B, T > enable_if
Definition: enable_if.hpp:24
int32_t index_t
Definition: ck.hpp:298
const GenericPointer< typename T::ValueType > T2 T::AllocatorType & a
Definition: pointer.h:1249
Definition: integral_constant.hpp:20
static constexpr T value
Definition: integral_constant.hpp:21
Definition: type.hpp:177
Definition: math.hpp:56
__host__ constexpr __device__ T operator()(T a, T b) const
Definition: math.hpp:57
Definition: math.hpp:217
__host__ constexpr __device__ bool operator()(T x, T y) const
Definition: math.hpp:218
Definition: math.hpp:44
__host__ constexpr __device__ T operator()(T a, T b) const
Definition: math.hpp:45
Definition: math.hpp:50
__host__ constexpr __device__ T operator()(T a, T b) const
Definition: math.hpp:51
Definition: math.hpp:29
__host__ constexpr __device__ T operator()(T a, T b) const
Definition: math.hpp:30
Definition: math.hpp:34
__host__ constexpr __device__ auto operator()(const A &a, const B &b) const
Definition: math.hpp:36
Definition: math.hpp:23
__host__ constexpr __device__ T operator()(T a, T b) const
Definition: math.hpp:24
Definition: math.hpp:17
__host__ constexpr __device__ T operator()(T a) const
Definition: math.hpp:18