9 namespace element_wise {
13 template <
typename Y,
typename X0,
typename X1>
14 __host__ __device__ constexpr
void operator()(Y& y,
const X0& x0,
const X1& x1)
const;
17 __host__ __device__ constexpr
void
18 operator()<
float>(
float& y,
const float& x0,
const float& x1)
const
24 __host__ __device__ constexpr
void
25 operator()<
double>(
double& y,
const double& x0,
const double& x1)
const
31 __host__ __device__ constexpr
void
32 operator()<
float>(
float& y,
const float& x0,
const half_t& x1)
const
34 y = x0 + type_convert<half_t>(x1);
38 __host__ __device__ constexpr
void
39 operator()<
half_t>(
half_t& y,
const float& x0,
const float& x1)
const
41 y = type_convert<half_t>(x0 + x1);
45 __host__ __device__ constexpr
void
48 y = type_convert<half_t>(x0) + x1;
52 __host__ __device__ constexpr
void
59 __host__ __device__ constexpr
void
60 operator()<
float>(
float& y,
const float& x0,
const bf16_t& x1)
const
62 const float x1_tmp = type_convert<float>(x1);
67 __host__ __device__ constexpr
void
70 const float x1_tmp = type_convert<float>(x0);
71 const float x2_tmp = type_convert<float>(x1);
72 const float y_tmp = x1_tmp + x2_tmp;
73 y = type_convert<bf16_t>(y_tmp);
77 __host__ __device__ constexpr
void
80 const float x2_tmp = type_convert<float>(x1);
81 const float y_tmp = x0 + x2_tmp;
82 y = type_convert<bf16_t>(y_tmp);
86 __host__ __device__ constexpr
void
87 operator()<
bf16_t>(
bf16_t& y,
const float& x0,
const float& x1)
const
89 const float y_tmp = x0 + x1;
90 y = type_convert<bf16_t>(y_tmp);
94 __host__ __device__ constexpr
void
Definition: cluster_descriptor.hpp:13
int8_t int8_t
Definition: int8.hpp:20
bfloat16_t bf16_t
Definition: bfloat16.hpp:113
_Float16 half_t
Definition: half.hpp:111
Definition: binary_elementwise_operation.hpp:12
__host__ constexpr __device__ void operator()(Y &y, const X0 &x0, const X1 &x1) const