10 namespace tensor_operation {
 
   11 namespace element_wise {
 
   15     static constexpr 
const char* 
name = 
"Add";
 
   17     template <
typename Y, 
typename X0, 
typename X1>
 
   18     __host__ __device__ constexpr 
void operator()(Y& y, 
const X0& x0, 
const X1& x1) 
const;
 
   21     __host__ __device__ constexpr 
void 
   22     operator()<
float>(
float& y, 
const float& x0, 
const float& x1) 
const 
   28     __host__ __device__ constexpr 
void 
   29     operator()<
double>(
double& y, 
const double& x0, 
const double& x1) 
const 
   35     __host__ __device__ constexpr 
void 
   36     operator()<
float>(
float& y, 
const float& x0, 
const half_t& x1) 
const 
   38         y = x0 + type_convert<half_t>(x1);
 
   42     __host__ __device__ constexpr 
void 
   43     operator()<
half_t>(
half_t& y, 
const float& x0, 
const float& x1) 
const 
   45         y = type_convert<half_t>(x0 + x1);
 
   49     __host__ __device__ constexpr 
void 
   52         y = x0 + type_convert<float>(x1);
 
   56     __host__ __device__ constexpr 
void 
   63     __host__ __device__ constexpr 
void 
   64     operator()<
float>(
float& y, 
const float& x0, 
const bhalf_t& x1) 
const 
   66         const float x1_tmp = ck::type_convert<float>(x1);
 
   71     __host__ __device__ constexpr 
void 
   74         const float x1_tmp = ck::type_convert<float>(x0);
 
   75         const float x2_tmp = ck::type_convert<float>(x1);
 
   76         const float y_tmp  = x1_tmp + x2_tmp;
 
   77         y                  = ck::type_convert<bhalf_t>(y_tmp);
 
   81     __host__ __device__ constexpr 
void 
   84         const float x2_tmp = ck::type_convert<float>(x1);
 
   85         const float y_tmp  = x0 + x2_tmp;
 
   86         y                  = ck::type_convert<bhalf_t>(y_tmp);
 
   90     __host__ __device__ constexpr 
void 
   99     static constexpr 
const char* 
name = 
"Max";
 
  101     template <
typename Y, 
typename X0, 
typename X1>
 
  102     __host__ __device__ 
void operator()(Y& y, 
const X0& x0, 
const X1& x1)
 const 
  104         const Y x0_converted = type_convert<Y>(x0);
 
  105         const Y x1_converted = type_convert<Y>(x1);
 
  112     static constexpr 
const char* 
name = 
"Min";
 
  114     template <
typename Y, 
typename X0, 
typename X1>
 
  115     __host__ __device__ 
void operator()(Y& y, 
const X0& x0, 
const X1& x1)
 const 
  117         const Y x0_converted = type_convert<Y>(x0);
 
  118         const Y x1_converted = type_convert<Y>(x1);
 
  125     static constexpr 
const char* 
name = 
"Multiply";
 
  127     template <
typename Y, 
typename X0, 
typename X1>
 
  128     __host__ __device__ constexpr 
void operator()(Y& y, 
const X0& x0, 
const X1& x1) 
const;
 
  131     __host__ __device__ constexpr 
void 
  132     operator()<
float>(
float& y, 
const float& x0, 
const float& x1) 
const 
  138     __host__ __device__ constexpr 
void 
  139     operator()<
double>(
double& y, 
const double& x0, 
const double& x1) 
const 
  145     __host__ __device__ constexpr 
void 
  146     operator()<
float>(
float& y, 
const float& x0, 
const half_t& x1) 
const 
  148         y = x0 * type_convert<half_t>(x1);
 
  152     __host__ __device__ constexpr 
void 
  155         y = type_convert<half_t>(x0 * x1);
 
  159     __host__ __device__ constexpr 
void 
  162         y = type_convert<half_t>(x0) * x1;
 
  166     __host__ __device__ constexpr 
void 
  173     __host__ __device__ constexpr 
void 
  174     operator()<
float>(
float& y, 
const float& x0, 
const bhalf_t& x1) 
const 
  176         const float x1_tmp = ck::type_convert<float>(x1);
 
  181     __host__ __device__ constexpr 
void 
  184         const float x1_tmp = ck::type_convert<float>(x0);
 
  185         const float x2_tmp = ck::type_convert<float>(x1);
 
  186         const float y_tmp  = x1_tmp * x2_tmp;
 
  187         y                  = ck::type_convert<bhalf_t>(y_tmp);
 
  191     __host__ __device__ constexpr 
void 
  194         const float x1_tmp = ck::type_convert<float>(x0);
 
  195         const float x2_tmp = ck::type_convert<float>(x1);
 
  196         const float y_tmp  = x1_tmp * x2_tmp;
 
  197         y                  = ck::type_convert<bhalf_t>(y_tmp);
 
  201     __host__ __device__ constexpr 
void 
  204         const float x2_tmp = ck::type_convert<float>(x1);
 
  205         const float y_tmp  = x0 * x2_tmp;
 
  206         y                  = ck::type_convert<bhalf_t>(y_tmp);
 
  210     __host__ __device__ constexpr 
void 
  219     static constexpr 
const char* 
name = 
"ScaleAdd";
 
  223     template <
typename Y, 
typename X0, 
typename X1>
 
  224     __host__ __device__ constexpr 
void operator()(Y& y, 
const X0& x0, 
const X1& x1)
 const 
  226         y = ck::type_convert<Y>(
scale_ * ck::type_convert<float>(x0) + ck::type_convert<float>(x1));
 
  230     __host__ __device__ 
void 
  231     operator()<float, float, 
half_t>(
float& y, 
const float& x0, 
const half_t& x1) 
const 
  233         y = 
scale_ * x0 + ck::type_convert<float>(x1);
 
  237     __host__ __device__ 
void 
  238     operator()<float, float, 
bhalf_t>(
float& y, 
const float& x0, 
const bhalf_t& x1) 
const 
  240         y = 
scale_ * x0 + ck::type_convert<float>(x1);
 
  248     static constexpr 
const char* 
name = 
"Subtract";
 
  250     template <
typename T>
 
  251     __host__ __device__ constexpr 
void operator()(T& y, 
const T& x0, 
const T& x1) 
const;
 
  254     __host__ __device__ constexpr 
void 
  255     operator()<
float>(
float& y, 
const float& x0, 
const float& x1) 
const 
  261     __host__ __device__ constexpr 
void 
  262     operator()<
double>(
double& y, 
const double& x0, 
const double& x1) 
const 
  268     __host__ __device__ constexpr 
void 
  275     __host__ __device__ constexpr 
void 
  278         const float x1_tmp = ck::type_convert<float>(x0);
 
  279         const float x2_tmp = ck::type_convert<float>(x1);
 
  280         const float y_tmp  = x1_tmp - x2_tmp;
 
  281         y                  = ck::type_convert<bhalf_t>(y_tmp);
 
  285     __host__ __device__ constexpr 
void 
  294     static constexpr 
const char* 
name = 
"Bilinear";
 
  298     template <
typename Y, 
typename X0, 
typename X1>
 
  299     __host__ __device__ constexpr 
void operator()(Y&, 
const X0&, 
const X1&) 
const;
 
  302     __host__ __device__ constexpr 
void 
  303     operator()<double, double, 
double>(
double& y, 
const double& x0, 
const double& x1) 
const 
  309     __host__ __device__ constexpr 
void 
  310     operator()<float, float, 
float>(
float& y, 
const float& x0, 
const float& x1) 
const 
  316     __host__ __device__ constexpr 
void 
  319         y = type_convert<int8_t>(
alpha_ * type_convert<float>(x0) +
 
  320                                  beta_ * type_convert<float>(x1));
 
  324     __host__ __device__ constexpr 
void 
  327         y = type_convert<half_t>(
alpha_) * x0 + type_convert<half_t>(
beta_) * x1;
 
  331     __host__ __device__ constexpr 
void 
  334         y = type_convert<half_t>(
alpha_ * x0 + 
beta_ * ck::type_convert<float>(x1));
 
  338     __host__ __device__ constexpr 
void 
  341         const float x0_tmp = type_convert<float>(x0);
 
  342         const float x1_tmp = type_convert<float>(x1);
 
  343         const float y_tmp  = 
alpha_ * x0_tmp + 
beta_ * x1_tmp;
 
  344         y                  = type_convert<bhalf_t>(y_tmp);
 
  348     __host__ __device__ constexpr 
void 
  351         const float x1_tmp = ck::type_convert<float>(x1);
 
  357     __host__ __device__ constexpr 
void 
  360         y = type_convert<int8_t>(
alpha_ * type_convert<float>(x0) +
 
  361                                  beta_ * type_convert<float>(x1));
 
  370     static constexpr 
const char* 
name = 
"AddClamp";
 
  375     template <
typename Y, 
typename X0, 
typename X1>
 
  376     __host__ __device__ constexpr 
void operator()(Y& y, 
const X0& x0, 
const X1& x1) 
const;
 
  379     __host__ __device__ constexpr 
void 
  380     operator()<float, float, 
float>(
float& y, 
const float& x0, 
const float& x1) 
const 
  382         const float a = x0 + x1;
 
  387     __host__ __device__ constexpr 
void 
  388     operator()<double, double, 
double>(
double& y, 
const double& x0, 
const double& x1) 
const 
  390         const double a = x0 + x1;
 
  395     __host__ __device__ constexpr 
void 
  405     __host__ __device__ constexpr 
void 
  408         const float a = x0 + type_convert<float>(x1);
 
  410         y             = type_convert<half_t>(b);
 
  414     __host__ __device__ constexpr 
void 
  415     operator()<float, float, 
half_t>(
float& y, 
const float& x0, 
const half_t& x1) 
const 
  417         const float a = x0 + type_convert<float>(x1);
 
  422     __host__ __device__ constexpr 
void 
  425         const float a = x0 + type_convert<float>(x1);
 
  427         y             = type_convert<bhalf_t>(b);
 
  431     __host__ __device__ constexpr 
void 
  434         const float a = type_convert<float>(x0) + type_convert<float>(x1);
 
  436         y             = type_convert<bhalf_t>(b);
 
  440     __host__ __device__ constexpr 
void 
  441     operator()<int, int, 
int8_t>(
int& y, 
const int& x0, 
const int8_t& x1) 
const 
  448     __host__ __device__ constexpr 
void 
  461     static constexpr 
const char* 
name = 
"AddRelu";
 
  463     template <
typename Y, 
typename X0, 
typename X1>
 
  464     __host__ __device__ constexpr 
void operator()(Y& y, 
const X0& x0, 
const X1& x1) 
const;
 
  467     __host__ __device__ constexpr 
void 
  468     operator()<float, float, 
float>(
float& y, 
const float& x0, 
const float& x1) 
const 
  470         const float a = x0 + x1;
 
  471         y             = 
a > 0.0f ? 
a : 0.0f;
 
  475     __host__ __device__ constexpr 
void 
  476     operator()<double, double, 
double>(
double& y, 
const double& x0, 
const double& x1) 
const 
  478         const double a = x0 + x1;
 
  479         y              = 
a > 0.0 ? 
a : 0.0;
 
  483     __host__ __device__ constexpr 
void 
  487         y              = 
a > type_convert<half_t>(0.0f) ? 
a : type_convert<half_t>(0.0f);
 
  491     __host__ __device__ constexpr 
void 
  494         const float a = x0 + type_convert<float>(x1);
 
  495         const float b = 
a > 0.0f ? 
a : 0.0f;
 
  496         y             = type_convert<half_t>(b);
 
  500     __host__ __device__ constexpr 
void 
  501     operator()<float, float, 
half_t>(
float& y, 
const float& x0, 
const half_t& x1) 
const 
  503         const float a = x0 + type_convert<float>(x1);
 
  504         y             = 
a > 0.0f ? 
a : 0.0f;
 
  508     __host__ __device__ constexpr 
void 
  511         const float a = x0 + type_convert<float>(x1);
 
  512         const float b = 
a > 0.0f ? 
a : 0.0f;
 
  513         y             = type_convert<bhalf_t>(b);
 
  517     __host__ __device__ constexpr 
void 
  520         const float a = type_convert<float>(x0) + type_convert<float>(x1);
 
  521         const float b = 
a > 0.0f ? 
a : 0.0f;
 
  522         y             = type_convert<bhalf_t>(b);
 
  526     __host__ __device__ constexpr 
void 
  527     operator()<int, int, 
int8_t>(
int& y, 
const int& x0, 
const int8_t& x1) 
const 
  534     __host__ __device__ constexpr 
void 
  544     static constexpr 
const char* 
name = 
"AddHardswish";
 
  546     template <
typename T>
 
  547     __host__ __device__ constexpr 
void operator()(T& y, 
const T& x0, 
const T& x1) 
const;
 
  550     __host__ __device__ constexpr 
void 
  551     operator()<
float>(
float& y, 
const float& x0, 
const float& x1) 
const 
  554         float b = 
a + 
float{3};
 
  555         float c = (b > 0) * (b > 6.0f ? 6.0f : b) * 
a * 0.166667f;
 
  560     __host__ __device__ constexpr 
void 
  561     operator()<
double>(
double& y, 
const double& x0, 
const double& x1) 
const 
  565         double c = (b > 0) * (b > 6.0 ? 6.0 : b) * 
a * 0.166667;
 
  570     __host__ __device__ constexpr 
void 
  575         float c = (b > 0) * (b > 6.0f ? 6.0f : b) * 
a * 0.166667f;
 
  583     static constexpr 
const char* 
name = 
"AddFastGelu";
 
  585     template <
typename E, 
typename C, 
typename D>
 
  586     __host__ __device__ constexpr 
void operator()(E& e, 
const C& c, 
const D& d) 
const;
 
  589     __host__ __device__ constexpr 
void 
  590     operator()<float, float, 
float>(
float& e, 
const float& c, 
const float& d) 
const 
  592         const float x = c + d;
 
  594         FastGelu{}.template operator()<float, 
float>(e, x);
 
  598     __host__ __device__ constexpr 
void 
  607     __host__ __device__ constexpr 
void 
  610         const float x0_f = c + d;
 
  617         e = type_convert<half_t>(x1_f);
 
  621     __host__ __device__ constexpr 
void 
  624         const float x0_f = type_convert<float>(c) + type_convert<float>(d);
 
  628         FastGelu{}.template operator()<float, 
float>(x1_f, x0_f);
 
  630         e = type_convert<bhalf_t>(x1_f);
 
  634     __host__ __device__ constexpr 
void 
  637         const float x0_f = c + type_convert<float>(d);
 
  641         FastGelu{}.template operator()<float, 
float>(x1_f, x0_f);
 
  643         e = type_convert<bhalf_t>(x1_f);
 
  650     static constexpr 
const char* 
name = 
"MultiplyFastGelu";
 
  652     template <
typename E, 
typename C, 
typename D>
 
  653     __host__ __device__ constexpr 
void operator()(E& e, 
const C& c, 
const D& d) 
const;
 
  656     __host__ __device__ constexpr 
void 
  657     operator()<float, float, 
float>(
float& e, 
const float& c, 
const float& d) 
const 
  659         const float x = c * d;
 
  661         FastGelu{}.template operator()<float, 
float>(e, x);
 
  665     __host__ __device__ constexpr 
void 
  674     __host__ __device__ constexpr 
void 
  677         const float x0_f = c * d;
 
  684         e = type_convert<half_t>(x1_f);
 
  688     __host__ __device__ constexpr 
void 
  691         const float x0_f = type_convert<float>(c) * type_convert<float>(d);
 
  695         FastGelu{}.template operator()<float, 
float>(x1_f, x0_f);
 
  697         e = type_convert<bhalf_t>(x1_f);
 
  701     __host__ __device__ constexpr 
void 
  704         const float x0_f = c * type_convert<float>(d);
 
  708         FastGelu{}.template operator()<float, 
float>(x1_f, x0_f);
 
  710         e = type_convert<bhalf_t>(x1_f);
 
  717     static constexpr 
const char* 
name = 
"AddSilu";
 
  719     template <
typename E, 
typename C, 
typename D>
 
  720     __host__ __device__ constexpr 
void operator()(E& e, 
const C& c, 
const D& d) 
const;
 
  723     __host__ __device__ constexpr 
void 
  724     operator()<float, float, 
float>(
float& e, 
const float& c, 
const float& d) 
const 
  726         const float x = c + d;
 
  728         Silu{}.template operator()<
float>(e, x);
 
  732     __host__ __device__ constexpr 
void 
  741     __host__ __device__ constexpr 
void 
  744         const float x0_f = c + d;
 
  748         Silu{}.template operator()<
float>(x1_f, x0_f);
 
  750         e = type_convert<half_t>(x1_f);
 
  754     __host__ __device__ constexpr 
void 
  757         const float x0_f = c + type_convert<float>(d);
 
  761         Silu{}.template operator()<
float>(x1_f, x0_f);
 
  763         e = type_convert<bhalf_t>(x1_f);
 
  769     static constexpr 
const char* 
name = 
"ConvScaleAdd";
 
  772                                      float scale_wei = 1.f,
 
  773                                      float scale_out = 1.f)
 
  778     template <
typename E, 
typename C, 
typename D>
 
  779     __host__ __device__ 
void operator()(E& e, 
const C& c, 
const D& d) 
const;
 
  782     __host__ __device__ 
void 
  783     operator()<
f8_t, float, 
float>(
f8_t& e, 
const float& c, 
const float& d) 
const 
__host__ T ceil(T x)
Definition: math_v2.hpp:331
 
__host__ constexpr __device__ T max(T x)
Definition: math.hpp:84
 
__host__ T floor(T x)
Definition: math_v2.hpp:367
 
__host__ constexpr __device__ T min(T x)
Definition: math.hpp:116
 
f8_fnuz_t f8_t
Definition: amd_ck_fp8.hpp:1762
 
_Float16 half_t
Definition: data_type.hpp:31
 
ushort bhalf_t
Definition: data_type.hpp:30
 
const GenericPointer< typename T::ValueType > T2 T::AllocatorType & a
Definition: pointer.h:1517
 
signed int int32_t
Definition: stdint.h:123
 
signed char int8_t
Definition: stdint.h:121
 
Definition: numeric_limits.hpp:309
 
Definition: amd_ck_fp8.hpp:36
 
Definition: binary_element_wise_operation.hpp:369
 
static constexpr const char * name
Definition: binary_element_wise_operation.hpp:370
 
AddClamp(float floor=0.f, float ceil=NumericLimits< float >::Max())
Definition: binary_element_wise_operation.hpp:372
 
const float ceil_
Definition: binary_element_wise_operation.hpp:456
 
__host__ constexpr __device__ void operator()(Y &y, const X0 &x0, const X1 &x1) const
 
const float floor_
Definition: binary_element_wise_operation.hpp:453
 
Definition: binary_element_wise_operation.hpp:582
 
__host__ constexpr __device__ void operator()(E &e, const C &c, const D &d) const
 
static constexpr const char * name
Definition: binary_element_wise_operation.hpp:583
 
Definition: binary_element_wise_operation.hpp:543
 
__host__ constexpr __device__ void operator()(T &y, const T &x0, const T &x1) const
 
static constexpr const char * name
Definition: binary_element_wise_operation.hpp:544
 
Definition: binary_element_wise_operation.hpp:14
 
static constexpr const char * name
Definition: binary_element_wise_operation.hpp:15
 
__host__ constexpr __device__ void operator()(Y &y, const X0 &x0, const X1 &x1) const
 
Definition: binary_element_wise_operation.hpp:460
 
static constexpr const char * name
Definition: binary_element_wise_operation.hpp:461
 
__host__ constexpr __device__ void operator()(Y &y, const X0 &x0, const X1 &x1) const
 
Definition: binary_element_wise_operation.hpp:716
 
__host__ constexpr __device__ void operator()(E &e, const C &c, const D &d) const
 
static constexpr const char * name
Definition: binary_element_wise_operation.hpp:717
 
Definition: binary_element_wise_operation.hpp:293
 
Bilinear(float alpha=1.f, float beta=1.f)
Definition: binary_element_wise_operation.hpp:296
 
static constexpr const char * name
Definition: binary_element_wise_operation.hpp:294
 
__host__ constexpr __device__ void operator()(Y &, const X0 &, const X1 &) const
 
float beta_
Definition: binary_element_wise_operation.hpp:365
 
float alpha_
Definition: binary_element_wise_operation.hpp:362
 
Definition: binary_element_wise_operation.hpp:768
 
float scale_in_
Definition: binary_element_wise_operation.hpp:788
 
float scale_wei_
Definition: binary_element_wise_operation.hpp:791
 
__host__ __device__ ConvScaleAdd(float scale_in=1.f, float scale_wei=1.f, float scale_out=1.f)
Definition: binary_element_wise_operation.hpp:771
 
float scale_out_
Definition: binary_element_wise_operation.hpp:792
 
static constexpr const char * name
Definition: binary_element_wise_operation.hpp:769
 
__host__ __device__ void operator()(E &e, const C &c, const D &d) const
 
Definition: unary_element_wise_operation.hpp:924
 
Definition: binary_element_wise_operation.hpp:98
 
static constexpr const char * name
Definition: binary_element_wise_operation.hpp:99
 
__host__ __device__ void operator()(Y &y, const X0 &x0, const X1 &x1) const
Definition: binary_element_wise_operation.hpp:102
 
Definition: binary_element_wise_operation.hpp:111
 
static constexpr const char * name
Definition: binary_element_wise_operation.hpp:112
 
__host__ __device__ void operator()(Y &y, const X0 &x0, const X1 &x1) const
Definition: binary_element_wise_operation.hpp:115
 
Definition: binary_element_wise_operation.hpp:649
 
__host__ constexpr __device__ void operator()(E &e, const C &c, const D &d) const
 
static constexpr const char * name
Definition: binary_element_wise_operation.hpp:650
 
Definition: binary_element_wise_operation.hpp:124
 
static constexpr const char * name
Definition: binary_element_wise_operation.hpp:125
 
__host__ constexpr __device__ void operator()(Y &y, const X0 &x0, const X1 &x1) const
 
Definition: binary_element_wise_operation.hpp:218
 
__host__ constexpr __device__ void operator()(Y &y, const X0 &x0, const X1 &x1) const
Definition: binary_element_wise_operation.hpp:224
 
float scale_
Definition: binary_element_wise_operation.hpp:241
 
__host__ __device__ ScaleAdd(float scale=1.f)
Definition: binary_element_wise_operation.hpp:221
 
static constexpr const char * name
Definition: binary_element_wise_operation.hpp:219
 
Definition: unary_element_wise_operation.hpp:1087
 
Definition: binary_element_wise_operation.hpp:247
 
static constexpr const char * name
Definition: binary_element_wise_operation.hpp:248
 
__host__ constexpr __device__ void operator()(T &y, const T &x0, const T &x1) const