/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/element/combined_element_wise_operation.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/element/combined_element_wise_operation.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/element/combined_element_wise_operation.hpp Source File
combined_element_wise_operation.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
7 
8 namespace ck {
9 namespace tensor_operation {
10 namespace element_wise {
11 
12 // y = UnaryOp0(UnaryOp1(...(x)))
13 template <typename... UnaryOpsSet>
15 {
16  static constexpr const char* name = "UnaryCombinedOp";
17 
18  __host__ __device__ UnaryCombinedOp() : unary_ops_() {}
19 
20  __host__ __device__ UnaryCombinedOp(UnaryOpsSet... unary_ops) : unary_ops_(unary_ops...) {}
21 
22  template <typename Y, typename X>
23  __host__ __device__ void operator()(Y& y, const X& x) const
24  {
25  // Execute first unary op to copy data to y
26  unary_ops_.At(Number<0>{})(y, x);
27 
28  static_for<1, Tuple<UnaryOpsSet...>::Size(), 1>{}([&](auto i) { unary_ops_.At(i)(y, y); });
29  };
30 
31  Tuple<UnaryOpsSet...> unary_ops_;
32 };
33 
34 // y = BinaryOp(UnaryOp0(x0), UnaryOp1(x1))
35 template <typename BinaryOp, typename UnaryOp0, typename UnaryOp1>
37 {
38  static constexpr const char* name = "BinaryWithUnaryCombinedOp";
39 
40  __host__ __device__ BinaryWithUnaryCombinedOp() : binary_op_(), unary_op0_(), unary_op1_() {}
41 
42  __host__ __device__ BinaryWithUnaryCombinedOp(BinaryOp binary_op,
43  UnaryOp0 unary_op0,
44  UnaryOp1 unary_op1)
45  : binary_op_(binary_op), unary_op0_(unary_op0), unary_op1_(unary_op1)
46  {
47  }
48 
49  template <typename Y, typename X0, typename X1>
50  __host__ __device__ void operator()(Y& y, const X0& x0, const X1& x1) const
51  {
52  Y unary_x0_tmp_result;
53  Y unary_x1_tmp_result;
54  unary_op0_(unary_x0_tmp_result, x0);
55  unary_op1_(unary_x1_tmp_result, x1);
56  binary_op_(y, unary_x0_tmp_result, unary_x1_tmp_result);
57  };
58 
59  private:
60  BinaryOp binary_op_;
61  UnaryOp0 unary_op0_;
62  UnaryOp1 unary_op1_;
63 };
64 
65 // y = BinaryOp0(BinaryOp1(UnaryOp0(x0), UnaryOp1(x1)), UnaryOp2(x2))
66 template <typename BinaryOp0,
67  typename BinaryOp1,
68  typename UnaryOp0,
69  typename UnaryOp1,
70  typename UnaryOp2>
72 {
73  static constexpr const char* name = "TrinaryWithUnaryCombinedOp";
74 
75  __host__ __device__ TrinaryWithUnaryCombinedOp()
76  : binary_op0_(), binary_op1_(), unary_op0_(), unary_op1_(), unary_op2_()
77  {
78  }
79 
80  __host__ __device__ TrinaryWithUnaryCombinedOp(BinaryOp0 binary_op0,
81  BinaryOp0 binary_op1,
82  UnaryOp0 unary_op0,
83  UnaryOp1 unary_op1,
84  UnaryOp2 unary_op2)
85  : binary_op0_(binary_op0),
86  binary_op1_(binary_op1),
87  unary_op0_(unary_op0),
88  unary_op1_(unary_op1),
89  unary_op2_(unary_op2)
90  {
91  }
92 
93  template <typename Y, typename X0, typename X1, typename X2>
94  __host__ __device__ void operator()(Y& y, const X0& x0, const X1& x1, const X2& x2) const
95  {
96 
97  Y unary_x0_tmp_result;
98  Y unary_x1_tmp_result;
99  Y unary_x2_tmp_result;
100  unary_op0_(unary_x0_tmp_result, x0);
101  unary_op1_(unary_x1_tmp_result, x1);
102  unary_op2_(unary_x2_tmp_result, x2);
103  binary_op0_(unary_x0_tmp_result, unary_x0_tmp_result, unary_x1_tmp_result);
104  binary_op1_(y, unary_x0_tmp_result, unary_x2_tmp_result);
105  };
106 
107  private:
108  BinaryOp0 binary_op0_{};
109  BinaryOp1 binary_op1_{};
110  UnaryOp0 unary_op0_{};
111  UnaryOp1 unary_op1_{};
112  UnaryOp2 unary_op2_{};
113 };
114 
117 
118 } // namespace element_wise
119 } // namespace tensor_operation
120 } // namespace ck
Definition: ck.hpp:268
Definition: tuple.hpp:117
Definition: integral_constant.hpp:20
Definition: functional2.hpp:33
Definition: combined_element_wise_operation.hpp:37
__host__ __device__ BinaryWithUnaryCombinedOp(BinaryOp binary_op, UnaryOp0 unary_op0, UnaryOp1 unary_op1)
Definition: combined_element_wise_operation.hpp:42
__host__ __device__ BinaryWithUnaryCombinedOp()
Definition: combined_element_wise_operation.hpp:40
__host__ __device__ void operator()(Y &y, const X0 &x0, const X1 &x1) const
Definition: combined_element_wise_operation.hpp:50
static constexpr const char * name
Definition: combined_element_wise_operation.hpp:38
Definition: combined_element_wise_operation.hpp:72
__host__ __device__ TrinaryWithUnaryCombinedOp(BinaryOp0 binary_op0, BinaryOp0 binary_op1, UnaryOp0 unary_op0, UnaryOp1 unary_op1, UnaryOp2 unary_op2)
Definition: combined_element_wise_operation.hpp:80
__host__ __device__ void operator()(Y &y, const X0 &x0, const X1 &x1, const X2 &x2) const
Definition: combined_element_wise_operation.hpp:94
static constexpr const char * name
Definition: combined_element_wise_operation.hpp:73
__host__ __device__ TrinaryWithUnaryCombinedOp()
Definition: combined_element_wise_operation.hpp:75
Definition: combined_element_wise_operation.hpp:15
__host__ __device__ UnaryCombinedOp()
Definition: combined_element_wise_operation.hpp:18
Tuple< UnaryOpsSet... > unary_ops_
Definition: combined_element_wise_operation.hpp:29
__host__ __device__ UnaryCombinedOp(UnaryOpsSet... unary_ops)
Definition: combined_element_wise_operation.hpp:20
__host__ __device__ void operator()(Y &y, const X &x) const
Definition: combined_element_wise_operation.hpp:23
static constexpr const char * name
Definition: combined_element_wise_operation.hpp:16