13 namespace tensor_operation {
 
   14 namespace element_wise {
 
   36     static constexpr 
const char* 
name = 
"AddReluAdd";
 
   38     template <
typename Y, 
typename X0, 
typename X1, 
typename X2>
 
   39     __host__ __device__ constexpr 
void operator()(Y&, 
const X0&, 
const X1&, 
const X2&) 
const;
 
   51     __host__ __device__ constexpr 
void operator()<float, float, float, 
float>(
float& y,
 
   54                                                                               const float& x2) 
const 
   57         float b = 
a > 0 ? 
a : 0;
 
   63     __host__ __device__ constexpr 
void operator()<float, float, 
half_t, 
half_t>(
 
   64         float& y, 
const float& x0, 
const half_t& x1, 
const half_t& x2) 
const 
   67         float b = 
a > 0 ? 
a : 0;
 
   77         (*this)(y_float, x0, x1, x2);
 
   86         float b = 
a > 0 ? 
a : 0;
 
  101 #ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4 
  116     static constexpr 
const char* 
name = 
"AddHardswishAdd";
 
  118     template <
typename Y, 
typename X0, 
typename X1, 
typename X2>
 
  119     __host__ __device__ constexpr 
void operator()(Y&, 
const X0&, 
const X1&, 
const X2&) 
const;
 
  122     __host__ __device__ constexpr 
void operator()<float, float, float, 
float>(
float& y,
 
  125                                                                               const float& x2) 
const 
  128         float b = 
a + 
float{3};
 
  129         float c = (b > 0) * (b > 
float{6} ? 
float{6} : b) * 
a * 
float{0.166667};
 
  139         float b = 
a + 
float{3};
 
  140         float c = (b > 0) * (b > 
float{6} ? 
float{6} : b) * 
a * 
float{0.166667};
 
  150     static constexpr 
const char* 
name = 
"AddAdd";
 
  152     template <
typename E, 
typename C, 
typename D0, 
typename D1>
 
  153     __host__ __device__ 
void operator()(E& e, 
const C& c, 
const D0& d0, 
const D1& d1)
 const 
  158                       "Data type is not supported by this operation!");
 
  162                       "Data type is not supported by this operation!");
 
  166                       "Data type is not supported by this operation!");
 
  170                       "Data type is not supported by this operation!");
 
  172         const C y = c + type_convert<C>(d0) + type_convert<C>(d1);
 
  173         e         = type_convert<E>(y);
 
  181     static constexpr 
const char* 
name = 
"AddMultiply";
 
  183     template <
typename E, 
typename C, 
typename D0, 
typename D1>
 
  184     __host__ __device__ 
void operator()(E& e, 
const C& c, 
const D0& d0, 
const D1& d1) 
const;
 
  192         const half_t y = (c + d0) * d1;
 
  201         const half_t y = (type_convert<half_t>(c) + d0) * d1;
 
  205     __host__ __device__ 
void operator()<float, float, 
half_t, 
half_t>(
float& e,
 
  210         const float y = (c + d0) * d1;
 
  219     static constexpr 
const char* 
name = 
"MultiplyAdd";
 
  221     template <
typename E, 
typename C, 
typename D0, 
typename D1>
 
  222     __host__ __device__ 
void operator()(E& e, 
const C& c, 
const D0& d0, 
const D1& d1) 
const;
 
  230         const half_t y = (c * d0) + d1;
 
  239         const half_t y = type_convert<half_t>(c) * d0 + d1;
 
  248         const bhalf_t y = type_convert<bhalf_t>(c) * d0 + d1;
 
  252     __host__ __device__ 
void operator()<float, float, 
half_t, 
half_t>(
float& e,
 
  257         const float y = c * d0 + d1;
 
  261     __host__ __device__ 
void operator()<
half_t, float, float, 
float>(
half_t& e,
 
  264                                                                      const float& d1) 
const 
  266         const float y = c * d0 + d1;
 
  273     static constexpr 
const char* 
name = 
"MultiplyMultiply";
 
  275     template <
typename E, 
typename C, 
typename D0, 
typename D1>
 
  276     __host__ __device__ constexpr 
void 
  277     operator()(E& e, 
const C& c, 
const D0& d0, 
const D1& d1) 
const;
 
  280     __host__ __device__ constexpr 
void operator()<
ck::half_t, float, float, 
float>(
 
  281         ck::half_t& e, 
const float& c, 
const float& d0, 
const float& d1) 
const 
  283         const float x0_f = c * d0 * d1;
 
  285         e = ck::type_convert<ck::half_t>(x0_f);
 
  289     __host__ __device__ constexpr 
void operator()<
ck::bhalf_t, float, float, 
float>(
 
  290         ck::bhalf_t& e, 
const float& c, 
const float& d0, 
const float& d1) 
const 
  292         const float x0_f = c * d0 * d1;
 
  294         e = ck::type_convert<ck::bhalf_t>(x0_f);
 
  302             ck::type_convert<float>(c) * ck::type_convert<float>(d0) * ck::type_convert<float>(d1);
 
  304         e = ck::type_convert<ck::half_t>(x0_f);
 
  308     __host__ __device__ constexpr 
void operator()<
ck::half_t, int, float, 
float>(
 
  309         ck::half_t& e, 
const int& c, 
const float& d0, 
const float& d1) 
const 
  312             ck::type_convert<float>(c) * ck::type_convert<float>(d0) * ck::type_convert<float>(d1);
 
  314         e = ck::type_convert<ck::half_t>(x0_f);
 
  318     __host__ __device__ constexpr 
void operator()<
ck::bhalf_t, int, float, 
float>(
 
  319         ck::bhalf_t& e, 
const int& c, 
const float& d0, 
const float& d1) 
const 
  322             ck::type_convert<float>(c) * ck::type_convert<float>(d0) * ck::type_convert<float>(d1);
 
  324         e = ck::type_convert<ck::bhalf_t>(x0_f);
 
  330     static constexpr 
const char* 
name = 
"MultiplyAddFastGelu";
 
  332     template <
typename E, 
typename C, 
typename D0, 
typename D1>
 
  333     __host__ __device__ constexpr 
void 
  334     operator()(E& e, 
const C& c, 
const D0& d0, 
const D1& d1) 
const;
 
  340         const float x0_f = c * ck::type_convert<float>(d0) + ck::type_convert<float>(d1);
 
  344         FastGelu{}.template operator()<float, 
float>(x1_f, x0_f);
 
  346         e = ck::type_convert<ck::bhalf_t>(x1_f);
 
  353     static constexpr 
const char* 
name = 
"AddAddFastGelu";
 
  355     template <
typename E, 
typename C, 
typename D0, 
typename D1>
 
  356     __host__ __device__ constexpr 
void 
  357     operator()(E& e, 
const C& c, 
const D0& d0, 
const D1& d1) 
const;
 
  360     __host__ __device__ constexpr 
void operator()<float, float, float, 
float>(
float& e,
 
  363                                                                               const float& d1) 
const 
  365         const float x = c + d0 + d1;
 
  367         FastGelu{}.template operator()<float, 
float>(e, x);
 
  374         const half_t x = c + d0 + d1;
 
  383         const float x0_f = c + d0 + d1;
 
  390         e = type_convert<half_t>(x1_f);
 
  397         const float x0_f = c + type_convert<float>(d0) + type_convert<float>(d1);
 
  404         e = type_convert<bhalf_t>(x1_f);
 
  412             type_convert<float>(c) + type_convert<float>(d0) + type_convert<float>(d1);
 
  419         e = type_convert<int8_t>(x1_f);
 
  426     static constexpr 
const char* 
name = 
"ScaleAddScaleAddRelu";
 
  433     template <
typename E, 
typename C, 
typename D0, 
typename D1>
 
  434     __host__ __device__ constexpr 
void 
  435     operator()(E& e, 
const C& c, 
const D0& d0, 
const D1& d1) 
const;
 
  438     __host__ __device__ constexpr 
void operator()<float, float, float, 
float>(
float& e,
 
  441                                                                               const float& d1) 
const 
  451         const float x = type_convert<float>(c) * 
alpha1_ + 
alpha2_ * type_convert<float>(d0) +
 
  452                         type_convert<float>(d1);
 
  455         result       = x > 0 ? x : 0;
 
  457         e = type_convert<half_t>(result);
 
  464         const float x = type_convert<float>(c) * 
alpha1_ + 
alpha2_ * type_convert<float>(d0) +
 
  465                         type_convert<float>(d1);
 
  468         result       = x > 0 ? x : 0;
 
  470         e = type_convert<bhalf_t>(result);
 
  474     __host__ __device__ constexpr 
void operator()<
int8_t, 
int8_t, float, 
float>(
 
  475         int8_t& e, 
const int8_t& c, 
const float& d0, 
const float& d1) 
const 
  477         const float x = type_convert<float>(c) * 
alpha1_ + 
alpha2_ * d0 + d1;
 
  480         result       = x > 0 ? x : 0;
 
  482         e = type_convert<int8_t>(result);
 
  491     static constexpr 
const char* 
name = 
"Normalize";
 
  496     template <
typename T1, 
typename T2, 
typename T3>
 
  500                                                   const T2& mean_square,
 
  502                                                   const T3& beta) 
const;
 
  508                                                                          const float& mean_square,
 
  512         using ck::math::sqrt;
 
  514         float variance = mean_square - (mean * mean);
 
  516         float tmp_x     = type_convert<float>(x);
 
  517         float tmp_gamma = type_convert<float>(gamma);
 
  518         float tmp_beta  = type_convert<float>(beta);
 
  521             ((tmp_x - mean) / sqrt(variance + type_convert<float>(
epsilon_))) * tmp_gamma +
 
  524         y = type_convert<half_t>(tmp_y);
 
  528     __host__ __device__ constexpr 
void operator()<float, float, 
float>(
float& y,
 
  531                                                                        const float& mean_square,
 
  533                                                                        const float& beta) 
const 
  535         using ck::math::sqrt;
 
  537         float variance = mean_square - (mean * mean);
 
  538         y = ((x - mean) / sqrt(variance + type_convert<float>(
epsilon_))) * gamma + beta;
 
  542     __host__ __device__ constexpr 
void operator()<double, double, 
double>(
double& y,
 
  545                                                                           const double& mean_square,
 
  547                                                                           const double& beta) 
const 
  549         using ck::math::sqrt;
 
  551         double variance = mean_square - (mean * mean);
 
  552         y               = ((x - mean) / sqrt(variance + 
epsilon_)) * gamma + beta;
 
  564     static constexpr 
const char* 
name = 
"NormalizeInInfer";
 
  568     template <
typename T1, 
typename T2, 
typename T3, 
typename T4>
 
  574                                                   const T4& beta)
 const 
  577                       "Data type is not supported by this operation!");
 
  580         using ck::math::sqrt;
 
  584         tmp_x = type_convert<T2>(x);
 
  586         tmp_y = ((tmp_x - mean) / sqrt(variance + type_convert<T2>(
epsilon_))) *
 
  587                     type_convert<T2>(gamma) +
 
  588                 type_convert<T2>(beta);
 
  589         y = type_convert<T1>(tmp_y);
 
  598     static constexpr 
const char* 
name = 
"BiasNormalizeInInferClamp";
 
  602                               float epsilon = 1e-4)
 
  607     template <
typename T>
 
  617         using ck::math::sqrt;
 
  619         float tmp_x = type_convert<float>(x) + type_convert<float>(bias);
 
  622             ((tmp_x - type_convert<float>(mean)) / sqrt(type_convert<float>(variance) + 
epsilon_)) *
 
  623                 type_convert<float>(gamma) +
 
  624             type_convert<float>(beta);
 
  626         y = type_convert<T>(tmp_y);
 
  634                                                   const float& variance,
 
  636                                                   const float& beta)
 const 
  639         using ck::math::sqrt;
 
  641         float tmp_y = (((x + bias) - mean) / sqrt(variance + 
epsilon_)) * gamma + beta;
 
  649 template <
typename Y, 
typename X>
 
  655     static constexpr 
const char* name = 
"UnaryTypeConvert";
 
  659         y = ck::type_convert<float, ck::bhalf_t>(x);
 
  666     static constexpr 
const char* name = 
"UnaryTypeConvert";
 
  670         y = ck::type_convert<ck::bhalf_t, float>(x);
 
__host__ T ceil(T x)
Definition: math_v2.hpp:331
 
__host__ T floor(T x)
Definition: math_v2.hpp:367
 
_Float16 half_t
Definition: data_type.hpp:31
 
ushort bhalf_t
Definition: data_type.hpp:30
 
__host__ constexpr __device__ Y type_convert(X x)
Definition: type_convert.hpp:98
 
_BitInt(4) int4_t
Definition: data_type.hpp:32
 
const GenericPointer< typename T::ValueType > T2 T::AllocatorType & a
Definition: pointer.h:1517
 
signed int int32_t
Definition: stdint.h:123
 
signed char int8_t
Definition: stdint.h:121
 
Definition: numeric_limits.hpp:309
 
Definition: element_wise_operation.hpp:352
 
__host__ constexpr __device__ void operator()(E &e, const C &c, const D0 &d0, const D1 &d1) const
 
static constexpr const char * name
Definition: element_wise_operation.hpp:353
 
Definition: element_wise_operation.hpp:149
 
__host__ __device__ void operator()(E &e, const C &c, const D0 &d0, const D1 &d1) const
Definition: element_wise_operation.hpp:153
 
static constexpr const char * name
Definition: element_wise_operation.hpp:150
 
Definition: element_wise_operation.hpp:115
 
static constexpr const char * name
Definition: element_wise_operation.hpp:116
 
__host__ constexpr __device__ void operator()(Y &, const X0 &, const X1 &, const X2 &) const
 
Definition: element_wise_operation.hpp:180
 
__host__ __device__ void operator()(E &e, const C &c, const D0 &d0, const D1 &d1) const
 
static constexpr const char * name
Definition: element_wise_operation.hpp:181
 
Definition: element_wise_operation.hpp:35
 
static constexpr const char * name
Definition: element_wise_operation.hpp:36
 
__host__ constexpr __device__ void operator()(Y &, const X0 &, const X1 &, const X2 &) const
 
Definition: element_wise_operation.hpp:597
 
BiasNormalizeInInferClamp(float floor=0.f, float ceil=NumericLimits< float >::Max(), float epsilon=1e-4)
Definition: element_wise_operation.hpp:600
 
__host__ constexpr __device__ void operator()(T &y, const T &x, const T &bias, const T &mean, const T &variance, const T &gamma, const T &beta) const
Definition: element_wise_operation.hpp:608
 
float epsilon_
Definition: element_wise_operation.hpp:646
 
Clamp clamp_
Definition: element_wise_operation.hpp:643
 
__host__ constexpr __device__ void operator()(float &y, const float &x, const float &bias, const float &mean, const float &variance, const float &gamma, const float &beta) const
Definition: element_wise_operation.hpp:630
 
static constexpr const char * name
Definition: element_wise_operation.hpp:598
 
Definition: unary_element_wise_operation.hpp:811
 
Definition: unary_element_wise_operation.hpp:924
 
Definition: element_wise_operation.hpp:329
 
__host__ constexpr __device__ void operator()(E &e, const C &c, const D0 &d0, const D1 &d1) const
 
static constexpr const char * name
Definition: element_wise_operation.hpp:330
 
Definition: element_wise_operation.hpp:218
 
__host__ __device__ void operator()(E &e, const C &c, const D0 &d0, const D1 &d1) const
 
static constexpr const char * name
Definition: element_wise_operation.hpp:219
 
Definition: element_wise_operation.hpp:272
 
static constexpr const char * name
Definition: element_wise_operation.hpp:273
 
__host__ constexpr __device__ void operator()(E &e, const C &c, const D0 &d0, const D1 &d1) const
 
Definition: element_wise_operation.hpp:490
 
Normalize(double epsilon=1e-4)
Definition: element_wise_operation.hpp:494
 
double epsilon_
Definition: element_wise_operation.hpp:553
 
__host__ constexpr __device__ void operator()(T1 &y, const T1 &x, const T2 &mean, const T2 &mean_square, const T3 &gamma, const T3 &beta) const
 
static constexpr const char * name
Definition: element_wise_operation.hpp:491
 
Definition: element_wise_operation.hpp:563
 
static constexpr const char * name
Definition: element_wise_operation.hpp:564
 
double epsilon_
Definition: element_wise_operation.hpp:590
 
__host__ constexpr __device__ void operator()(T1 &y, const T1 &x, const T2 &mean, const T2 &variance, const T3 &gamma, const T4 &beta) const
Definition: element_wise_operation.hpp:569
 
NormalizeInInfer(double epsilon=1e-4)
Definition: element_wise_operation.hpp:566
 
Definition: element_wise_operation.hpp:425
 
ScaleAddScaleAddRelu(const float alpha1=1.f, const float alpha2=1.f)
Definition: element_wise_operation.hpp:428
 
static constexpr const char * name
Definition: element_wise_operation.hpp:426
 
const float alpha2_
Definition: element_wise_operation.hpp:486
 
const float alpha1_
Definition: element_wise_operation.hpp:485
 
__host__ constexpr __device__ void operator()(E &e, const C &c, const D0 &d0, const D1 &d1) const
 
__host__ __device__ void operator()(ck::bhalf_t &y, float &x) const
Definition: element_wise_operation.hpp:668
 
__host__ __device__ void operator()(float &y, ck::bhalf_t &x) const
Definition: element_wise_operation.hpp:657
 
Definition: element_wise_operation.hpp:650