10 namespace tensor_operation {
11 namespace element_wise {
15 template <
typename Y,
typename X0,
typename X1>
16 __host__ __device__ constexpr
void operator()(Y& y,
const X0& x0,
const X1& x1)
const;
19 __host__ __device__ constexpr
void
20 operator()<
float>(
float& y,
const float& x0,
const float& x1)
const
26 __host__ __device__ constexpr
void
27 operator()<
double>(
double& y,
const double& x0,
const double& x1)
const
33 __host__ __device__ constexpr
void
34 operator()<
float>(
float& y,
const float& x0,
const half_t& x1)
const
36 y = x0 + type_convert<half_t>(x1);
40 __host__ __device__ constexpr
void
41 operator()<
half_t>(
half_t& y,
const float& x0,
const float& x1)
const
43 y = type_convert<half_t>(x0 + x1);
47 __host__ __device__ constexpr
void
50 y = type_convert<half_t>(x0) + x1;
54 __host__ __device__ constexpr
void
61 __host__ __device__ constexpr
void
62 operator()<
float>(
float& y,
const float& x0,
const bhalf_t& x1)
const
64 const float x1_tmp = ck::type_convert<float>(x1);
69 __host__ __device__ constexpr
void
72 const float x1_tmp = ck::type_convert<float>(x0);
73 const float x2_tmp = ck::type_convert<float>(x1);
74 const float y_tmp = x1_tmp + x2_tmp;
75 y = ck::type_convert<bhalf_t>(y_tmp);
79 __host__ __device__ constexpr
void
82 const float x2_tmp = ck::type_convert<float>(x1);
83 const float y_tmp = x0 + x2_tmp;
84 y = ck::type_convert<bhalf_t>(y_tmp);
88 __host__ __device__ constexpr
void
97 template <
typename Y,
typename X0,
typename X1>
98 __host__ __device__
void operator()(Y& y,
const X0& x0,
const X1& x1)
const
100 const Y x0_converted = type_convert<Y>(x0);
101 const Y x1_converted = type_convert<Y>(x1);
108 template <
typename Y,
typename X0,
typename X1>
109 __host__ __device__
void operator()(Y& y,
const X0& x0,
const X1& x1)
const
111 const Y x0_converted = type_convert<Y>(x0);
112 const Y x1_converted = type_convert<Y>(x1);
119 template <
typename Y,
typename X0,
typename X1>
120 __host__ __device__ constexpr
void operator()(Y& y,
const X0& x0,
const X1& x1)
const;
123 __host__ __device__ constexpr
void
124 operator()<
float>(
float& y,
const float& x0,
const float& x1)
const
130 __host__ __device__ constexpr
void
131 operator()<
double>(
double& y,
const double& x0,
const double& x1)
const
137 __host__ __device__ constexpr
void
138 operator()<
float>(
float& y,
const float& x0,
const half_t& x1)
const
140 y = x0 * type_convert<half_t>(x1);
144 __host__ __device__ constexpr
void
147 y = type_convert<half_t>(x0 * x1);
151 __host__ __device__ constexpr
void
154 y = type_convert<half_t>(x0) * x1;
158 __host__ __device__ constexpr
void
165 __host__ __device__ constexpr
void
166 operator()<
float>(
float& y,
const float& x0,
const bhalf_t& x1)
const
168 const float x1_tmp = ck::type_convert<float>(x1);
173 __host__ __device__ constexpr
void
176 const float x1_tmp = ck::type_convert<float>(x0);
177 const float x2_tmp = ck::type_convert<float>(x1);
178 const float y_tmp = x1_tmp * x2_tmp;
179 y = ck::type_convert<bhalf_t>(y_tmp);
183 __host__ __device__ constexpr
void
186 const float x1_tmp = ck::type_convert<float>(x0);
187 const float x2_tmp = ck::type_convert<float>(x1);
188 const float y_tmp = x1_tmp * x2_tmp;
189 y = ck::type_convert<bhalf_t>(y_tmp);
193 __host__ __device__ constexpr
void
196 const float x2_tmp = ck::type_convert<float>(x1);
197 const float y_tmp = x0 * x2_tmp;
198 y = ck::type_convert<bhalf_t>(y_tmp);
202 __host__ __device__ constexpr
void
213 template <
typename Y,
typename X0,
typename X1>
214 __host__ __device__ constexpr
void operator()(Y& y,
const X0& x0,
const X1& x1)
const
216 y = ck::type_convert<Y>(
scale_ * ck::type_convert<float>(x0) + ck::type_convert<float>(x1));
220 __host__ __device__
void
221 operator()<float, float,
half_t>(
float& y,
const float& x0,
const half_t& x1)
const
223 y =
scale_ * x0 + ck::type_convert<float>(x1);
227 __host__ __device__
void
228 operator()<float, float,
bhalf_t>(
float& y,
const float& x0,
const bhalf_t& x1)
const
230 y =
scale_ * x0 + ck::type_convert<float>(x1);
238 template <
typename T>
239 __host__ __device__ constexpr
void operator()(T& y,
const T& x0,
const T& x1)
const;
242 __host__ __device__ constexpr
void
243 operator()<
float>(
float& y,
const float& x0,
const float& x1)
const
249 __host__ __device__ constexpr
void
250 operator()<
double>(
double& y,
const double& x0,
const double& x1)
const
256 __host__ __device__ constexpr
void
263 __host__ __device__ constexpr
void
266 const float x1_tmp = ck::type_convert<float>(x0);
267 const float x2_tmp = ck::type_convert<float>(x1);
268 const float y_tmp = x1_tmp - x2_tmp;
269 y = ck::type_convert<bhalf_t>(y_tmp);
273 __host__ __device__ constexpr
void
284 template <
typename Y,
typename X0,
typename X1>
285 __host__ __device__ constexpr
void operator()(Y&,
const X0&,
const X1&)
const;
288 __host__ __device__ constexpr
void
289 operator()<double, double,
double>(
double& y,
const double& x0,
const double& x1)
const
295 __host__ __device__ constexpr
void
296 operator()<float, float,
float>(
float& y,
const float& x0,
const float& x1)
const
302 __host__ __device__ constexpr
void
305 y = type_convert<int8_t>(
alpha_ * type_convert<float>(x0) +
306 beta_ * type_convert<float>(x1));
310 __host__ __device__ constexpr
void
313 y = type_convert<half_t>(
alpha_) * x0 + type_convert<half_t>(
beta_) * x1;
317 __host__ __device__ constexpr
void
320 y = type_convert<half_t>(
alpha_ * x0 +
beta_ * ck::type_convert<float>(x1));
324 __host__ __device__ constexpr
void
327 const float x0_tmp = type_convert<float>(x0);
328 const float x1_tmp = type_convert<float>(x1);
329 const float y_tmp =
alpha_ * x0_tmp +
beta_ * x1_tmp;
330 y = type_convert<bhalf_t>(y_tmp);
334 __host__ __device__ constexpr
void
337 const float x1_tmp = ck::type_convert<float>(x1);
343 __host__ __device__ constexpr
void
346 y = type_convert<int8_t>(
alpha_ * type_convert<float>(x0) +
347 beta_ * type_convert<float>(x1));
359 template <
typename Y,
typename X0,
typename X1>
360 __host__ __device__ constexpr
void operator()(Y& y,
const X0& x0,
const X1& x1)
const;
363 __host__ __device__ constexpr
void
364 operator()<float, float,
float>(
float& y,
const float& x0,
const float& x1)
const
366 const float a = x0 + x1;
371 __host__ __device__ constexpr
void
372 operator()<double, double,
double>(
double& y,
const double& x0,
const double& x1)
const
374 const double a = x0 + x1;
379 __host__ __device__ constexpr
void
389 __host__ __device__ constexpr
void
392 const float a = x0 + type_convert<float>(x1);
394 y = type_convert<half_t>(b);
398 __host__ __device__ constexpr
void
399 operator()<float, float,
half_t>(
float& y,
const float& x0,
const half_t& x1)
const
401 const float a = x0 + type_convert<float>(x1);
406 __host__ __device__ constexpr
void
409 const float a = x0 + type_convert<float>(x1);
411 y = type_convert<bhalf_t>(b);
415 __host__ __device__ constexpr
void
418 const float a = type_convert<float>(x0) + type_convert<float>(x1);
420 y = type_convert<bhalf_t>(b);
424 __host__ __device__ constexpr
void
425 operator()<int, int,
int8_t>(
int& y,
const int& x0,
const int8_t& x1)
const
432 __host__ __device__ constexpr
void
445 template <
typename Y,
typename X0,
typename X1>
446 __host__ __device__ constexpr
void operator()(Y& y,
const X0& x0,
const X1& x1)
const;
449 __host__ __device__ constexpr
void
450 operator()<float, float,
float>(
float& y,
const float& x0,
const float& x1)
const
452 const float a = x0 + x1;
453 y =
a > 0.0f ?
a : 0.0f;
457 __host__ __device__ constexpr
void
458 operator()<double, double,
double>(
double& y,
const double& x0,
const double& x1)
const
460 const double a = x0 + x1;
461 y =
a > 0.0 ?
a : 0.0;
465 __host__ __device__ constexpr
void
469 y =
a > type_convert<half_t>(0.0f) ?
a : type_convert<half_t>(0.0f);
473 __host__ __device__ constexpr
void
476 const float a = x0 + type_convert<float>(x1);
477 const float b =
a > 0.0f ?
a : 0.0f;
478 y = type_convert<half_t>(b);
482 __host__ __device__ constexpr
void
483 operator()<float, float,
half_t>(
float& y,
const float& x0,
const half_t& x1)
const
485 const float a = x0 + type_convert<float>(x1);
486 y =
a > 0.0f ?
a : 0.0f;
490 __host__ __device__ constexpr
void
493 const float a = x0 + type_convert<float>(x1);
494 const float b =
a > 0.0f ?
a : 0.0f;
495 y = type_convert<bhalf_t>(b);
499 __host__ __device__ constexpr
void
502 const float a = type_convert<float>(x0) + type_convert<float>(x1);
503 const float b =
a > 0.0f ?
a : 0.0f;
504 y = type_convert<bhalf_t>(b);
508 __host__ __device__ constexpr
void
509 operator()<int, int,
int8_t>(
int& y,
const int& x0,
const int8_t& x1)
const
516 __host__ __device__ constexpr
void
526 template <
typename T>
527 __host__ __device__ constexpr
void operator()(T& y,
const T& x0,
const T& x1)
const;
530 __host__ __device__ constexpr
void
531 operator()<
float>(
float& y,
const float& x0,
const float& x1)
const
534 float b =
a +
float{3};
535 float c = (b > 0) * (b > 6.0f ? 6.0f : b) *
a * 0.166667f;
540 __host__ __device__ constexpr
void
541 operator()<
double>(
double& y,
const double& x0,
const double& x1)
const
545 double c = (b > 0) * (b > 6.0 ? 6.0 : b) *
a * 0.166667;
550 __host__ __device__ constexpr
void
555 float c = (b > 0) * (b > 6.0f ? 6.0f : b) *
a * 0.166667f;
563 template <
typename E,
typename C,
typename D>
564 __host__ __device__ constexpr
void operator()(E& e,
const C& c,
const D& d)
const;
567 __host__ __device__ constexpr
void
568 operator()<float, float,
float>(
float& e,
const float& c,
const float& d)
const
570 const float x = c + d;
572 FastGelu{}.template operator()<float,
float>(e, x);
576 __host__ __device__ constexpr
void
585 __host__ __device__ constexpr
void
588 const float x0_f = c + d;
595 e = type_convert<half_t>(x1_f);
599 __host__ __device__ constexpr
void
602 const float x0_f = type_convert<float>(c) + type_convert<float>(d);
606 FastGelu{}.template operator()<float,
float>(x1_f, x0_f);
608 e = type_convert<bhalf_t>(x1_f);
612 __host__ __device__ constexpr
void
615 const float x0_f = c + type_convert<float>(d);
619 FastGelu{}.template operator()<float,
float>(x1_f, x0_f);
621 e = type_convert<bhalf_t>(x1_f);
628 template <
typename E,
typename C,
typename D>
629 __host__ __device__ constexpr
void operator()(E& e,
const C& c,
const D& d)
const;
632 __host__ __device__ constexpr
void
633 operator()<float, float,
float>(
float& e,
const float& c,
const float& d)
const
635 const float x = c * d;
637 FastGelu{}.template operator()<float,
float>(e, x);
641 __host__ __device__ constexpr
void
650 __host__ __device__ constexpr
void
653 const float x0_f = c * d;
660 e = type_convert<half_t>(x1_f);
664 __host__ __device__ constexpr
void
667 const float x0_f = type_convert<float>(c) * type_convert<float>(d);
671 FastGelu{}.template operator()<float,
float>(x1_f, x0_f);
673 e = type_convert<bhalf_t>(x1_f);
677 __host__ __device__ constexpr
void
680 const float x0_f = c * type_convert<float>(d);
684 FastGelu{}.template operator()<float,
float>(x1_f, x0_f);
686 e = type_convert<bhalf_t>(x1_f);
693 template <
typename E,
typename C,
typename D>
694 __host__ __device__ constexpr
void operator()(E& e,
const C& c,
const D& d)
const;
697 __host__ __device__ constexpr
void
698 operator()<float, float,
float>(
float& e,
const float& c,
const float& d)
const
700 const float x = c + d;
702 Silu{}.template operator()<
float>(e, x);
706 __host__ __device__ constexpr
void
715 __host__ __device__ constexpr
void
718 const float x0_f = c + d;
722 Silu{}.template operator()<
float>(x1_f, x0_f);
724 e = type_convert<half_t>(x1_f);
728 __host__ __device__ constexpr
void
731 const float x0_f = c + type_convert<float>(d);
735 Silu{}.template operator()<
float>(x1_f, x0_f);
737 e = type_convert<bhalf_t>(x1_f);
744 float scale_wei = 1.f,
745 float scale_out = 1.f)
750 template <
typename E,
typename C,
typename D>
751 __host__ __device__
void operator()(E& e,
const C& c,
const D& d)
const;
754 __host__ __device__
void
755 operator()<
f8_t, float,
float>(
f8_t& e,
const float& c,
const float& d)
const
__host__ T ceil(T x)
Definition: math_v2.hpp:331
__host__ constexpr __device__ T max(T x)
Definition: math.hpp:84
__host__ T floor(T x)
Definition: math_v2.hpp:367
__host__ constexpr __device__ T min(T x)
Definition: math.hpp:116
f8_fnuz_t f8_t
Definition: amd_ck_fp8.hpp:1737
_Float16 half_t
Definition: data_type.hpp:30
ushort bhalf_t
Definition: data_type.hpp:29
const GenericPointer< typename T::ValueType > T2 T::AllocatorType & a
Definition: pointer.h:1249
signed int int32_t
Definition: stdint.h:123
signed char int8_t
Definition: stdint.h:121
Definition: numeric_limits.hpp:309
Definition: binary_element_wise_operation.hpp:355
AddClamp(float floor=0.f, float ceil=NumericLimits< float >::Max())
Definition: binary_element_wise_operation.hpp:356
const float ceil_
Definition: binary_element_wise_operation.hpp:440
__host__ constexpr __device__ void operator()(Y &y, const X0 &x0, const X1 &x1) const
const float floor_
Definition: binary_element_wise_operation.hpp:437
Definition: binary_element_wise_operation.hpp:562
__host__ constexpr __device__ void operator()(E &e, const C &c, const D &d) const
Definition: binary_element_wise_operation.hpp:525
__host__ constexpr __device__ void operator()(T &y, const T &x0, const T &x1) const
Definition: binary_element_wise_operation.hpp:14
__host__ constexpr __device__ void operator()(Y &y, const X0 &x0, const X1 &x1) const
Definition: binary_element_wise_operation.hpp:444
__host__ constexpr __device__ void operator()(Y &y, const X0 &x0, const X1 &x1) const
Definition: binary_element_wise_operation.hpp:692
__host__ constexpr __device__ void operator()(E &e, const C &c, const D &d) const
Definition: binary_element_wise_operation.hpp:281
Bilinear(float alpha=1.f, float beta=1.f)
Definition: binary_element_wise_operation.hpp:282
__host__ constexpr __device__ void operator()(Y &, const X0 &, const X1 &) const
float beta_
Definition: binary_element_wise_operation.hpp:351
float alpha_
Definition: binary_element_wise_operation.hpp:348
Definition: binary_element_wise_operation.hpp:742
float scale_in_
Definition: binary_element_wise_operation.hpp:760
float scale_wei_
Definition: binary_element_wise_operation.hpp:763
__host__ __device__ ConvScaleAdd(float scale_in=1.f, float scale_wei=1.f, float scale_out=1.f)
Definition: binary_element_wise_operation.hpp:743
float scale_out_
Definition: binary_element_wise_operation.hpp:764
__host__ __device__ void operator()(E &e, const C &c, const D &d) const
Definition: unary_element_wise_operation.hpp:892
Definition: binary_element_wise_operation.hpp:96
__host__ __device__ void operator()(Y &y, const X0 &x0, const X1 &x1) const
Definition: binary_element_wise_operation.hpp:98
Definition: binary_element_wise_operation.hpp:107
__host__ __device__ void operator()(Y &y, const X0 &x0, const X1 &x1) const
Definition: binary_element_wise_operation.hpp:109
Definition: binary_element_wise_operation.hpp:627
__host__ constexpr __device__ void operator()(E &e, const C &c, const D &d) const
Definition: binary_element_wise_operation.hpp:118
__host__ constexpr __device__ void operator()(Y &y, const X0 &x0, const X1 &x1) const
Definition: binary_element_wise_operation.hpp:210
__host__ constexpr __device__ void operator()(Y &y, const X0 &x0, const X1 &x1) const
Definition: binary_element_wise_operation.hpp:214
float scale_
Definition: binary_element_wise_operation.hpp:231
__host__ __device__ ScaleAdd(float scale=1.f)
Definition: binary_element_wise_operation.hpp:211
Definition: unary_element_wise_operation.hpp:1049
Definition: binary_element_wise_operation.hpp:237
__host__ constexpr __device__ void operator()(T &y, const T &x0, const T &x1) const