/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/core/utility/reduce_operator.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/core/utility/reduce_operator.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/core/utility/reduce_operator.hpp Source File
reduce_operator.hpp
Go to the documentation of this file.
1 // Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
2 // SPDX-License-Identifier: MIT
3 
4 #pragma once
5 
9 
10 namespace ck_tile {
11 
12 namespace ReduceOp {
13 // y = ReduceOp(y, x);
14 struct Add
15 {
16  template <typename T>
18  {
19  return type_convert<T>(0.0f);
20  };
21 
22  template <typename T,
24  CK_TILE_HOST_DEVICE constexpr T operator()(const T& y, const T x) const
25  {
26  return y + x;
27  }
28 
29  template <typename T,
31  CK_TILE_HOST_DEVICE constexpr T operator()(T& y, T x) const
32  {
33  float y_ = type_convert<float>(y);
34  float x_ = type_convert<float>(x);
35 
36  return type_convert<T>(y_ + x_);
37  }
38 
39  CK_TILE_HOST_DEVICE static constexpr auto GetAtomic()
40  {
42  }
43 };
44 
45 struct SquareAdd
46 {
47  template <typename T>
49  {
50  return type_convert<T>(0.0f);
51  };
52 
53  template <typename T,
55  CK_TILE_HOST_DEVICE constexpr T operator()(const T& y, const T x) const
56  {
57  return y + (x * x);
58  }
59 
60  template <typename T,
62  CK_TILE_HOST_DEVICE constexpr T operator()(T& y, T x) const
63  {
64  float y_ = type_convert<float>(y);
65  float x_ = type_convert<float>(x);
66  return type_convert<T>(y_ + (x_ * x_));
67  }
68 };
69 
70 struct Max
71 {
72  template <
73  typename T,
74  typename = std::enable_if_t<
77  {
78  return numeric<T>::lowest();
79  };
80 
81  template <
82  typename T,
83  typename = std::enable_if_t<
85  CK_TILE_HOST_DEVICE constexpr T operator()(const T& y, const T x) const
86  {
87  return max(y, x);
88  }
89 
90  // Overload with changed flag for index tracking
91  template <
92  typename T,
93  typename = std::enable_if_t<
95  CK_TILE_HOST_DEVICE constexpr T operator()(const T& y, const T x, bool& changed) const
96  {
97  T new_max = max(y, x);
98  if(x > y)
99  {
100  changed = true;
101  }
102  return new_max;
103  }
104 };
105 
106 struct AbsMax
107 {
108  template <
109  typename T,
110  typename = std::enable_if_t<
113  {
114  return numeric<T>::zero();
115  };
116 
117  template <
118  typename T,
119  typename = std::enable_if_t<
121  CK_TILE_HOST_DEVICE constexpr T operator()(const T& y, const T x) const
122  {
123  return max(y, abs(x));
124  }
125 
126  // Overload with changed flag for index tracking
127  template <
128  typename T,
129  typename = std::enable_if_t<
131  CK_TILE_HOST_DEVICE constexpr T operator()(const T& y, const T x, bool& changed) const
132  {
133  T new_max = max(y, abs(x));
134  if(abs(x) > y)
135  {
136  changed = true;
137  }
138  return new_max;
139  }
140 };
141 
142 } // namespace ReduceOp
143 } // namespace ck_tile
#define CK_TILE_HOST_DEVICE
Definition: config.hpp:46
Definition: cluster_descriptor.hpp:13
CK_TILE_HOST_DEVICE bfloat16_t abs(const bfloat16_t &x)
Definition: bfloat16.hpp:403
constexpr CK_TILE_HOST_DEVICE T max(T x)
Definition: math.hpp:157
typename std::enable_if< B, T >::type enable_if_t
Definition: enable_if.hpp:27
__device__ X atomic_add(X *p_dst, const X &x)
const GenericPointer< typename T::ValueType > T2 value
Definition: pointer.h:1697
Definition: reduce_operator.hpp:107
constexpr CK_TILE_HOST_DEVICE T operator()(const T &y, const T x, bool &changed) const
Definition: reduce_operator.hpp:131
constexpr CK_TILE_HOST_DEVICE T operator()(const T &y, const T x) const
Definition: reduce_operator.hpp:121
static constexpr CK_TILE_HOST_DEVICE T GetIdentityValue()
Definition: reduce_operator.hpp:112
Definition: reduce_operator.hpp:15
constexpr CK_TILE_HOST_DEVICE T operator()(T &y, T x) const
Definition: reduce_operator.hpp:31
static constexpr CK_TILE_HOST_DEVICE T GetIdentityValue()
Definition: reduce_operator.hpp:17
static constexpr CK_TILE_HOST_DEVICE auto GetAtomic()
Definition: reduce_operator.hpp:39
constexpr CK_TILE_HOST_DEVICE T operator()(const T &y, const T x) const
Definition: reduce_operator.hpp:24
Definition: reduce_operator.hpp:71
constexpr CK_TILE_HOST_DEVICE T operator()(const T &y, const T x, bool &changed) const
Definition: reduce_operator.hpp:95
static constexpr CK_TILE_HOST_DEVICE T GetIdentityValue()
Definition: reduce_operator.hpp:76
constexpr CK_TILE_HOST_DEVICE T operator()(const T &y, const T x) const
Definition: reduce_operator.hpp:85
Definition: reduce_operator.hpp:46
constexpr CK_TILE_HOST_DEVICE T operator()(const T &y, const T x) const
Definition: reduce_operator.hpp:55
constexpr CK_TILE_HOST_DEVICE T operator()(T &y, T x) const
Definition: reduce_operator.hpp:62
static constexpr CK_TILE_HOST_DEVICE T GetIdentityValue()
Definition: reduce_operator.hpp:48
Definition: type_traits.hpp:115
static constexpr CK_TILE_HOST_DEVICE T lowest()
Definition: numeric.hpp:23
static constexpr CK_TILE_HOST_DEVICE T zero()
Definition: numeric.hpp:58