/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/elementwise/binary_elementwise_operation.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/elementwise/binary_elementwise_operation.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/elementwise/binary_elementwise_operation.hpp Source File
binary_elementwise_operation.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include "ck_tile/core.hpp"
7 
8 namespace ck_tile {
9 namespace element_wise {
10 
11 struct Add
12 {
13  template <typename Y, typename X0, typename X1>
14  __host__ __device__ constexpr void operator()(Y& y, const X0& x0, const X1& x1) const;
15 
16  template <>
17  __host__ __device__ constexpr void
18  operator()<float>(float& y, const float& x0, const float& x1) const
19  {
20  y = x0 + x1;
21  };
22 
23  template <>
24  __host__ __device__ constexpr void
25  operator()<double>(double& y, const double& x0, const double& x1) const
26  {
27  y = x0 + x1;
28  };
29 
30  template <>
31  __host__ __device__ constexpr void
32  operator()<float>(float& y, const float& x0, const half_t& x1) const
33  {
34  y = x0 + type_convert<half_t>(x1);
35  };
36 
37  template <>
38  __host__ __device__ constexpr void
39  operator()<half_t>(half_t& y, const float& x0, const float& x1) const
40  {
41  y = type_convert<half_t>(x0 + x1);
42  };
43 
44  template <>
45  __host__ __device__ constexpr void
46  operator()<half_t>(half_t& y, const float& x0, const half_t& x1) const
47  {
48  y = type_convert<half_t>(x0) + x1;
49  };
50 
51  template <>
52  __host__ __device__ constexpr void
53  operator()<half_t>(half_t& y, const half_t& x0, const half_t& x1) const
54  {
55  y = x0 + x1;
56  };
57 
58  template <>
59  __host__ __device__ constexpr void
60  operator()<float>(float& y, const float& x0, const bf16_t& x1) const
61  {
62  const float x1_tmp = type_convert<float>(x1);
63  y = x0 + x1_tmp;
64  }
65 
66  template <>
67  __host__ __device__ constexpr void
68  operator()<bf16_t>(bf16_t& y, const bf16_t& x0, const bf16_t& x1) const
69  {
70  const float x1_tmp = type_convert<float>(x0);
71  const float x2_tmp = type_convert<float>(x1);
72  const float y_tmp = x1_tmp + x2_tmp;
73  y = type_convert<bf16_t>(y_tmp);
74  }
75 
76  template <>
77  __host__ __device__ constexpr void
78  operator()<bf16_t>(bf16_t& y, const float& x0, const bf16_t& x1) const
79  {
80  const float x2_tmp = type_convert<float>(x1);
81  const float y_tmp = x0 + x2_tmp;
82  y = type_convert<bf16_t>(y_tmp);
83  }
84 
85  template <>
86  __host__ __device__ constexpr void
87  operator()<bf16_t>(bf16_t& y, const float& x0, const float& x1) const
88  {
89  const float y_tmp = x0 + x1;
90  y = type_convert<bf16_t>(y_tmp);
91  }
92 
93  template <>
94  __host__ __device__ constexpr void
95  operator()<int8_t>(int8_t& y, const int8_t& x0, const int8_t& x1) const
96  {
97  y = x0 + x1;
98  };
99 };
100 
101 } // namespace element_wise
102 } // namespace ck_tile
Definition: cluster_descriptor.hpp:13
int8_t int8_t
Definition: int8.hpp:20
bfloat16_t bf16_t
Definition: bfloat16.hpp:113
_Float16 half_t
Definition: half.hpp:111
Definition: binary_elementwise_operation.hpp:12
__host__ constexpr __device__ void operator()(Y &y, const X0 &x0, const X1 &x1) const