/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp Source File
binary_element_wise_operation.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
8 
9 namespace ck {
10 namespace tensor_operation {
11 namespace element_wise {
12 
13 struct Add
14 {
15  static constexpr const char* name = "Add";
16 
17  template <typename Y, typename X0, typename X1>
18  __host__ __device__ constexpr void operator()(Y& y, const X0& x0, const X1& x1) const;
19 
20  template <>
21  __host__ __device__ constexpr void
22  operator()<float>(float& y, const float& x0, const float& x1) const
23  {
24  y = x0 + x1;
25  };
26 
27  template <>
28  __host__ __device__ constexpr void
29  operator()<double>(double& y, const double& x0, const double& x1) const
30  {
31  y = x0 + x1;
32  };
33 
34  template <>
35  __host__ __device__ constexpr void
36  operator()<float>(float& y, const float& x0, const half_t& x1) const
37  {
38  y = x0 + type_convert<half_t>(x1);
39  };
40 
41  template <>
42  __host__ __device__ constexpr void
43  operator()<half_t>(half_t& y, const float& x0, const float& x1) const
44  {
45  y = type_convert<half_t>(x0 + x1);
46  };
47 
48  template <>
49  __host__ __device__ constexpr void
50  operator()<half_t>(half_t& y, const float& x0, const half_t& x1) const
51  {
52  y = x0 + type_convert<float>(x1);
53  };
54 
55  template <>
56  __host__ __device__ constexpr void
57  operator()<half_t>(half_t& y, const half_t& x0, const half_t& x1) const
58  {
59  y = x0 + x1;
60  };
61 
62  template <>
63  __host__ __device__ constexpr void
64  operator()<float>(float& y, const float& x0, const bhalf_t& x1) const
65  {
66  const float x1_tmp = ck::type_convert<float>(x1);
67  y = x0 + x1_tmp;
68  }
69 
70  template <>
71  __host__ __device__ constexpr void
72  operator()<bhalf_t>(bhalf_t& y, const bhalf_t& x0, const bhalf_t& x1) const
73  {
74  const float x1_tmp = ck::type_convert<float>(x0);
75  const float x2_tmp = ck::type_convert<float>(x1);
76  const float y_tmp = x1_tmp + x2_tmp;
77  y = ck::type_convert<bhalf_t>(y_tmp);
78  }
79 
80  template <>
81  __host__ __device__ constexpr void
82  operator()<bhalf_t>(bhalf_t& y, const float& x0, const bhalf_t& x1) const
83  {
84  const float x2_tmp = ck::type_convert<float>(x1);
85  const float y_tmp = x0 + x2_tmp;
86  y = ck::type_convert<bhalf_t>(y_tmp);
87  }
88 
89  template <>
90  __host__ __device__ constexpr void
91  operator()<int8_t>(int8_t& y, const int8_t& x0, const int8_t& x1) const
92  {
93  y = x0 + x1;
94  };
95 };
96 
97 struct Max
98 {
99  static constexpr const char* name = "Max";
100 
101  template <typename Y, typename X0, typename X1>
102  __host__ __device__ void operator()(Y& y, const X0& x0, const X1& x1) const
103  {
104  const Y x0_converted = type_convert<Y>(x0);
105  const Y x1_converted = type_convert<Y>(x1);
106  y = ck::math::max(x0_converted, x1_converted);
107  }
108 };
109 
110 struct Min
111 {
112  static constexpr const char* name = "Min";
113 
114  template <typename Y, typename X0, typename X1>
115  __host__ __device__ void operator()(Y& y, const X0& x0, const X1& x1) const
116  {
117  const Y x0_converted = type_convert<Y>(x0);
118  const Y x1_converted = type_convert<Y>(x1);
119  y = ck::math::min(x0_converted, x1_converted);
120  }
121 };
122 
123 struct Multiply
124 {
125  static constexpr const char* name = "Multiply";
126 
127  template <typename Y, typename X0, typename X1>
128  __host__ __device__ constexpr void operator()(Y& y, const X0& x0, const X1& x1) const;
129 
130  template <>
131  __host__ __device__ constexpr void
132  operator()<float>(float& y, const float& x0, const float& x1) const
133  {
134  y = x0 * x1;
135  };
136 
137  template <>
138  __host__ __device__ constexpr void
139  operator()<double>(double& y, const double& x0, const double& x1) const
140  {
141  y = x0 * x1;
142  };
143 
144  template <>
145  __host__ __device__ constexpr void
146  operator()<float>(float& y, const float& x0, const half_t& x1) const
147  {
148  y = x0 * type_convert<half_t>(x1);
149  };
150 
151  template <>
152  __host__ __device__ constexpr void
153  operator()<half_t>(half_t& y, const float& x0, const float& x1) const
154  {
155  y = type_convert<half_t>(x0 * x1);
156  };
157 
158  template <>
159  __host__ __device__ constexpr void
160  operator()<half_t>(half_t& y, const float& x0, const half_t& x1) const
161  {
162  y = type_convert<half_t>(x0) * x1;
163  };
164 
165  template <>
166  __host__ __device__ constexpr void
167  operator()<half_t>(half_t& y, const half_t& x0, const half_t& x1) const
168  {
169  y = x0 * x1;
170  };
171 
172  template <>
173  __host__ __device__ constexpr void
174  operator()<float>(float& y, const float& x0, const bhalf_t& x1) const
175  {
176  const float x1_tmp = ck::type_convert<float>(x1);
177  y = x0 * x1_tmp;
178  }
179 
180  template <>
181  __host__ __device__ constexpr void
182  operator()<bhalf_t>(bhalf_t& y, const bhalf_t& x0, const bhalf_t& x1) const
183  {
184  const float x1_tmp = ck::type_convert<float>(x0);
185  const float x2_tmp = ck::type_convert<float>(x1);
186  const float y_tmp = x1_tmp * x2_tmp;
187  y = ck::type_convert<bhalf_t>(y_tmp);
188  }
189 
190  template <>
191  __host__ __device__ constexpr void
192  operator()<bhalf_t>(bhalf_t& y, const int8_t& x0, const bhalf_t& x1) const
193  {
194  const float x1_tmp = ck::type_convert<float>(x0);
195  const float x2_tmp = ck::type_convert<float>(x1);
196  const float y_tmp = x1_tmp * x2_tmp;
197  y = ck::type_convert<bhalf_t>(y_tmp);
198  }
199 
200  template <>
201  __host__ __device__ constexpr void
202  operator()<bhalf_t>(bhalf_t& y, const float& x0, const bhalf_t& x1) const
203  {
204  const float x2_tmp = ck::type_convert<float>(x1);
205  const float y_tmp = x0 * x2_tmp;
206  y = ck::type_convert<bhalf_t>(y_tmp);
207  }
208 
209  template <>
210  __host__ __device__ constexpr void
211  operator()<int8_t>(int8_t& y, const int8_t& x0, const int8_t& x1) const
212  {
213  y = x0 * x1;
214  };
215 };
216 
217 struct ScaleAdd
218 {
219  static constexpr const char* name = "ScaleAdd";
220 
221  __host__ __device__ ScaleAdd(float scale = 1.f) : scale_(scale) {}
222 
223  template <typename Y, typename X0, typename X1>
224  __host__ __device__ constexpr void operator()(Y& y, const X0& x0, const X1& x1) const
225  {
226  y = ck::type_convert<Y>(scale_ * ck::type_convert<float>(x0) + ck::type_convert<float>(x1));
227  }
228 
229  template <>
230  __host__ __device__ void
231  operator()<float, float, half_t>(float& y, const float& x0, const half_t& x1) const
232  {
233  y = scale_ * x0 + ck::type_convert<float>(x1);
234  };
235 
236  template <>
237  __host__ __device__ void
238  operator()<float, float, bhalf_t>(float& y, const float& x0, const bhalf_t& x1) const
239  {
240  y = scale_ * x0 + ck::type_convert<float>(x1);
241  };
242 
243  float scale_;
244 };
245 
246 struct Subtract
247 {
248  static constexpr const char* name = "Subtract";
249 
250  template <typename T>
251  __host__ __device__ constexpr void operator()(T& y, const T& x0, const T& x1) const;
252 
253  template <>
254  __host__ __device__ constexpr void
255  operator()<float>(float& y, const float& x0, const float& x1) const
256  {
257  y = x0 - x1;
258  };
259 
260  template <>
261  __host__ __device__ constexpr void
262  operator()<double>(double& y, const double& x0, const double& x1) const
263  {
264  y = x0 - x1;
265  };
266 
267  template <>
268  __host__ __device__ constexpr void
269  operator()<half_t>(half_t& y, const half_t& x0, const half_t& x1) const
270  {
271  y = x0 - x1;
272  };
273 
274  template <>
275  __host__ __device__ constexpr void
276  operator()<bhalf_t>(bhalf_t& y, const bhalf_t& x0, const bhalf_t& x1) const
277  {
278  const float x1_tmp = ck::type_convert<float>(x0);
279  const float x2_tmp = ck::type_convert<float>(x1);
280  const float y_tmp = x1_tmp - x2_tmp;
281  y = ck::type_convert<bhalf_t>(y_tmp);
282  }
283 
284  template <>
285  __host__ __device__ constexpr void
286  operator()<int8_t>(int8_t& y, const int8_t& x0, const int8_t& x1) const
287  {
288  y = x0 - x1;
289  };
290 };
291 
292 struct Bilinear
293 {
294  static constexpr const char* name = "Bilinear";
295 
296  Bilinear(float alpha = 1.f, float beta = 1.f) : alpha_(alpha), beta_(beta){};
297 
298  template <typename Y, typename X0, typename X1>
299  __host__ __device__ constexpr void operator()(Y&, const X0&, const X1&) const;
300 
301  template <>
302  __host__ __device__ constexpr void
303  operator()<double, double, double>(double& y, const double& x0, const double& x1) const
304  {
305  y = alpha_ * x0 + beta_ * x1;
306  };
307 
308  template <>
309  __host__ __device__ constexpr void
310  operator()<float, float, float>(float& y, const float& x0, const float& x1) const
311  {
312  y = alpha_ * x0 + beta_ * x1;
313  };
314 
315  template <>
316  __host__ __device__ constexpr void
317  operator()<int8_t, int8_t, int8_t>(int8_t& y, const int8_t& x0, const int8_t& x1) const
318  {
319  y = type_convert<int8_t>(alpha_ * type_convert<float>(x0) +
320  beta_ * type_convert<float>(x1));
321  };
322 
323  template <>
324  __host__ __device__ constexpr void
325  operator()<half_t, half_t, half_t>(half_t& y, const half_t& x0, const half_t& x1) const
326  {
327  y = type_convert<half_t>(alpha_) * x0 + type_convert<half_t>(beta_) * x1;
328  };
329 
330  template <>
331  __host__ __device__ constexpr void
332  operator()<half_t, float, half_t>(half_t& y, const float& x0, const half_t& x1) const
333  {
334  y = type_convert<half_t>(alpha_ * x0 + beta_ * ck::type_convert<float>(x1));
335  };
336 
337  template <>
338  __host__ __device__ constexpr void
339  operator()<bhalf_t, bhalf_t, bhalf_t>(bhalf_t& y, const bhalf_t& x0, const bhalf_t& x1) const
340  {
341  const float x0_tmp = type_convert<float>(x0);
342  const float x1_tmp = type_convert<float>(x1);
343  const float y_tmp = alpha_ * x0_tmp + beta_ * x1_tmp;
344  y = type_convert<bhalf_t>(y_tmp);
345  };
346 
347  template <>
348  __host__ __device__ constexpr void
349  operator()<bhalf_t, float, bhalf_t>(bhalf_t& y, const float& x0, const bhalf_t& x1) const
350  {
351  const float x1_tmp = ck::type_convert<float>(x1);
352  const float y_tmp = alpha_ * x0 + beta_ * x1_tmp;
353  y = y_tmp;
354  };
355 
356  template <>
357  __host__ __device__ constexpr void
358  operator()<int8_t, int32_t, int8_t>(int8_t& y, const int32_t& x0, const int8_t& x1) const
359  {
360  y = type_convert<int8_t>(alpha_ * type_convert<float>(x0) +
361  beta_ * type_convert<float>(x1));
362  };
363 
364  float alpha_;
365  float beta_;
366 };
367 
368 struct AddClamp
369 {
370  static constexpr const char* name = "AddClamp";
371 
373  : floor_(floor), ceil_(ceil){};
374 
375  template <typename Y, typename X0, typename X1>
376  __host__ __device__ constexpr void operator()(Y& y, const X0& x0, const X1& x1) const;
377 
378  template <>
379  __host__ __device__ constexpr void
380  operator()<float, float, float>(float& y, const float& x0, const float& x1) const
381  {
382  const float a = x0 + x1;
383  y = a > floor_ ? (a < ceil_ ? a : ceil_) : floor_;
384  };
385 
386  template <>
387  __host__ __device__ constexpr void
388  operator()<double, double, double>(double& y, const double& x0, const double& x1) const
389  {
390  const double a = x0 + x1;
391  y = a > floor_ ? (a < ceil_ ? a : ceil_) : floor_;
392  };
393 
394  template <>
395  __host__ __device__ constexpr void
396  operator()<half_t, half_t, half_t>(half_t& y, const half_t& x0, const half_t& x1) const
397  {
398  const half_t floor = type_convert<half_t>(floor_);
399  const half_t ceil = type_convert<half_t>(ceil_);
400  const half_t a = x0 + x1;
401  y = a > floor ? (a < ceil ? a : ceil) : floor;
402  };
403 
404  template <>
405  __host__ __device__ constexpr void
406  operator()<half_t, float, half_t>(half_t& y, const float& x0, const half_t& x1) const
407  {
408  const float a = x0 + type_convert<float>(x1);
409  const float b = a > floor_ ? (a < ceil_ ? a : ceil_) : floor_;
410  y = type_convert<half_t>(b);
411  };
412 
413  template <>
414  __host__ __device__ constexpr void
415  operator()<float, float, half_t>(float& y, const float& x0, const half_t& x1) const
416  {
417  const float a = x0 + type_convert<float>(x1);
418  y = a > floor_ ? (a < ceil_ ? a : ceil_) : floor_;
419  };
420 
421  template <>
422  __host__ __device__ constexpr void
423  operator()<bhalf_t, float, bhalf_t>(bhalf_t& y, const float& x0, const bhalf_t& x1) const
424  {
425  const float a = x0 + type_convert<float>(x1);
426  const float b = a > floor_ ? (a < ceil_ ? a : ceil_) : floor_;
427  y = type_convert<bhalf_t>(b);
428  };
429 
430  template <>
431  __host__ __device__ constexpr void
432  operator()<bhalf_t, bhalf_t, bhalf_t>(bhalf_t& y, const bhalf_t& x0, const bhalf_t& x1) const
433  {
434  const float a = type_convert<float>(x0) + type_convert<float>(x1);
435  const float b = a > floor_ ? (a < ceil_ ? a : ceil_) : floor_;
436  y = type_convert<bhalf_t>(b);
437  };
438 
439  template <>
440  __host__ __device__ constexpr void
441  operator()<int, int, int8_t>(int& y, const int& x0, const int8_t& x1) const
442  {
443  const int8_t a = x0 + x1;
444  y = a > floor_ ? (a < ceil_ ? a : ceil_) : floor_;
445  };
446 
447  template <>
448  __host__ __device__ constexpr void
449  operator()<int8_t, int8_t, int8_t>(int8_t& y, const int8_t& x0, const int8_t& x1) const
450  {
451  const int8_t a = x0 + x1;
452  y = a > floor_ ? (a < ceil_ ? a : ceil_) : floor_;
453  };
454 
455  const float floor_;
456  const float ceil_;
457 };
458 
459 struct AddRelu
460 {
461  static constexpr const char* name = "AddRelu";
462 
463  template <typename Y, typename X0, typename X1>
464  __host__ __device__ constexpr void operator()(Y& y, const X0& x0, const X1& x1) const;
465 
466  template <>
467  __host__ __device__ constexpr void
468  operator()<float, float, float>(float& y, const float& x0, const float& x1) const
469  {
470  const float a = x0 + x1;
471  y = a > 0.0f ? a : 0.0f;
472  };
473 
474  template <>
475  __host__ __device__ constexpr void
476  operator()<double, double, double>(double& y, const double& x0, const double& x1) const
477  {
478  const double a = x0 + x1;
479  y = a > 0.0 ? a : 0.0;
480  };
481 
482  template <>
483  __host__ __device__ constexpr void
484  operator()<half_t, half_t, half_t>(half_t& y, const half_t& x0, const half_t& x1) const
485  {
486  const half_t a = x0 + x1;
487  y = a > type_convert<half_t>(0.0f) ? a : type_convert<half_t>(0.0f);
488  };
489 
490  template <>
491  __host__ __device__ constexpr void
492  operator()<half_t, float, half_t>(half_t& y, const float& x0, const half_t& x1) const
493  {
494  const float a = x0 + type_convert<float>(x1);
495  const float b = a > 0.0f ? a : 0.0f;
496  y = type_convert<half_t>(b);
497  };
498 
499  template <>
500  __host__ __device__ constexpr void
501  operator()<float, float, half_t>(float& y, const float& x0, const half_t& x1) const
502  {
503  const float a = x0 + type_convert<float>(x1);
504  y = a > 0.0f ? a : 0.0f;
505  };
506 
507  template <>
508  __host__ __device__ constexpr void
509  operator()<bhalf_t, float, bhalf_t>(bhalf_t& y, const float& x0, const bhalf_t& x1) const
510  {
511  const float a = x0 + type_convert<float>(x1);
512  const float b = a > 0.0f ? a : 0.0f;
513  y = type_convert<bhalf_t>(b);
514  };
515 
516  template <>
517  __host__ __device__ constexpr void
518  operator()<bhalf_t, bhalf_t, bhalf_t>(bhalf_t& y, const bhalf_t& x0, const bhalf_t& x1) const
519  {
520  const float a = type_convert<float>(x0) + type_convert<float>(x1);
521  const float b = a > 0.0f ? a : 0.0f;
522  y = type_convert<bhalf_t>(b);
523  };
524 
525  template <>
526  __host__ __device__ constexpr void
527  operator()<int, int, int8_t>(int& y, const int& x0, const int8_t& x1) const
528  {
529  const int8_t a = x0 + x1;
530  y = a > 0 ? a : 0;
531  };
532 
533  template <>
534  __host__ __device__ constexpr void
535  operator()<int8_t, int8_t, int8_t>(int8_t& y, const int8_t& x0, const int8_t& x1) const
536  {
537  const int8_t a = x0 + x1;
538  y = a > 0 ? a : 0;
539  };
540 };
541 
543 {
544  static constexpr const char* name = "AddHardswish";
545 
546  template <typename T>
547  __host__ __device__ constexpr void operator()(T& y, const T& x0, const T& x1) const;
548 
549  template <>
550  __host__ __device__ constexpr void
551  operator()<float>(float& y, const float& x0, const float& x1) const
552  {
553  float a = x0 + x1;
554  float b = a + float{3};
555  float c = (b > 0) * (b > 6.0f ? 6.0f : b) * a * 0.166667f;
556  y = c;
557  };
558 
559  template <>
560  __host__ __device__ constexpr void
561  operator()<double>(double& y, const double& x0, const double& x1) const
562  {
563  double a = x0 + x1;
564  double b = a + 3.0;
565  double c = (b > 0) * (b > 6.0 ? 6.0 : b) * a * 0.166667;
566  y = c;
567  };
568 
569  template <>
570  __host__ __device__ constexpr void
571  operator()<half_t>(half_t& y, const half_t& x0, const half_t& x1) const
572  {
573  float a = x0 + x1;
574  float b = a + 3.0f;
575  float c = (b > 0) * (b > 6.0f ? 6.0f : b) * a * 0.166667f;
576  y = c;
577  };
578 };
579 
580 // E = FastGelu(C + D)
582 {
583  static constexpr const char* name = "AddFastGelu";
584 
585  template <typename E, typename C, typename D>
586  __host__ __device__ constexpr void operator()(E& e, const C& c, const D& d) const;
587 
588  template <>
589  __host__ __device__ constexpr void
590  operator()<float, float, float>(float& e, const float& c, const float& d) const
591  {
592  const float x = c + d;
593 
594  FastGelu{}.template operator()<float, float>(e, x);
595  }
596 
597  template <>
598  __host__ __device__ constexpr void
599  operator()<half_t, half_t, half_t>(half_t& e, const half_t& c, const half_t& d) const
600  {
601  const half_t x = c + d;
602 
603  ck::tensor_operation::element_wise::FastGelu{}.template operator()<half_t, half_t>(e, x);
604  }
605 
606  template <>
607  __host__ __device__ constexpr void
608  operator()<half_t, float, half_t>(half_t& e, const float& c, const half_t& d) const
609  {
610  const float x0_f = c + d;
611 
612  float x1_f = 0;
613 
614  ck::tensor_operation::element_wise::FastGelu{}.template operator()<float, float>(x1_f,
615  x0_f);
616 
617  e = type_convert<half_t>(x1_f);
618  }
619 
620  template <>
621  __host__ __device__ constexpr void
622  operator()<bhalf_t, bhalf_t, bhalf_t>(bhalf_t& e, const bhalf_t& c, const bhalf_t& d) const
623  {
624  const float x0_f = type_convert<float>(c) + type_convert<float>(d);
625 
626  float x1_f = 0;
627 
628  FastGelu{}.template operator()<float, float>(x1_f, x0_f);
629 
630  e = type_convert<bhalf_t>(x1_f);
631  }
632 
633  template <>
634  __host__ __device__ constexpr void
635  operator()<bhalf_t, float, bhalf_t>(bhalf_t& e, const float& c, const bhalf_t& d) const
636  {
637  const float x0_f = c + type_convert<float>(d);
638 
639  float x1_f = 0;
640 
641  FastGelu{}.template operator()<float, float>(x1_f, x0_f);
642 
643  e = type_convert<bhalf_t>(x1_f);
644  }
645 };
646 
647 // E = MultiplyFastGelu(C + D)
649 {
650  static constexpr const char* name = "MultiplyFastGelu";
651 
652  template <typename E, typename C, typename D>
653  __host__ __device__ constexpr void operator()(E& e, const C& c, const D& d) const;
654 
655  template <>
656  __host__ __device__ constexpr void
657  operator()<float, float, float>(float& e, const float& c, const float& d) const
658  {
659  const float x = c * d;
660 
661  FastGelu{}.template operator()<float, float>(e, x);
662  }
663 
664  template <>
665  __host__ __device__ constexpr void
666  operator()<half_t, half_t, half_t>(half_t& e, const half_t& c, const half_t& d) const
667  {
668  const half_t x = c * d;
669 
670  ck::tensor_operation::element_wise::FastGelu{}.template operator()<half_t, half_t>(e, x);
671  }
672 
673  template <>
674  __host__ __device__ constexpr void
675  operator()<half_t, float, half_t>(half_t& e, const float& c, const half_t& d) const
676  {
677  const float x0_f = c * d;
678 
679  float x1_f = 0;
680 
681  ck::tensor_operation::element_wise::FastGelu{}.template operator()<float, float>(x1_f,
682  x0_f);
683 
684  e = type_convert<half_t>(x1_f);
685  }
686 
687  template <>
688  __host__ __device__ constexpr void
689  operator()<bhalf_t, bhalf_t, bhalf_t>(bhalf_t& e, const bhalf_t& c, const bhalf_t& d) const
690  {
691  const float x0_f = type_convert<float>(c) * type_convert<float>(d);
692 
693  float x1_f = 0;
694 
695  FastGelu{}.template operator()<float, float>(x1_f, x0_f);
696 
697  e = type_convert<bhalf_t>(x1_f);
698  }
699 
700  template <>
701  __host__ __device__ constexpr void
702  operator()<bhalf_t, float, bhalf_t>(bhalf_t& e, const float& c, const bhalf_t& d) const
703  {
704  const float x0_f = c * type_convert<float>(d);
705 
706  float x1_f = 0;
707 
708  FastGelu{}.template operator()<float, float>(x1_f, x0_f);
709 
710  e = type_convert<bhalf_t>(x1_f);
711  }
712 };
713 
714 // E = Silu(C + D)
715 struct AddSilu
716 {
717  static constexpr const char* name = "AddSilu";
718 
719  template <typename E, typename C, typename D>
720  __host__ __device__ constexpr void operator()(E& e, const C& c, const D& d) const;
721 
722  template <>
723  __host__ __device__ constexpr void
724  operator()<float, float, float>(float& e, const float& c, const float& d) const
725  {
726  const float x = c + d;
727 
728  Silu{}.template operator()<float>(e, x);
729  }
730 
731  template <>
732  __host__ __device__ constexpr void
733  operator()<half_t, half_t, half_t>(half_t& e, const half_t& c, const half_t& d) const
734  {
735  const half_t x = c + d;
736 
737  Silu{}.template operator()<half_t>(e, x);
738  }
739 
740  template <>
741  __host__ __device__ constexpr void
742  operator()<half_t, float, half_t>(half_t& e, const float& c, const half_t& d) const
743  {
744  const float x0_f = c + d;
745 
746  float x1_f = 0;
747 
748  Silu{}.template operator()<float>(x1_f, x0_f);
749 
750  e = type_convert<half_t>(x1_f);
751  }
752 
753  template <>
754  __host__ __device__ constexpr void
755  operator()<bhalf_t, float, bhalf_t>(bhalf_t& e, const float& c, const bhalf_t& d) const
756  {
757  const float x0_f = c + type_convert<float>(d);
758 
759  float x1_f = 0;
760 
761  Silu{}.template operator()<float>(x1_f, x0_f);
762 
763  e = type_convert<bhalf_t>(x1_f);
764  }
765 };
766 
768 {
769  static constexpr const char* name = "ConvScaleAdd";
770 
771  __host__ __device__ ConvScaleAdd(float scale_in = 1.f,
772  float scale_wei = 1.f,
773  float scale_out = 1.f)
774  : scale_in_(scale_in), scale_wei_(scale_wei), scale_out_(scale_out)
775  {
776  }
777 
778  template <typename E, typename C, typename D>
779  __host__ __device__ void operator()(E& e, const C& c, const D& d) const;
780 
781  template <>
782  __host__ __device__ void
783  operator()<f8_t, float, float>(f8_t& e, const float& c, const float& d) const
784  {
785  float x;
786  Add{}.template operator()<float>(x, c * scale_in_ * scale_wei_, d);
787  e = type_convert<f8_t>(x * scale_out_);
788  };
789 
790  float scale_in_;
791  float scale_wei_;
792  float scale_out_;
793 };
794 
795 } // namespace element_wise
796 } // namespace tensor_operation
797 } // namespace ck
__host__ T ceil(T x)
Definition: math_v2.hpp:331
__host__ constexpr __device__ T max(T x)
Definition: math.hpp:84
__host__ T floor(T x)
Definition: math_v2.hpp:367
__host__ constexpr __device__ T min(T x)
Definition: math.hpp:116
Definition: ck.hpp:268
f8_fnuz_t f8_t
Definition: amd_ck_fp8.hpp:1762
_Float16 half_t
Definition: data_type.hpp:31
ushort bhalf_t
Definition: data_type.hpp:30
const GenericPointer< typename T::ValueType > T2 T::AllocatorType & a
Definition: pointer.h:1517
signed int int32_t
Definition: stdint.h:123
signed char int8_t
Definition: stdint.h:121
Definition: numeric_limits.hpp:309
Definition: amd_ck_fp8.hpp:36
Definition: binary_element_wise_operation.hpp:369
static constexpr const char * name
Definition: binary_element_wise_operation.hpp:370
AddClamp(float floor=0.f, float ceil=NumericLimits< float >::Max())
Definition: binary_element_wise_operation.hpp:372
const float ceil_
Definition: binary_element_wise_operation.hpp:456
__host__ constexpr __device__ void operator()(Y &y, const X0 &x0, const X1 &x1) const
const float floor_
Definition: binary_element_wise_operation.hpp:453
Definition: binary_element_wise_operation.hpp:582
__host__ constexpr __device__ void operator()(E &e, const C &c, const D &d) const
static constexpr const char * name
Definition: binary_element_wise_operation.hpp:583
Definition: binary_element_wise_operation.hpp:543
__host__ constexpr __device__ void operator()(T &y, const T &x0, const T &x1) const
static constexpr const char * name
Definition: binary_element_wise_operation.hpp:544
Definition: binary_element_wise_operation.hpp:14
static constexpr const char * name
Definition: binary_element_wise_operation.hpp:15
__host__ constexpr __device__ void operator()(Y &y, const X0 &x0, const X1 &x1) const
Definition: binary_element_wise_operation.hpp:460
static constexpr const char * name
Definition: binary_element_wise_operation.hpp:461
__host__ constexpr __device__ void operator()(Y &y, const X0 &x0, const X1 &x1) const
Definition: binary_element_wise_operation.hpp:716
__host__ constexpr __device__ void operator()(E &e, const C &c, const D &d) const
static constexpr const char * name
Definition: binary_element_wise_operation.hpp:717
Definition: binary_element_wise_operation.hpp:293
Bilinear(float alpha=1.f, float beta=1.f)
Definition: binary_element_wise_operation.hpp:296
static constexpr const char * name
Definition: binary_element_wise_operation.hpp:294
__host__ constexpr __device__ void operator()(Y &, const X0 &, const X1 &) const
float beta_
Definition: binary_element_wise_operation.hpp:365
float alpha_
Definition: binary_element_wise_operation.hpp:362
Definition: binary_element_wise_operation.hpp:768
float scale_in_
Definition: binary_element_wise_operation.hpp:788
float scale_wei_
Definition: binary_element_wise_operation.hpp:791
__host__ __device__ ConvScaleAdd(float scale_in=1.f, float scale_wei=1.f, float scale_out=1.f)
Definition: binary_element_wise_operation.hpp:771
float scale_out_
Definition: binary_element_wise_operation.hpp:792
static constexpr const char * name
Definition: binary_element_wise_operation.hpp:769
__host__ __device__ void operator()(E &e, const C &c, const D &d) const
Definition: unary_element_wise_operation.hpp:924
Definition: binary_element_wise_operation.hpp:98
static constexpr const char * name
Definition: binary_element_wise_operation.hpp:99
__host__ __device__ void operator()(Y &y, const X0 &x0, const X1 &x1) const
Definition: binary_element_wise_operation.hpp:102
Definition: binary_element_wise_operation.hpp:111
static constexpr const char * name
Definition: binary_element_wise_operation.hpp:112
__host__ __device__ void operator()(Y &y, const X0 &x0, const X1 &x1) const
Definition: binary_element_wise_operation.hpp:115
Definition: binary_element_wise_operation.hpp:649
__host__ constexpr __device__ void operator()(E &e, const C &c, const D &d) const
static constexpr const char * name
Definition: binary_element_wise_operation.hpp:650
Definition: binary_element_wise_operation.hpp:124
static constexpr const char * name
Definition: binary_element_wise_operation.hpp:125
__host__ constexpr __device__ void operator()(Y &y, const X0 &x0, const X1 &x1) const
Definition: binary_element_wise_operation.hpp:218
__host__ constexpr __device__ void operator()(Y &y, const X0 &x0, const X1 &x1) const
Definition: binary_element_wise_operation.hpp:224
float scale_
Definition: binary_element_wise_operation.hpp:241
__host__ __device__ ScaleAdd(float scale=1.f)
Definition: binary_element_wise_operation.hpp:221
static constexpr const char * name
Definition: binary_element_wise_operation.hpp:219
Definition: unary_element_wise_operation.hpp:1087
Definition: binary_element_wise_operation.hpp:247
static constexpr const char * name
Definition: binary_element_wise_operation.hpp:248
__host__ constexpr __device__ void operator()(T &y, const T &x0, const T &x1) const