/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/element/element_wise_operation.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/element/element_wise_operation.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/element/element_wise_operation.hpp Source File
element_wise_operation.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
7 #include "ck/utility/math_v2.hpp"
11 
12 namespace ck {
13 namespace tensor_operation {
14 namespace element_wise {
15 
16 // Need to ensure compiler will fail if there is no matching candidate, instead of compiler
17 // siliently do implicit type conversion
18 //
19 // Example:
20 //
21 // struct ExampleElementwiseOp
22 // {
23 // template<typename Y, typename X>
24 // __host__ __device__ constexpr void
25 // operator()(Y&, const X) const;
26 //
27 // template<>
28 // __host__ __device__ constexpr void
29 // operator()<half_t, half_t>(half_t& y, const half_t& x) const
30 // {
31 // }
32 // };
33 
34 struct AddReluAdd
35 {
36  static constexpr const char* name = "AddReluAdd";
37 
38  template <typename Y, typename X0, typename X1, typename X2>
39  __host__ __device__ constexpr void operator()(Y&, const X0&, const X1&, const X2&) const;
40 
41  template <>
42  __host__ __device__ constexpr void operator()<half_t, half_t, half_t, half_t>(
43  half_t& y, const half_t& x0, const half_t& x1, const half_t& x2) const
44  {
45  half_t a = x0 + x1;
46  half_t b = a > 0 ? a : 0;
47  y = b + x2;
48  }
49 
50  template <>
51  __host__ __device__ constexpr void operator()<float, float, float, float>(float& y,
52  const float& x0,
53  const float& x1,
54  const float& x2) const
55  {
56  float a = x0 + x1;
57  float b = a > 0 ? a : 0;
58  float c = b + x2;
59  y = c;
60  }
61 
62  template <>
63  __host__ __device__ constexpr void operator()<float, float, half_t, half_t>(
64  float& y, const float& x0, const half_t& x1, const half_t& x2) const
65  {
66  float a = x0 + x1;
67  float b = a > 0 ? a : 0;
68  float c = b + x2;
69  y = c;
70  }
71 
72  template <>
73  __host__ __device__ constexpr void operator()<half_t, float, half_t, half_t>(
74  half_t& y, const float& x0, const half_t& x1, const half_t& x2) const
75  {
76  float y_float;
77  (*this)(y_float, x0, x1, x2);
78  y = y_float;
79  }
80 
81  template <>
82  __host__ __device__ constexpr void operator()<bhalf_t, float, bhalf_t, bhalf_t>(
83  bhalf_t& y, const float& x0, const bhalf_t& x1, const bhalf_t& x2) const
84  {
85  float a = x0 + x1;
86  float b = a > 0 ? a : 0;
87  float c = b + x2;
88  y = c;
89  }
90 
91  template <>
92  __host__ __device__ constexpr void operator()<int8_t, int8_t, int8_t, int8_t>(
93  int8_t& y, const int8_t& x0, const int8_t& x1, const int8_t& x2) const
94  {
95  int32_t a = x0 + x1;
96  int32_t b = a > 0 ? a : 0;
97  int32_t c = b + x2;
98  y = c;
99  }
100 
101 #ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
102  template <>
103  __host__ __device__ constexpr void operator()<int4_t, int8_t, int4_t, int4_t>(
104  int4_t& y, const int8_t& x0, const int4_t& x1, const int4_t& x2) const
105  {
106  int32_t a = x0 + x1;
107  int32_t b = a > 0 ? a : 0;
108  int32_t c = b + x2;
109  y = c;
110  }
111 #endif // CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
112 };
113 
115 {
116  static constexpr const char* name = "AddHardswishAdd";
117 
118  template <typename Y, typename X0, typename X1, typename X2>
119  __host__ __device__ constexpr void operator()(Y&, const X0&, const X1&, const X2&) const;
120 
121  template <>
122  __host__ __device__ constexpr void operator()<float, float, float, float>(float& y,
123  const float& x0,
124  const float& x1,
125  const float& x2) const
126  {
127  float a = x0 + x1;
128  float b = a + float{3};
129  float c = (b > 0) * (b > float{6} ? float{6} : b) * a * float{0.166667};
130  float d = c + x2;
131  y = d;
132  }
133 
134  template <>
135  __host__ __device__ constexpr void operator()<half_t, half_t, half_t, half_t>(
136  half_t& y, const half_t& x0, const half_t& x1, const half_t& x2) const
137  {
138  float a = x0 + x1;
139  float b = a + float{3};
140  float c = (b > 0) * (b > float{6} ? float{6} : b) * a * float{0.166667};
141  float d = c + x2;
142  y = d;
143  }
144 };
145 
146 // C = A * B
147 // E = C + D0 + D1
148 struct AddAdd
149 {
150  static constexpr const char* name = "AddAdd";
151 
152  template <typename E, typename C, typename D0, typename D1>
153  __host__ __device__ void operator()(E& e, const C& c, const D0& d0, const D1& d1) const
154  {
155  // Only support floating so far
158  "Data type is not supported by this operation!");
159 
162  "Data type is not supported by this operation!");
163 
166  "Data type is not supported by this operation!");
167 
170  "Data type is not supported by this operation!");
171 
172  const C y = c + type_convert<C>(d0) + type_convert<C>(d1);
173  e = type_convert<E>(y);
174  }
175 };
176 
177 // C = A * B
178 // E = (C + D0) x D1
180 {
181  static constexpr const char* name = "AddMultiply";
182 
183  template <typename E, typename C, typename D0, typename D1>
184  __host__ __device__ void operator()(E& e, const C& c, const D0& d0, const D1& d1) const;
185 
186  template <>
187  __host__ __device__ void operator()<half_t, half_t, half_t, half_t>(half_t& e,
188  const half_t& c,
189  const half_t& d0,
190  const half_t& d1) const
191  {
192  const half_t y = (c + d0) * d1;
193  e = y;
194  }
195  template <>
196  __host__ __device__ void operator()<half_t, float, half_t, half_t>(half_t& e,
197  const float& c,
198  const half_t& d0,
199  const half_t& d1) const
200  {
201  const half_t y = (type_convert<half_t>(c) + d0) * d1;
202  e = y;
203  }
204  template <>
205  __host__ __device__ void operator()<float, float, half_t, half_t>(float& e,
206  const float& c,
207  const half_t& d0,
208  const half_t& d1) const
209  {
210  const float y = (c + d0) * d1;
211  e = y;
212  }
213 };
214 
215 // C = A * B
216 // E = C x D0 + D1
218 {
219  static constexpr const char* name = "MultiplyAdd";
220 
221  template <typename E, typename C, typename D0, typename D1>
222  __host__ __device__ void operator()(E& e, const C& c, const D0& d0, const D1& d1) const;
223 
224  template <>
225  __host__ __device__ void operator()<half_t, half_t, half_t, half_t>(half_t& e,
226  const half_t& c,
227  const half_t& d0,
228  const half_t& d1) const
229  {
230  const half_t y = (c * d0) + d1;
231  e = y;
232  }
233  template <>
234  __host__ __device__ void operator()<half_t, float, half_t, half_t>(half_t& e,
235  const float& c,
236  const half_t& d0,
237  const half_t& d1) const
238  {
239  const half_t y = type_convert<half_t>(c) * d0 + d1;
240  e = y;
241  }
242  template <>
243  __host__ __device__ void operator()<bhalf_t, float, bhalf_t, bhalf_t>(bhalf_t& e,
244  const float& c,
245  const bhalf_t& d0,
246  const bhalf_t& d1) const
247  {
248  const bhalf_t y = type_convert<bhalf_t>(c) * d0 + d1;
249  e = y;
250  }
251  template <>
252  __host__ __device__ void operator()<float, float, half_t, half_t>(float& e,
253  const float& c,
254  const half_t& d0,
255  const half_t& d1) const
256  {
257  const float y = c * d0 + d1;
258  e = y;
259  }
260  template <>
261  __host__ __device__ void operator()<half_t, float, float, float>(half_t& e,
262  const float& c,
263  const float& d0,
264  const float& d1) const
265  {
266  const float y = c * d0 + d1;
267  e = y;
268  }
269 };
270 
272 {
273  static constexpr const char* name = "MultiplyMultiply";
274 
275  template <typename E, typename C, typename D0, typename D1>
276  __host__ __device__ constexpr void
277  operator()(E& e, const C& c, const D0& d0, const D1& d1) const;
278 
279  template <>
280  __host__ __device__ constexpr void operator()<ck::half_t, float, float, float>(
281  ck::half_t& e, const float& c, const float& d0, const float& d1) const
282  {
283  const float x0_f = c * d0 * d1;
284 
285  e = ck::type_convert<ck::half_t>(x0_f);
286  }
287 
288  template <>
289  __host__ __device__ constexpr void operator()<ck::bhalf_t, float, float, float>(
290  ck::bhalf_t& e, const float& c, const float& d0, const float& d1) const
291  {
292  const float x0_f = c * d0 * d1;
293 
294  e = ck::type_convert<ck::bhalf_t>(x0_f);
295  }
296 
297  template <>
298  __host__ __device__ constexpr void operator()<ck::half_t, int, ck::half_t, ck::half_t>(
299  ck::half_t& e, const int& c, const ck::half_t& d0, const ck::half_t& d1) const
300  {
301  const float x0_f =
302  ck::type_convert<float>(c) * ck::type_convert<float>(d0) * ck::type_convert<float>(d1);
303 
304  e = ck::type_convert<ck::half_t>(x0_f);
305  }
306 
307  template <>
308  __host__ __device__ constexpr void operator()<ck::half_t, int, float, float>(
309  ck::half_t& e, const int& c, const float& d0, const float& d1) const
310  {
311  const float x0_f =
312  ck::type_convert<float>(c) * ck::type_convert<float>(d0) * ck::type_convert<float>(d1);
313 
314  e = ck::type_convert<ck::half_t>(x0_f);
315  }
316 
317  template <>
318  __host__ __device__ constexpr void operator()<ck::bhalf_t, int, float, float>(
319  ck::bhalf_t& e, const int& c, const float& d0, const float& d1) const
320  {
321  const float x0_f =
322  ck::type_convert<float>(c) * ck::type_convert<float>(d0) * ck::type_convert<float>(d1);
323 
324  e = ck::type_convert<ck::bhalf_t>(x0_f);
325  }
326 };
327 
329 {
330  static constexpr const char* name = "MultiplyAddFastGelu";
331 
332  template <typename E, typename C, typename D0, typename D1>
333  __host__ __device__ constexpr void
334  operator()(E& e, const C& c, const D0& d0, const D1& d1) const;
335 
336  template <>
337  __host__ __device__ constexpr void operator()<ck::bhalf_t, float, ck::bhalf_t, ck::bhalf_t>(
338  ck::bhalf_t& e, const float& c, const ck::bhalf_t& d0, const ck::bhalf_t& d1) const
339  {
340  const float x0_f = c * ck::type_convert<float>(d0) + ck::type_convert<float>(d1);
341 
342  float x1_f = 0;
343 
344  FastGelu{}.template operator()<float, float>(x1_f, x0_f);
345 
346  e = ck::type_convert<ck::bhalf_t>(x1_f);
347  }
348 };
349 
350 // E = FastGelu(C + D0 + D1)
352 {
353  static constexpr const char* name = "AddAddFastGelu";
354 
355  template <typename E, typename C, typename D0, typename D1>
356  __host__ __device__ constexpr void
357  operator()(E& e, const C& c, const D0& d0, const D1& d1) const;
358 
359  template <>
360  __host__ __device__ constexpr void operator()<float, float, float, float>(float& e,
361  const float& c,
362  const float& d0,
363  const float& d1) const
364  {
365  const float x = c + d0 + d1;
366 
367  FastGelu{}.template operator()<float, float>(e, x);
368  }
369 
370  template <>
371  __host__ __device__ constexpr void operator()<half_t, half_t, half_t, half_t>(
372  half_t& e, const half_t& c, const half_t& d0, const half_t& d1) const
373  {
374  const half_t x = c + d0 + d1;
375 
376  ck::tensor_operation::element_wise::FastGelu{}.template operator()<half_t, half_t>(e, x);
377  }
378 
379  template <>
380  __host__ __device__ constexpr void operator()<half_t, float, half_t, half_t>(
381  half_t& e, const float& c, const half_t& d0, const half_t& d1) const
382  {
383  const float x0_f = c + d0 + d1;
384 
385  float x1_f = 0;
386 
387  ck::tensor_operation::element_wise::FastGelu{}.template operator()<float, float>(x1_f,
388  x0_f);
389 
390  e = type_convert<half_t>(x1_f);
391  }
392 
393  template <>
394  __host__ __device__ constexpr void operator()<bhalf_t, float, bhalf_t, bhalf_t>(
395  bhalf_t& e, const float& c, const bhalf_t& d0, const bhalf_t& d1) const
396  {
397  const float x0_f = c + type_convert<float>(d0) + type_convert<float>(d1);
398 
399  float x1_f = 0;
400 
401  ck::tensor_operation::element_wise::FastGelu{}.template operator()<float, float>(x1_f,
402  x0_f);
403 
404  e = type_convert<bhalf_t>(x1_f);
405  }
406 
407  template <>
408  __host__ __device__ constexpr void operator()<int8_t, int32_t, int8_t, int8_t>(
409  int8_t& e, const int32_t& c, const int8_t& d0, const int8_t& d1) const
410  {
411  const float x0_f =
412  type_convert<float>(c) + type_convert<float>(d0) + type_convert<float>(d1);
413 
414  float x1_f = 0;
415 
416  ck::tensor_operation::element_wise::FastGelu{}.template operator()<float, float>(x1_f,
417  x0_f);
418 
419  e = type_convert<int8_t>(x1_f);
420  }
421 };
422 
423 // E = Relu(alpha1 * C + alpha2 * D0 + D1)
425 {
426  static constexpr const char* name = "ScaleAddScaleAddRelu";
427 
428  ScaleAddScaleAddRelu(const float alpha1 = 1.f, const float alpha2 = 1.f)
429  : alpha1_(alpha1), alpha2_(alpha2)
430  {
431  }
432 
433  template <typename E, typename C, typename D0, typename D1>
434  __host__ __device__ constexpr void
435  operator()(E& e, const C& c, const D0& d0, const D1& d1) const;
436 
437  template <>
438  __host__ __device__ constexpr void operator()<float, float, float, float>(float& e,
439  const float& c,
440  const float& d0,
441  const float& d1) const
442  {
443  const float x = c * alpha1_ + alpha2_ * d0 + d1;
444  e = x > 0 ? x : 0;
445  }
446 
447  template <>
448  __host__ __device__ constexpr void operator()<half_t, half_t, half_t, half_t>(
449  half_t& e, const half_t& c, const half_t& d0, const half_t& d1) const
450  {
451  const float x = type_convert<float>(c) * alpha1_ + alpha2_ * type_convert<float>(d0) +
452  type_convert<float>(d1);
453 
454  float result = 0;
455  result = x > 0 ? x : 0;
456 
457  e = type_convert<half_t>(result);
458  }
459 
460  template <>
461  __host__ __device__ constexpr void operator()<bhalf_t, bhalf_t, bhalf_t, bhalf_t>(
462  bhalf_t& e, const bhalf_t& c, const bhalf_t& d0, const bhalf_t& d1) const
463  {
464  const float x = type_convert<float>(c) * alpha1_ + alpha2_ * type_convert<float>(d0) +
465  type_convert<float>(d1);
466 
467  float result = 0;
468  result = x > 0 ? x : 0;
469 
470  e = type_convert<bhalf_t>(result);
471  }
472 
473  template <>
474  __host__ __device__ constexpr void operator()<int8_t, int8_t, float, float>(
475  int8_t& e, const int8_t& c, const float& d0, const float& d1) const
476  {
477  const float x = type_convert<float>(c) * alpha1_ + alpha2_ * d0 + d1;
478 
479  float result = 0;
480  result = x > 0 ? x : 0;
481 
482  e = type_convert<int8_t>(result);
483  }
484 
485  const float alpha1_;
486  const float alpha2_;
487 };
488 
489 struct Normalize
490 {
491  static constexpr const char* name = "Normalize";
492 
493  // FIXME: is double absolutely necessary?
494  Normalize(double epsilon = 1e-4) : epsilon_(epsilon) {}
495 
496  template <typename T1, typename T2, typename T3>
497  __host__ __device__ constexpr void operator()(T1& y,
498  const T1& x,
499  const T2& mean,
500  const T2& mean_square,
501  const T3& gamma,
502  const T3& beta) const;
503 
504  template <>
505  __host__ __device__ constexpr void operator()<half_t, float, half_t>(half_t& y,
506  const half_t& x,
507  const float& mean,
508  const float& mean_square,
509  const half_t& gamma,
510  const half_t& beta) const
511  {
512  using ck::math::sqrt;
513 
514  float variance = mean_square - (mean * mean);
515 
516  float tmp_x = type_convert<float>(x);
517  float tmp_gamma = type_convert<float>(gamma);
518  float tmp_beta = type_convert<float>(beta);
519 
520  float tmp_y =
521  ((tmp_x - mean) / sqrt(variance + type_convert<float>(epsilon_))) * tmp_gamma +
522  tmp_beta;
523 
524  y = type_convert<half_t>(tmp_y);
525  };
526 
527  template <>
528  __host__ __device__ constexpr void operator()<float, float, float>(float& y,
529  const float& x,
530  const float& mean,
531  const float& mean_square,
532  const float& gamma,
533  const float& beta) const
534  {
535  using ck::math::sqrt;
536 
537  float variance = mean_square - (mean * mean);
538  y = ((x - mean) / sqrt(variance + type_convert<float>(epsilon_))) * gamma + beta;
539  };
540 
541  template <>
542  __host__ __device__ constexpr void operator()<double, double, double>(double& y,
543  const double& x,
544  const double& mean,
545  const double& mean_square,
546  const double& gamma,
547  const double& beta) const
548  {
549  using ck::math::sqrt;
550 
551  double variance = mean_square - (mean * mean);
552  y = ((x - mean) / sqrt(variance + epsilon_)) * gamma + beta;
553  };
554 
555  // FIXME: is double absolutely necessary?
556  double epsilon_;
557 };
558 
559 // used by BatchNorm inference
560 // y = gamma * (x-mean) / sqrt(epsilon+variance) + beta
561 // The data type of mean and variance is used as AccDataType
563 {
564  static constexpr const char* name = "NormalizeInInfer";
565 
566  NormalizeInInfer(double epsilon = 1e-4) : epsilon_(epsilon) {}
567 
568  template <typename T1, typename T2, typename T3, typename T4>
569  __host__ __device__ constexpr void operator()(T1& y,
570  const T1& x,
571  const T2& mean,
572  const T2& variance,
573  const T3& gamma,
574  const T4& beta) const
575  {
577  "Data type is not supported by this operation!");
578 
579  using ck::type_convert;
580  using ck::math::sqrt;
581 
582  T2 tmp_x, tmp_y;
583 
584  tmp_x = type_convert<T2>(x);
585 
586  tmp_y = ((tmp_x - mean) / sqrt(variance + type_convert<T2>(epsilon_))) *
587  type_convert<T2>(gamma) +
588  type_convert<T2>(beta);
589  y = type_convert<T1>(tmp_y);
590  };
591 
592  double epsilon_;
593 };
594 
595 // used by Conv+Bias+BatchNorm+Clamp inference
597 {
598  static constexpr const char* name = "BiasNormalizeInInferClamp";
599 
602  float epsilon = 1e-4)
603  : clamp_(floor, ceil), epsilon_(epsilon)
604  {
605  }
606 
607  template <typename T>
608  __host__ __device__ constexpr void operator()(T& y,
609  const T& x,
610  const T& bias,
611  const T& mean,
612  const T& variance,
613  const T& gamma,
614  const T& beta) const
615  {
616  using ck::type_convert;
617  using ck::math::sqrt;
618 
619  float tmp_x = type_convert<float>(x) + type_convert<float>(bias);
620 
621  float tmp_y =
622  ((tmp_x - type_convert<float>(mean)) / sqrt(type_convert<float>(variance) + epsilon_)) *
623  type_convert<float>(gamma) +
624  type_convert<float>(beta);
625  clamp_(tmp_y, tmp_y);
626  y = type_convert<T>(tmp_y);
627  };
628 
629  template <>
630  __host__ __device__ constexpr void operator()(float& y,
631  const float& x,
632  const float& bias,
633  const float& mean,
634  const float& variance,
635  const float& gamma,
636  const float& beta) const
637  {
638  using ck::type_convert;
639  using ck::math::sqrt;
640 
641  float tmp_y = (((x + bias) - mean) / sqrt(variance + epsilon_)) * gamma + beta;
642  clamp_(y, tmp_y);
643  };
644 
646  float epsilon_;
647 };
648 
649 template <typename Y, typename X>
651 
652 template <>
653 struct UnaryTypeConvert<float, ck::bhalf_t>
654 {
655  static constexpr const char* name = "UnaryTypeConvert";
656 
657  __host__ __device__ void operator()(float& y, ck::bhalf_t& x) const
658  {
659  y = ck::type_convert<float, ck::bhalf_t>(x);
660  }
661 };
662 
663 template <>
664 struct UnaryTypeConvert<ck::bhalf_t, float>
665 {
666  static constexpr const char* name = "UnaryTypeConvert";
667 
668  __host__ __device__ void operator()(ck::bhalf_t& y, float& x) const
669  {
670  y = ck::type_convert<ck::bhalf_t, float>(x);
671  }
672 };
673 
674 } // namespace element_wise
675 } // namespace tensor_operation
676 } // namespace ck
__host__ T ceil(T x)
Definition: math_v2.hpp:331
__host__ T floor(T x)
Definition: math_v2.hpp:367
Definition: ck.hpp:268
_Float16 half_t
Definition: data_type.hpp:31
ushort bhalf_t
Definition: data_type.hpp:30
__host__ constexpr __device__ Y type_convert(X x)
Definition: type_convert.hpp:98
_BitInt(4) int4_t
Definition: data_type.hpp:32
const GenericPointer< typename T::ValueType > T2 T::AllocatorType & a
Definition: pointer.h:1517
signed int int32_t
Definition: stdint.h:123
signed char int8_t
Definition: stdint.h:121
Definition: numeric_limits.hpp:309
Definition: type.hpp:177
Definition: element_wise_operation.hpp:352
__host__ constexpr __device__ void operator()(E &e, const C &c, const D0 &d0, const D1 &d1) const
static constexpr const char * name
Definition: element_wise_operation.hpp:353
Definition: element_wise_operation.hpp:149
__host__ __device__ void operator()(E &e, const C &c, const D0 &d0, const D1 &d1) const
Definition: element_wise_operation.hpp:153
static constexpr const char * name
Definition: element_wise_operation.hpp:150
Definition: element_wise_operation.hpp:115
static constexpr const char * name
Definition: element_wise_operation.hpp:116
__host__ constexpr __device__ void operator()(Y &, const X0 &, const X1 &, const X2 &) const
Definition: element_wise_operation.hpp:180
__host__ __device__ void operator()(E &e, const C &c, const D0 &d0, const D1 &d1) const
static constexpr const char * name
Definition: element_wise_operation.hpp:181
Definition: element_wise_operation.hpp:35
static constexpr const char * name
Definition: element_wise_operation.hpp:36
__host__ constexpr __device__ void operator()(Y &, const X0 &, const X1 &, const X2 &) const
Definition: element_wise_operation.hpp:597
BiasNormalizeInInferClamp(float floor=0.f, float ceil=NumericLimits< float >::Max(), float epsilon=1e-4)
Definition: element_wise_operation.hpp:600
__host__ constexpr __device__ void operator()(T &y, const T &x, const T &bias, const T &mean, const T &variance, const T &gamma, const T &beta) const
Definition: element_wise_operation.hpp:608
float epsilon_
Definition: element_wise_operation.hpp:646
Clamp clamp_
Definition: element_wise_operation.hpp:643
__host__ constexpr __device__ void operator()(float &y, const float &x, const float &bias, const float &mean, const float &variance, const float &gamma, const float &beta) const
Definition: element_wise_operation.hpp:630
static constexpr const char * name
Definition: element_wise_operation.hpp:598
Definition: unary_element_wise_operation.hpp:811
Definition: unary_element_wise_operation.hpp:924
Definition: element_wise_operation.hpp:329
__host__ constexpr __device__ void operator()(E &e, const C &c, const D0 &d0, const D1 &d1) const
static constexpr const char * name
Definition: element_wise_operation.hpp:330
Definition: element_wise_operation.hpp:218
__host__ __device__ void operator()(E &e, const C &c, const D0 &d0, const D1 &d1) const
static constexpr const char * name
Definition: element_wise_operation.hpp:219
Definition: element_wise_operation.hpp:272
static constexpr const char * name
Definition: element_wise_operation.hpp:273
__host__ constexpr __device__ void operator()(E &e, const C &c, const D0 &d0, const D1 &d1) const
Definition: element_wise_operation.hpp:490
Normalize(double epsilon=1e-4)
Definition: element_wise_operation.hpp:494
double epsilon_
Definition: element_wise_operation.hpp:553
__host__ constexpr __device__ void operator()(T1 &y, const T1 &x, const T2 &mean, const T2 &mean_square, const T3 &gamma, const T3 &beta) const
static constexpr const char * name
Definition: element_wise_operation.hpp:491
Definition: element_wise_operation.hpp:563
static constexpr const char * name
Definition: element_wise_operation.hpp:564
double epsilon_
Definition: element_wise_operation.hpp:590
__host__ constexpr __device__ void operator()(T1 &y, const T1 &x, const T2 &mean, const T2 &variance, const T3 &gamma, const T4 &beta) const
Definition: element_wise_operation.hpp:569
NormalizeInInfer(double epsilon=1e-4)
Definition: element_wise_operation.hpp:566
Definition: element_wise_operation.hpp:425
ScaleAddScaleAddRelu(const float alpha1=1.f, const float alpha2=1.f)
Definition: element_wise_operation.hpp:428
static constexpr const char * name
Definition: element_wise_operation.hpp:426
const float alpha2_
Definition: element_wise_operation.hpp:486
const float alpha1_
Definition: element_wise_operation.hpp:485
__host__ constexpr __device__ void operator()(E &e, const C &c, const D0 &d0, const D1 &d1) const
__host__ __device__ void operator()(ck::bhalf_t &y, float &x) const
Definition: element_wise_operation.hpp:668
__host__ __device__ void operator()(float &y, ck::bhalf_t &x) const
Definition: element_wise_operation.hpp:657
Definition: element_wise_operation.hpp:650