13 namespace tensor_operation {
14 namespace element_wise {
36 template <
typename Y,
typename X0,
typename X1,
typename X2>
37 __host__ __device__ constexpr
void operator()(Y&,
const X0&,
const X1&,
const X2&)
const;
49 __host__ __device__ constexpr
void operator()<float, float, float,
float>(
float& y,
52 const float& x2)
const
55 float b = a > 0 ? a : 0;
65 float b = a > 0 ? a : 0;
75 float b = a > 0 ? a : 0;
90 #ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
105 template <
typename Y,
typename X0,
typename X1,
typename X2>
106 __host__ __device__ constexpr
void operator()(Y&,
const X0&,
const X1&,
const X2&)
const;
109 __host__ __device__ constexpr
void operator()<float, float, float,
float>(
float& y,
112 const float& x2)
const
115 float b = a +
float{3};
116 float c = (b > 0) * (b >
float{6} ?
float{6} : b) * a *
float{0.166667};
126 float b = a +
float{3};
127 float c = (b > 0) * (b >
float{6} ?
float{6} : b) * a *
float{0.166667};
137 template <
typename E,
typename C,
typename D0,
typename D1>
138 __host__ __device__
void operator()(E& e,
const C& c,
const D0& d0,
const D1& d1)
const
143 "Data type is not supported by this operation!");
147 "Data type is not supported by this operation!");
151 "Data type is not supported by this operation!");
155 "Data type is not supported by this operation!");
157 const C y = c + type_convert<C>(d0) + type_convert<C>(d1);
158 e = type_convert<E>(y);
166 template <
typename E,
typename C,
typename D0,
typename D1>
167 __host__ __device__
void operator()(E& e,
const C& c,
const D0& d0,
const D1& d1)
const;
175 const half_t y = (c + d0) * d1;
184 const half_t y = (type_convert<half_t>(c) + d0) * d1;
188 __host__ __device__
void operator()<float, float,
half_t,
half_t>(
float& e,
193 const float y = (c + d0) * d1;
202 template <
typename E,
typename C,
typename D0,
typename D1>
203 __host__ __device__
void operator()(E& e,
const C& c,
const D0& d0,
const D1& d1)
const;
211 const half_t y = (c * d0) + d1;
220 const half_t y = type_convert<half_t>(c) * d0 + d1;
229 const bhalf_t y = type_convert<bhalf_t>(c) * d0 + d1;
233 __host__ __device__
void operator()<float, float,
half_t,
half_t>(
float& e,
238 const float y = c * d0 + d1;
242 __host__ __device__
void operator()<
half_t, float, float,
float>(
half_t& e,
245 const float& d1)
const
247 const float y = c * d0 + d1;
254 template <
typename E,
typename C,
typename D0,
typename D1>
255 __host__ __device__ constexpr
void
256 operator()(E& e,
const C& c,
const D0& d0,
const D1& d1)
const;
259 __host__ __device__ constexpr
void operator()<
ck::half_t, float, float,
float>(
260 ck::half_t& e,
const float& c,
const float& d0,
const float& d1)
const
262 const float x0_f = c * d0 * d1;
264 e = ck::type_convert<ck::half_t>(x0_f);
268 __host__ __device__ constexpr
void operator()<
ck::bhalf_t, float, float,
float>(
269 ck::bhalf_t& e,
const float& c,
const float& d0,
const float& d1)
const
271 const float x0_f = c * d0 * d1;
273 e = ck::type_convert<ck::bhalf_t>(x0_f);
281 ck::type_convert<float>(c) * ck::type_convert<float>(d0) * ck::type_convert<float>(d1);
283 e = ck::type_convert<ck::half_t>(x0_f);
287 __host__ __device__ constexpr
void operator()<
ck::half_t, int, float,
float>(
288 ck::half_t& e,
const int& c,
const float& d0,
const float& d1)
const
291 ck::type_convert<float>(c) * ck::type_convert<float>(d0) * ck::type_convert<float>(d1);
293 e = ck::type_convert<ck::half_t>(x0_f);
297 __host__ __device__ constexpr
void operator()<
ck::bhalf_t, int, float,
float>(
298 ck::bhalf_t& e,
const int& c,
const float& d0,
const float& d1)
const
301 ck::type_convert<float>(c) * ck::type_convert<float>(d0) * ck::type_convert<float>(d1);
303 e = ck::type_convert<ck::bhalf_t>(x0_f);
309 template <
typename E,
typename C,
typename D0,
typename D1>
310 __host__ __device__ constexpr
void
311 operator()(E& e,
const C& c,
const D0& d0,
const D1& d1)
const;
317 const float x0_f = c * ck::type_convert<float>(d0) + ck::type_convert<float>(d1);
321 FastGelu{}.template operator()<float,
float>(x1_f, x0_f);
323 e = ck::type_convert<ck::bhalf_t>(x1_f);
330 template <
typename E,
typename C,
typename D0,
typename D1>
331 __host__ __device__ constexpr
void
332 operator()(E& e,
const C& c,
const D0& d0,
const D1& d1)
const;
335 __host__ __device__ constexpr
void operator()<float, float, float,
float>(
float& e,
338 const float& d1)
const
340 const float x = c + d0 + d1;
342 FastGelu{}.template operator()<float,
float>(e, x);
349 const half_t x = c + d0 + d1;
358 const float x0_f = c + d0 + d1;
365 e = type_convert<half_t>(x1_f);
372 const float x0_f = c + type_convert<float>(d0) + type_convert<float>(d1);
379 e = type_convert<bhalf_t>(x1_f);
387 type_convert<float>(c) + type_convert<float>(d0) + type_convert<float>(d1);
394 e = type_convert<int8_t>(x1_f);
407 template <
typename E,
typename C,
typename D0,
typename D1>
408 __host__ __device__ constexpr
void
409 operator()(E& e,
const C& c,
const D0& d0,
const D1& d1)
const;
412 __host__ __device__ constexpr
void operator()<float, float, float,
float>(
float& e,
415 const float& d1)
const
425 const float x = type_convert<float>(c) *
alpha1_ +
alpha2_ * type_convert<float>(d0) +
426 type_convert<float>(d1);
429 result = x > 0 ? x : 0;
431 e = type_convert<half_t>(result);
438 const float x = type_convert<float>(c) *
alpha1_ +
alpha2_ * type_convert<float>(d0) +
439 type_convert<float>(d1);
442 result = x > 0 ? x : 0;
444 e = type_convert<bhalf_t>(result);
448 __host__ __device__ constexpr
void operator()<
int8_t,
int8_t, float,
float>(
449 int8_t& e,
const int8_t& c,
const float& d0,
const float& d1)
const
451 const float x = type_convert<float>(c) *
alpha1_ +
alpha2_ * d0 + d1;
454 result = x > 0 ? x : 0;
456 e = type_convert<int8_t>(result);
468 template <
typename T1,
typename T2,
typename T3>
472 const T2& mean_square,
474 const T3& beta)
const;
480 const float& mean_square,
484 using ck::math::sqrt;
486 float variance = mean_square - (mean * mean);
488 float tmp_x = type_convert<float>(x);
489 float tmp_gamma = type_convert<float>(gamma);
490 float tmp_beta = type_convert<float>(beta);
493 ((tmp_x - mean) / sqrt(variance + type_convert<float>(
epsilon_))) * tmp_gamma +
496 y = type_convert<half_t>(tmp_y);
500 __host__ __device__ constexpr
void operator()<float, float,
float>(
float& y,
503 const float& mean_square,
505 const float& beta)
const
507 using ck::math::sqrt;
509 float variance = mean_square - (mean * mean);
510 y = ((x - mean) / sqrt(variance + type_convert<float>(
epsilon_))) * gamma + beta;
514 __host__ __device__ constexpr
void operator()<double, double,
double>(
double& y,
517 const double& mean_square,
519 const double& beta)
const
521 using ck::math::sqrt;
523 double variance = mean_square - (mean * mean);
524 y = ((x - mean) / sqrt(variance +
epsilon_)) * gamma + beta;
538 template <
typename T1,
typename T2,
typename T3,
typename T4>
544 const T4& beta)
const
547 "Data type is not supported by this operation!");
550 using ck::math::sqrt;
554 tmp_x = type_convert<T2>(x);
556 tmp_y = ((tmp_x - mean) / sqrt(variance + type_convert<T2>(
epsilon_))) *
557 type_convert<T2>(gamma) +
558 type_convert<T2>(beta);
559 y = type_convert<T1>(tmp_y);
570 float epsilon = 1e-4)
575 template <
typename T>
585 using ck::math::sqrt;
587 float tmp_x = type_convert<float>(x) + type_convert<float>(bias);
590 ((tmp_x - type_convert<float>(mean)) / sqrt(type_convert<float>(variance) +
epsilon_)) *
591 type_convert<float>(gamma) +
592 type_convert<float>(beta);
594 y = type_convert<T>(tmp_y);
602 const float& variance,
604 const float& beta)
const
607 using ck::math::sqrt;
609 float tmp_y = (((x + bias) - mean) / sqrt(variance +
epsilon_)) * gamma + beta;
617 template <
typename Y,
typename X>
625 y = ck::type_convert<float, ck::bhalf_t>(x);
634 y = ck::type_convert<ck::bhalf_t, float>(x);
__host__ T ceil(T x)
Definition: math_v2.hpp:331
__host__ T floor(T x)
Definition: math_v2.hpp:367
int8_t int8_t
Definition: int8.hpp:20
int32_t int32_t
Definition: integer.hpp:10
_Float16 half_t
Definition: data_type.hpp:30
ushort bhalf_t
Definition: data_type.hpp:29
__host__ constexpr __device__ Y type_convert(X x)
Definition: type_convert.hpp:98
_BitInt(4) int4_t
Definition: data_type.hpp:31
Definition: numeric_limits.hpp:309
Definition: element_wise_operation.hpp:329
__host__ constexpr __device__ void operator()(E &e, const C &c, const D0 &d0, const D1 &d1) const
Definition: element_wise_operation.hpp:136
__host__ __device__ void operator()(E &e, const C &c, const D0 &d0, const D1 &d1) const
Definition: element_wise_operation.hpp:138
Definition: element_wise_operation.hpp:104
__host__ constexpr __device__ void operator()(Y &, const X0 &, const X1 &, const X2 &) const
Definition: element_wise_operation.hpp:165
__host__ __device__ void operator()(E &e, const C &c, const D0 &d0, const D1 &d1) const
Definition: element_wise_operation.hpp:35
__host__ constexpr __device__ void operator()(Y &, const X0 &, const X1 &, const X2 &) const
Definition: element_wise_operation.hpp:567
BiasNormalizeInInferClamp(float floor=0.f, float ceil=NumericLimits< float >::Max(), float epsilon=1e-4)
Definition: element_wise_operation.hpp:568
__host__ constexpr __device__ void operator()(T &y, const T &x, const T &bias, const T &mean, const T &variance, const T &gamma, const T &beta) const
Definition: element_wise_operation.hpp:576
float epsilon_
Definition: element_wise_operation.hpp:614
Clamp clamp_
Definition: element_wise_operation.hpp:611
__host__ constexpr __device__ void operator()(float &y, const float &x, const float &bias, const float &mean, const float &variance, const float &gamma, const float &beta) const
Definition: element_wise_operation.hpp:598
Definition: unary_element_wise_operation.hpp:757
Definition: unary_element_wise_operation.hpp:866
Definition: element_wise_operation.hpp:308
__host__ constexpr __device__ void operator()(E &e, const C &c, const D0 &d0, const D1 &d1) const
Definition: element_wise_operation.hpp:201
__host__ __device__ void operator()(E &e, const C &c, const D0 &d0, const D1 &d1) const
Definition: element_wise_operation.hpp:253
__host__ constexpr __device__ void operator()(E &e, const C &c, const D0 &d0, const D1 &d1) const
Definition: element_wise_operation.hpp:464
Normalize(double epsilon=1e-4)
Definition: element_wise_operation.hpp:466
double epsilon_
Definition: element_wise_operation.hpp:525
__host__ constexpr __device__ void operator()(T1 &y, const T1 &x, const T2 &mean, const T2 &mean_square, const T3 &gamma, const T3 &beta) const
Definition: element_wise_operation.hpp:535
double epsilon_
Definition: element_wise_operation.hpp:560
__host__ constexpr __device__ void operator()(T1 &y, const T1 &x, const T2 &mean, const T2 &variance, const T3 &gamma, const T4 &beta) const
Definition: element_wise_operation.hpp:539
NormalizeInInfer(double epsilon=1e-4)
Definition: element_wise_operation.hpp:536
Definition: element_wise_operation.hpp:400
ScaleAddScaleAddRelu(const float alpha1=1.f, const float alpha2=1.f)
Definition: element_wise_operation.hpp:402
const float alpha2_
Definition: element_wise_operation.hpp:460
const float alpha1_
Definition: element_wise_operation.hpp:459
__host__ constexpr __device__ void operator()(E &e, const C &c, const D0 &d0, const D1 &d1) const
__host__ __device__ void operator()(ck::bhalf_t &y, float &x) const
Definition: element_wise_operation.hpp:632
__host__ __device__ void operator()(float &y, ck::bhalf_t &x) const
Definition: element_wise_operation.hpp:623
Definition: element_wise_operation.hpp:618