/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/core/utility/reduce_operator.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/core/utility/reduce_operator.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/core/utility/reduce_operator.hpp Source File
reduce_operator.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
7 
8 namespace ck_tile {
9 
10 namespace ReduceOp {
11 // y = ReduceOp(y, x);
12 struct Add
13 {
14  template <typename T>
16  {
17  return type_convert<T>(0.0f);
18  };
19 
20  template <typename T,
21  typename = std::enable_if_t<std::is_same_v<T, float> || std::is_same_v<T, double> ||
22  std::is_same_v<T, int32_t> || std::is_same_v<T, int8_t>>>
23  CK_TILE_HOST_DEVICE constexpr T operator()(const T& y, const T x) const
24  {
25  return y + x;
26  }
27 
28  template <typename T,
29  typename = std::enable_if_t<std::is_same_v<T, half_t> || std::is_same_v<T, bf16_t> ||
30  std::is_same_v<T, fp8_t> || std::is_same_v<T, bf8_t>>>
31  CK_TILE_HOST_DEVICE constexpr T operator()(T& y, T x) const
32  {
33  float y_ = type_convert<float>(y);
34  float x_ = type_convert<float>(x);
35 
36  return type_convert<T>(y_ + x_);
37  }
38 };
39 
40 struct SquareAdd
41 {
42  template <typename T>
44  {
45  return type_convert<T>(0.0f);
46  };
47 
48  template <typename T,
49  typename = std::enable_if_t<std::is_same_v<T, float> || std::is_same_v<T, double> ||
50  std::is_same_v<T, int32_t> || std::is_same_v<T, int8_t>>>
51  CK_TILE_HOST_DEVICE constexpr T operator()(const T& y, const T x) const
52  {
53  return y + (x * x);
54  }
55 
56  template <typename T,
57  typename = std::enable_if_t<std::is_same_v<T, half_t> || std::is_same_v<T, bf16_t> ||
58  std::is_same_v<T, fp8_t> || std::is_same_v<T, bf8_t>>>
59  CK_TILE_HOST_DEVICE constexpr T operator()(T& y, T x) const
60  {
61  float y_ = type_convert<float>(y);
62  float x_ = type_convert<float>(x);
63  return type_convert<T>(y_ + (x_ * x_));
64  }
65 };
66 
67 struct Max
68 {
69  template <typename T,
70  typename = std::enable_if_t<std::is_same_v<T, float> || std::is_same_v<T, double> ||
71  std::is_same_v<T, int32_t> || std::is_same_v<T, int8_t> ||
72  std::is_same_v<T, half_t> || std::is_same_v<T, bf16_t> ||
73  std::is_same_v<T, fp8_t> || std::is_same_v<T, bf8_t>>>
75  {
76  return numeric<T>::min();
77  };
78 
79  template <typename T,
80  typename = std::enable_if_t<std::is_same_v<T, float> || std::is_same_v<T, double> ||
81  std::is_same_v<T, int32_t> || std::is_same_v<T, int8_t> ||
82  std::is_same_v<T, half_t> || std::is_same_v<T, bf16_t> ||
83  std::is_same_v<T, fp8_t> || std::is_same_v<T, bf8_t>>>
84  CK_TILE_HOST_DEVICE constexpr T operator()(const T& y, const T x) const
85  {
86  return max(y, x);
87  }
88 };
89 
90 struct AbsMax
91 {
92  template <typename T,
93  typename = std::enable_if_t<std::is_same_v<T, float> || std::is_same_v<T, double> ||
94  std::is_same_v<T, int32_t> || std::is_same_v<T, int8_t> ||
95  std::is_same_v<T, half_t> || std::is_same_v<T, bf16_t> ||
96  std::is_same_v<T, fp8_t> || std::is_same_v<T, bf8_t>>>
98  {
99  return numeric<T>::min();
100  };
101 
102  template <typename T,
103  typename = std::enable_if_t<std::is_same_v<T, float> || std::is_same_v<T, double> ||
104  std::is_same_v<T, int32_t> || std::is_same_v<T, int8_t> ||
105  std::is_same_v<T, half_t> || std::is_same_v<T, bf16_t> ||
106  std::is_same_v<T, fp8_t> || std::is_same_v<T, bf8_t>>>
107  CK_TILE_HOST_DEVICE constexpr T operator()(const T& y, const T x) const
108  {
109  return max(y, abs(x));
110  }
111 };
112 
113 } // namespace ReduceOp
114 } // namespace ck_tile
#define CK_TILE_HOST_DEVICE
Definition: config.hpp:42
Definition: cluster_descriptor.hpp:13
CK_TILE_HOST_DEVICE bfloat16_t abs(const bfloat16_t &x)
Definition: bfloat16.hpp:404
constexpr CK_TILE_HOST_DEVICE T max(T x)
Definition: math.hpp:161
Definition: reduce_operator.hpp:91
constexpr CK_TILE_HOST_DEVICE T operator()(const T &y, const T x) const
Definition: reduce_operator.hpp:107
static constexpr CK_TILE_HOST_DEVICE T GetIdentityValue()
Definition: reduce_operator.hpp:97
Definition: reduce_operator.hpp:13
constexpr CK_TILE_HOST_DEVICE T operator()(T &y, T x) const
Definition: reduce_operator.hpp:31
static constexpr CK_TILE_HOST_DEVICE T GetIdentityValue()
Definition: reduce_operator.hpp:15
constexpr CK_TILE_HOST_DEVICE T operator()(const T &y, const T x) const
Definition: reduce_operator.hpp:23
Definition: reduce_operator.hpp:68
static constexpr CK_TILE_HOST_DEVICE T GetIdentityValue()
Definition: reduce_operator.hpp:74
constexpr CK_TILE_HOST_DEVICE T operator()(const T &y, const T x) const
Definition: reduce_operator.hpp:84
Definition: reduce_operator.hpp:41
constexpr CK_TILE_HOST_DEVICE T operator()(const T &y, const T x) const
Definition: reduce_operator.hpp:51
constexpr CK_TILE_HOST_DEVICE T operator()(T &y, T x) const
Definition: reduce_operator.hpp:59
static constexpr CK_TILE_HOST_DEVICE T GetIdentityValue()
Definition: reduce_operator.hpp:43
static constexpr CK_TILE_HOST_DEVICE T min()
Definition: numeric.hpp:20