/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/element/combined_element_wise_operation.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/element/combined_element_wise_operation.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/element/combined_element_wise_operation.hpp Source File
combined_element_wise_operation.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
7 
8 namespace ck {
9 namespace tensor_operation {
10 namespace element_wise {
11 
12 // y = UnaryOp0(UnaryOp1(...(x)))
13 template <typename... UnaryOpsSet>
15 {
16  __host__ __device__ UnaryCombinedOp() : unary_ops_() {}
17 
18  __host__ __device__ UnaryCombinedOp(UnaryOpsSet... unary_ops) : unary_ops_(unary_ops...) {}
19 
20  template <typename Y, typename X>
21  __host__ __device__ void operator()(Y& y, const X& x) const
22  {
23  // Execute first unary op to copy data to y
24  unary_ops_.At(Number<0>{})(y, x);
25 
26  static_for<1, Tuple<UnaryOpsSet...>::Size(), 1>{}([&](auto i) { unary_ops_.At(i)(y, y); });
27  };
28 
29  Tuple<UnaryOpsSet...> unary_ops_;
30 };
31 
32 // y = BinaryOp(UnaryOp0(x0), UnaryOp1(x1))
33 template <typename BinaryOp, typename UnaryOp0, typename UnaryOp1>
35 {
36  __host__ __device__ BinaryWithUnaryCombinedOp() : binary_op_(), unary_op0_(), unary_op1_() {}
37 
38  __host__ __device__ BinaryWithUnaryCombinedOp(BinaryOp binary_op,
39  UnaryOp0 unary_op0,
40  UnaryOp1 unary_op1)
41  : binary_op_(binary_op), unary_op0_(unary_op0), unary_op1_(unary_op1)
42  {
43  }
44 
45  template <typename Y, typename X0, typename X1>
46  __host__ __device__ void operator()(Y& y, const X0& x0, const X1& x1) const
47  {
48  Y unary_x0_tmp_result;
49  Y unary_x1_tmp_result;
50  unary_op0_(unary_x0_tmp_result, x0);
51  unary_op1_(unary_x1_tmp_result, x1);
52  binary_op_(y, unary_x0_tmp_result, unary_x1_tmp_result);
53  };
54 
55  private:
56  BinaryOp binary_op_;
57  UnaryOp0 unary_op0_;
58  UnaryOp1 unary_op1_;
59 };
60 
61 // y = BinaryOp0(BinaryOp1(UnaryOp0(x0), UnaryOp1(x1)), UnaryOp2(x2))
62 template <typename BinaryOp0,
63  typename BinaryOp1,
64  typename UnaryOp0,
65  typename UnaryOp1,
66  typename UnaryOp2>
68 {
69  __host__ __device__ TrinaryWithUnaryCombinedOp()
70  : binary_op0_(), binary_op1_(), unary_op0_(), unary_op1_(), unary_op2_()
71  {
72  }
73 
74  __host__ __device__ TrinaryWithUnaryCombinedOp(BinaryOp0 binary_op0,
75  BinaryOp0 binary_op1,
76  UnaryOp0 unary_op0,
77  UnaryOp1 unary_op1,
78  UnaryOp2 unary_op2)
79  : binary_op0_(binary_op0),
80  binary_op1_(binary_op1),
81  unary_op0_(unary_op0),
82  unary_op1_(unary_op1),
83  unary_op2_(unary_op2)
84  {
85  }
86 
87  template <typename Y, typename X0, typename X1, typename X2>
88  __host__ __device__ void operator()(Y& y, const X0& x0, const X1& x1, const X2& x2) const
89  {
90 
91  Y unary_x0_tmp_result;
92  Y unary_x1_tmp_result;
93  Y unary_x2_tmp_result;
94  unary_op0_(unary_x0_tmp_result, x0);
95  unary_op1_(unary_x1_tmp_result, x1);
96  unary_op2_(unary_x2_tmp_result, x2);
97  binary_op0_(unary_x0_tmp_result, unary_x0_tmp_result, unary_x1_tmp_result);
98  binary_op1_(y, unary_x0_tmp_result, unary_x2_tmp_result);
99  };
100 
101  private:
102  BinaryOp0 binary_op0_{};
103  BinaryOp1 binary_op1_{};
104  UnaryOp0 unary_op0_{};
105  UnaryOp1 unary_op1_{};
106  UnaryOp2 unary_op2_{};
107 };
108 
111 
112 } // namespace element_wise
113 } // namespace tensor_operation
114 } // namespace ck
Definition: ck.hpp:267
Definition: tuple.hpp:117
Definition: integral_constant.hpp:20
Definition: functional2.hpp:33
Definition: combined_element_wise_operation.hpp:35
__host__ __device__ BinaryWithUnaryCombinedOp(BinaryOp binary_op, UnaryOp0 unary_op0, UnaryOp1 unary_op1)
Definition: combined_element_wise_operation.hpp:38
__host__ __device__ BinaryWithUnaryCombinedOp()
Definition: combined_element_wise_operation.hpp:36
__host__ __device__ void operator()(Y &y, const X0 &x0, const X1 &x1) const
Definition: combined_element_wise_operation.hpp:46
Definition: combined_element_wise_operation.hpp:68
__host__ __device__ TrinaryWithUnaryCombinedOp(BinaryOp0 binary_op0, BinaryOp0 binary_op1, UnaryOp0 unary_op0, UnaryOp1 unary_op1, UnaryOp2 unary_op2)
Definition: combined_element_wise_operation.hpp:74
__host__ __device__ void operator()(Y &y, const X0 &x0, const X1 &x1, const X2 &x2) const
Definition: combined_element_wise_operation.hpp:88
__host__ __device__ TrinaryWithUnaryCombinedOp()
Definition: combined_element_wise_operation.hpp:69
Definition: combined_element_wise_operation.hpp:15
__host__ __device__ UnaryCombinedOp()
Definition: combined_element_wise_operation.hpp:16
Tuple< UnaryOpsSet... > unary_ops_
Definition: combined_element_wise_operation.hpp:27
__host__ __device__ UnaryCombinedOp(UnaryOpsSet... unary_ops)
Definition: combined_element_wise_operation.hpp:18
__host__ __device__ void operator()(Y &y, const X &x) const
Definition: combined_element_wise_operation.hpp:21