/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp Source File
binary_element_wise_operation.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
8 
9 namespace ck {
10 namespace tensor_operation {
11 namespace element_wise {
12 
13 struct Add
14 {
15  template <typename Y, typename X0, typename X1>
16  __host__ __device__ constexpr void operator()(Y& y, const X0& x0, const X1& x1) const;
17 
18  template <>
19  __host__ __device__ constexpr void
20  operator()<float>(float& y, const float& x0, const float& x1) const
21  {
22  y = x0 + x1;
23  };
24 
25  template <>
26  __host__ __device__ constexpr void
27  operator()<double>(double& y, const double& x0, const double& x1) const
28  {
29  y = x0 + x1;
30  };
31 
32  template <>
33  __host__ __device__ constexpr void
34  operator()<float>(float& y, const float& x0, const half_t& x1) const
35  {
36  y = x0 + type_convert<half_t>(x1);
37  };
38 
39  template <>
40  __host__ __device__ constexpr void
41  operator()<half_t>(half_t& y, const float& x0, const float& x1) const
42  {
43  y = type_convert<half_t>(x0 + x1);
44  };
45 
46  template <>
47  __host__ __device__ constexpr void
48  operator()<half_t>(half_t& y, const float& x0, const half_t& x1) const
49  {
50  y = type_convert<half_t>(x0) + x1;
51  };
52 
53  template <>
54  __host__ __device__ constexpr void
55  operator()<half_t>(half_t& y, const half_t& x0, const half_t& x1) const
56  {
57  y = x0 + x1;
58  };
59 
60  template <>
61  __host__ __device__ constexpr void
62  operator()<float>(float& y, const float& x0, const bhalf_t& x1) const
63  {
64  const float x1_tmp = ck::type_convert<float>(x1);
65  y = x0 + x1_tmp;
66  }
67 
68  template <>
69  __host__ __device__ constexpr void
70  operator()<bhalf_t>(bhalf_t& y, const bhalf_t& x0, const bhalf_t& x1) const
71  {
72  const float x1_tmp = ck::type_convert<float>(x0);
73  const float x2_tmp = ck::type_convert<float>(x1);
74  const float y_tmp = x1_tmp + x2_tmp;
75  y = ck::type_convert<bhalf_t>(y_tmp);
76  }
77 
78  template <>
79  __host__ __device__ constexpr void
80  operator()<bhalf_t>(bhalf_t& y, const float& x0, const bhalf_t& x1) const
81  {
82  const float x2_tmp = ck::type_convert<float>(x1);
83  const float y_tmp = x0 + x2_tmp;
84  y = ck::type_convert<bhalf_t>(y_tmp);
85  }
86 
87  template <>
88  __host__ __device__ constexpr void
89  operator()<int8_t>(int8_t& y, const int8_t& x0, const int8_t& x1) const
90  {
91  y = x0 + x1;
92  };
93 };
94 
95 struct Max
96 {
97  template <typename Y, typename X0, typename X1>
98  __host__ __device__ void operator()(Y& y, const X0& x0, const X1& x1) const
99  {
100  const Y x0_converted = type_convert<Y>(x0);
101  const Y x1_converted = type_convert<Y>(x1);
102  y = ck::math::max(x0_converted, x1_converted);
103  }
104 };
105 
106 struct Min
107 {
108  template <typename Y, typename X0, typename X1>
109  __host__ __device__ void operator()(Y& y, const X0& x0, const X1& x1) const
110  {
111  const Y x0_converted = type_convert<Y>(x0);
112  const Y x1_converted = type_convert<Y>(x1);
113  y = ck::math::min(x0_converted, x1_converted);
114  }
115 };
116 
117 struct Multiply
118 {
119  template <typename Y, typename X0, typename X1>
120  __host__ __device__ constexpr void operator()(Y& y, const X0& x0, const X1& x1) const;
121 
122  template <>
123  __host__ __device__ constexpr void
124  operator()<float>(float& y, const float& x0, const float& x1) const
125  {
126  y = x0 * x1;
127  };
128 
129  template <>
130  __host__ __device__ constexpr void
131  operator()<double>(double& y, const double& x0, const double& x1) const
132  {
133  y = x0 * x1;
134  };
135 
136  template <>
137  __host__ __device__ constexpr void
138  operator()<float>(float& y, const float& x0, const half_t& x1) const
139  {
140  y = x0 * type_convert<half_t>(x1);
141  };
142 
143  template <>
144  __host__ __device__ constexpr void
145  operator()<half_t>(half_t& y, const float& x0, const float& x1) const
146  {
147  y = type_convert<half_t>(x0 * x1);
148  };
149 
150  template <>
151  __host__ __device__ constexpr void
152  operator()<half_t>(half_t& y, const float& x0, const half_t& x1) const
153  {
154  y = type_convert<half_t>(x0) * x1;
155  };
156 
157  template <>
158  __host__ __device__ constexpr void
159  operator()<half_t>(half_t& y, const half_t& x0, const half_t& x1) const
160  {
161  y = x0 * x1;
162  };
163 
164  template <>
165  __host__ __device__ constexpr void
166  operator()<float>(float& y, const float& x0, const bhalf_t& x1) const
167  {
168  const float x1_tmp = ck::type_convert<float>(x1);
169  y = x0 * x1_tmp;
170  }
171 
172  template <>
173  __host__ __device__ constexpr void
174  operator()<bhalf_t>(bhalf_t& y, const bhalf_t& x0, const bhalf_t& x1) const
175  {
176  const float x1_tmp = ck::type_convert<float>(x0);
177  const float x2_tmp = ck::type_convert<float>(x1);
178  const float y_tmp = x1_tmp * x2_tmp;
179  y = ck::type_convert<bhalf_t>(y_tmp);
180  }
181 
182  template <>
183  __host__ __device__ constexpr void
184  operator()<bhalf_t>(bhalf_t& y, const int8_t& x0, const bhalf_t& x1) const
185  {
186  const float x1_tmp = ck::type_convert<float>(x0);
187  const float x2_tmp = ck::type_convert<float>(x1);
188  const float y_tmp = x1_tmp * x2_tmp;
189  y = ck::type_convert<bhalf_t>(y_tmp);
190  }
191 
192  template <>
193  __host__ __device__ constexpr void
194  operator()<bhalf_t>(bhalf_t& y, const float& x0, const bhalf_t& x1) const
195  {
196  const float x2_tmp = ck::type_convert<float>(x1);
197  const float y_tmp = x0 * x2_tmp;
198  y = ck::type_convert<bhalf_t>(y_tmp);
199  }
200 
201  template <>
202  __host__ __device__ constexpr void
203  operator()<int8_t>(int8_t& y, const int8_t& x0, const int8_t& x1) const
204  {
205  y = x0 * x1;
206  };
207 };
208 
209 struct ScaleAdd
210 {
211  __host__ __device__ ScaleAdd(float scale = 1.f) : scale_(scale) {}
212 
213  template <typename Y, typename X0, typename X1>
214  __host__ __device__ constexpr void operator()(Y& y, const X0& x0, const X1& x1) const
215  {
216  y = ck::type_convert<Y>(scale_ * ck::type_convert<float>(x0) + ck::type_convert<float>(x1));
217  }
218 
219  template <>
220  __host__ __device__ void
221  operator()<float, float, half_t>(float& y, const float& x0, const half_t& x1) const
222  {
223  y = scale_ * x0 + ck::type_convert<float>(x1);
224  };
225 
226  template <>
227  __host__ __device__ void
228  operator()<float, float, bhalf_t>(float& y, const float& x0, const bhalf_t& x1) const
229  {
230  y = scale_ * x0 + ck::type_convert<float>(x1);
231  };
232 
233  float scale_;
234 };
235 
236 struct Subtract
237 {
238  template <typename T>
239  __host__ __device__ constexpr void operator()(T& y, const T& x0, const T& x1) const;
240 
241  template <>
242  __host__ __device__ constexpr void
243  operator()<float>(float& y, const float& x0, const float& x1) const
244  {
245  y = x0 - x1;
246  };
247 
248  template <>
249  __host__ __device__ constexpr void
250  operator()<double>(double& y, const double& x0, const double& x1) const
251  {
252  y = x0 - x1;
253  };
254 
255  template <>
256  __host__ __device__ constexpr void
257  operator()<half_t>(half_t& y, const half_t& x0, const half_t& x1) const
258  {
259  y = x0 - x1;
260  };
261 
262  template <>
263  __host__ __device__ constexpr void
264  operator()<bhalf_t>(bhalf_t& y, const bhalf_t& x0, const bhalf_t& x1) const
265  {
266  const float x1_tmp = ck::type_convert<float>(x0);
267  const float x2_tmp = ck::type_convert<float>(x1);
268  const float y_tmp = x1_tmp - x2_tmp;
269  y = ck::type_convert<bhalf_t>(y_tmp);
270  }
271 
272  template <>
273  __host__ __device__ constexpr void
274  operator()<int8_t>(int8_t& y, const int8_t& x0, const int8_t& x1) const
275  {
276  y = x0 - x1;
277  };
278 };
279 
280 struct Bilinear
281 {
282  Bilinear(float alpha = 1.f, float beta = 1.f) : alpha_(alpha), beta_(beta){};
283 
284  template <typename Y, typename X0, typename X1>
285  __host__ __device__ constexpr void operator()(Y&, const X0&, const X1&) const;
286 
287  template <>
288  __host__ __device__ constexpr void
289  operator()<double, double, double>(double& y, const double& x0, const double& x1) const
290  {
291  y = alpha_ * x0 + beta_ * x1;
292  };
293 
294  template <>
295  __host__ __device__ constexpr void
296  operator()<float, float, float>(float& y, const float& x0, const float& x1) const
297  {
298  y = alpha_ * x0 + beta_ * x1;
299  };
300 
301  template <>
302  __host__ __device__ constexpr void
303  operator()<int8_t, int8_t, int8_t>(int8_t& y, const int8_t& x0, const int8_t& x1) const
304  {
305  y = type_convert<int8_t>(alpha_ * type_convert<float>(x0) +
306  beta_ * type_convert<float>(x1));
307  };
308 
309  template <>
310  __host__ __device__ constexpr void
311  operator()<half_t, half_t, half_t>(half_t& y, const half_t& x0, const half_t& x1) const
312  {
313  y = type_convert<half_t>(alpha_) * x0 + type_convert<half_t>(beta_) * x1;
314  };
315 
316  template <>
317  __host__ __device__ constexpr void
318  operator()<half_t, float, half_t>(half_t& y, const float& x0, const half_t& x1) const
319  {
320  y = type_convert<half_t>(alpha_ * x0 + beta_ * ck::type_convert<float>(x1));
321  };
322 
323  template <>
324  __host__ __device__ constexpr void
325  operator()<bhalf_t, bhalf_t, bhalf_t>(bhalf_t& y, const bhalf_t& x0, const bhalf_t& x1) const
326  {
327  const float x0_tmp = type_convert<float>(x0);
328  const float x1_tmp = type_convert<float>(x1);
329  const float y_tmp = alpha_ * x0_tmp + beta_ * x1_tmp;
330  y = type_convert<bhalf_t>(y_tmp);
331  };
332 
333  template <>
334  __host__ __device__ constexpr void
335  operator()<bhalf_t, float, bhalf_t>(bhalf_t& y, const float& x0, const bhalf_t& x1) const
336  {
337  const float x1_tmp = ck::type_convert<float>(x1);
338  const float y_tmp = alpha_ * x0 + beta_ * x1_tmp;
339  y = y_tmp;
340  };
341 
342  template <>
343  __host__ __device__ constexpr void
344  operator()<int8_t, int32_t, int8_t>(int8_t& y, const int32_t& x0, const int8_t& x1) const
345  {
346  y = type_convert<int8_t>(alpha_ * type_convert<float>(x0) +
347  beta_ * type_convert<float>(x1));
348  };
349 
350  float alpha_;
351  float beta_;
352 };
353 
354 struct AddClamp
355 {
357  : floor_(floor), ceil_(ceil){};
358 
359  template <typename Y, typename X0, typename X1>
360  __host__ __device__ constexpr void operator()(Y& y, const X0& x0, const X1& x1) const;
361 
362  template <>
363  __host__ __device__ constexpr void
364  operator()<float, float, float>(float& y, const float& x0, const float& x1) const
365  {
366  const float a = x0 + x1;
367  y = a > floor_ ? (a < ceil_ ? a : ceil_) : floor_;
368  };
369 
370  template <>
371  __host__ __device__ constexpr void
372  operator()<double, double, double>(double& y, const double& x0, const double& x1) const
373  {
374  const double a = x0 + x1;
375  y = a > floor_ ? (a < ceil_ ? a : ceil_) : floor_;
376  };
377 
378  template <>
379  __host__ __device__ constexpr void
380  operator()<half_t, half_t, half_t>(half_t& y, const half_t& x0, const half_t& x1) const
381  {
382  const half_t floor = type_convert<half_t>(floor_);
383  const half_t ceil = type_convert<half_t>(ceil_);
384  const half_t a = x0 + x1;
385  y = a > floor ? (a < ceil ? a : ceil) : floor;
386  };
387 
388  template <>
389  __host__ __device__ constexpr void
390  operator()<half_t, float, half_t>(half_t& y, const float& x0, const half_t& x1) const
391  {
392  const float a = x0 + type_convert<float>(x1);
393  const float b = a > floor_ ? (a < ceil_ ? a : ceil_) : floor_;
394  y = type_convert<half_t>(b);
395  };
396 
397  template <>
398  __host__ __device__ constexpr void
399  operator()<float, float, half_t>(float& y, const float& x0, const half_t& x1) const
400  {
401  const float a = x0 + type_convert<float>(x1);
402  y = a > floor_ ? (a < ceil_ ? a : ceil_) : floor_;
403  };
404 
405  template <>
406  __host__ __device__ constexpr void
407  operator()<bhalf_t, float, bhalf_t>(bhalf_t& y, const float& x0, const bhalf_t& x1) const
408  {
409  const float a = x0 + type_convert<float>(x1);
410  const float b = a > floor_ ? (a < ceil_ ? a : ceil_) : floor_;
411  y = type_convert<bhalf_t>(b);
412  };
413 
414  template <>
415  __host__ __device__ constexpr void
416  operator()<bhalf_t, bhalf_t, bhalf_t>(bhalf_t& y, const bhalf_t& x0, const bhalf_t& x1) const
417  {
418  const float a = type_convert<float>(x0) + type_convert<float>(x1);
419  const float b = a > floor_ ? (a < ceil_ ? a : ceil_) : floor_;
420  y = type_convert<bhalf_t>(b);
421  };
422 
423  template <>
424  __host__ __device__ constexpr void
425  operator()<int, int, int8_t>(int& y, const int& x0, const int8_t& x1) const
426  {
427  const int8_t a = x0 + x1;
428  y = a > floor_ ? (a < ceil_ ? a : ceil_) : floor_;
429  };
430 
431  template <>
432  __host__ __device__ constexpr void
433  operator()<int8_t, int8_t, int8_t>(int8_t& y, const int8_t& x0, const int8_t& x1) const
434  {
435  const int8_t a = x0 + x1;
436  y = a > floor_ ? (a < ceil_ ? a : ceil_) : floor_;
437  };
438 
439  const float floor_;
440  const float ceil_;
441 };
442 
443 struct AddRelu
444 {
445  template <typename Y, typename X0, typename X1>
446  __host__ __device__ constexpr void operator()(Y& y, const X0& x0, const X1& x1) const;
447 
448  template <>
449  __host__ __device__ constexpr void
450  operator()<float, float, float>(float& y, const float& x0, const float& x1) const
451  {
452  const float a = x0 + x1;
453  y = a > 0.0f ? a : 0.0f;
454  };
455 
456  template <>
457  __host__ __device__ constexpr void
458  operator()<double, double, double>(double& y, const double& x0, const double& x1) const
459  {
460  const double a = x0 + x1;
461  y = a > 0.0 ? a : 0.0;
462  };
463 
464  template <>
465  __host__ __device__ constexpr void
466  operator()<half_t, half_t, half_t>(half_t& y, const half_t& x0, const half_t& x1) const
467  {
468  const half_t a = x0 + x1;
469  y = a > type_convert<half_t>(0.0f) ? a : type_convert<half_t>(0.0f);
470  };
471 
472  template <>
473  __host__ __device__ constexpr void
474  operator()<half_t, float, half_t>(half_t& y, const float& x0, const half_t& x1) const
475  {
476  const float a = x0 + type_convert<float>(x1);
477  const float b = a > 0.0f ? a : 0.0f;
478  y = type_convert<half_t>(b);
479  };
480 
481  template <>
482  __host__ __device__ constexpr void
483  operator()<float, float, half_t>(float& y, const float& x0, const half_t& x1) const
484  {
485  const float a = x0 + type_convert<float>(x1);
486  y = a > 0.0f ? a : 0.0f;
487  };
488 
489  template <>
490  __host__ __device__ constexpr void
491  operator()<bhalf_t, float, bhalf_t>(bhalf_t& y, const float& x0, const bhalf_t& x1) const
492  {
493  const float a = x0 + type_convert<float>(x1);
494  const float b = a > 0.0f ? a : 0.0f;
495  y = type_convert<bhalf_t>(b);
496  };
497 
498  template <>
499  __host__ __device__ constexpr void
500  operator()<bhalf_t, bhalf_t, bhalf_t>(bhalf_t& y, const bhalf_t& x0, const bhalf_t& x1) const
501  {
502  const float a = type_convert<float>(x0) + type_convert<float>(x1);
503  const float b = a > 0.0f ? a : 0.0f;
504  y = type_convert<bhalf_t>(b);
505  };
506 
507  template <>
508  __host__ __device__ constexpr void
509  operator()<int, int, int8_t>(int& y, const int& x0, const int8_t& x1) const
510  {
511  const int8_t a = x0 + x1;
512  y = a > 0 ? a : 0;
513  };
514 
515  template <>
516  __host__ __device__ constexpr void
517  operator()<int8_t, int8_t, int8_t>(int8_t& y, const int8_t& x0, const int8_t& x1) const
518  {
519  const int8_t a = x0 + x1;
520  y = a > 0 ? a : 0;
521  };
522 };
523 
525 {
526  template <typename T>
527  __host__ __device__ constexpr void operator()(T& y, const T& x0, const T& x1) const;
528 
529  template <>
530  __host__ __device__ constexpr void
531  operator()<float>(float& y, const float& x0, const float& x1) const
532  {
533  float a = x0 + x1;
534  float b = a + float{3};
535  float c = (b > 0) * (b > 6.0f ? 6.0f : b) * a * 0.166667f;
536  y = c;
537  };
538 
539  template <>
540  __host__ __device__ constexpr void
541  operator()<double>(double& y, const double& x0, const double& x1) const
542  {
543  double a = x0 + x1;
544  double b = a + 3.0;
545  double c = (b > 0) * (b > 6.0 ? 6.0 : b) * a * 0.166667;
546  y = c;
547  };
548 
549  template <>
550  __host__ __device__ constexpr void
551  operator()<half_t>(half_t& y, const half_t& x0, const half_t& x1) const
552  {
553  float a = x0 + x1;
554  float b = a + 3.0f;
555  float c = (b > 0) * (b > 6.0f ? 6.0f : b) * a * 0.166667f;
556  y = c;
557  };
558 };
559 
560 // E = FastGelu(C + D)
562 {
563  template <typename E, typename C, typename D>
564  __host__ __device__ constexpr void operator()(E& e, const C& c, const D& d) const;
565 
566  template <>
567  __host__ __device__ constexpr void
568  operator()<float, float, float>(float& e, const float& c, const float& d) const
569  {
570  const float x = c + d;
571 
572  FastGelu{}.template operator()<float, float>(e, x);
573  }
574 
575  template <>
576  __host__ __device__ constexpr void
577  operator()<half_t, half_t, half_t>(half_t& e, const half_t& c, const half_t& d) const
578  {
579  const half_t x = c + d;
580 
581  ck::tensor_operation::element_wise::FastGelu{}.template operator()<half_t, half_t>(e, x);
582  }
583 
584  template <>
585  __host__ __device__ constexpr void
586  operator()<half_t, float, half_t>(half_t& e, const float& c, const half_t& d) const
587  {
588  const float x0_f = c + d;
589 
590  float x1_f = 0;
591 
592  ck::tensor_operation::element_wise::FastGelu{}.template operator()<float, float>(x1_f,
593  x0_f);
594 
595  e = type_convert<half_t>(x1_f);
596  }
597 
598  template <>
599  __host__ __device__ constexpr void
600  operator()<bhalf_t, bhalf_t, bhalf_t>(bhalf_t& e, const bhalf_t& c, const bhalf_t& d) const
601  {
602  const float x0_f = type_convert<float>(c) + type_convert<float>(d);
603 
604  float x1_f = 0;
605 
606  FastGelu{}.template operator()<float, float>(x1_f, x0_f);
607 
608  e = type_convert<bhalf_t>(x1_f);
609  }
610 
611  template <>
612  __host__ __device__ constexpr void
613  operator()<bhalf_t, float, bhalf_t>(bhalf_t& e, const float& c, const bhalf_t& d) const
614  {
615  const float x0_f = c + type_convert<float>(d);
616 
617  float x1_f = 0;
618 
619  FastGelu{}.template operator()<float, float>(x1_f, x0_f);
620 
621  e = type_convert<bhalf_t>(x1_f);
622  }
623 };
624 
625 // E = MultiplyFastGelu(C + D)
627 {
628  template <typename E, typename C, typename D>
629  __host__ __device__ constexpr void operator()(E& e, const C& c, const D& d) const;
630 
631  template <>
632  __host__ __device__ constexpr void
633  operator()<float, float, float>(float& e, const float& c, const float& d) const
634  {
635  const float x = c * d;
636 
637  FastGelu{}.template operator()<float, float>(e, x);
638  }
639 
640  template <>
641  __host__ __device__ constexpr void
642  operator()<half_t, half_t, half_t>(half_t& e, const half_t& c, const half_t& d) const
643  {
644  const half_t x = c * d;
645 
646  ck::tensor_operation::element_wise::FastGelu{}.template operator()<half_t, half_t>(e, x);
647  }
648 
649  template <>
650  __host__ __device__ constexpr void
651  operator()<half_t, float, half_t>(half_t& e, const float& c, const half_t& d) const
652  {
653  const float x0_f = c * d;
654 
655  float x1_f = 0;
656 
657  ck::tensor_operation::element_wise::FastGelu{}.template operator()<float, float>(x1_f,
658  x0_f);
659 
660  e = type_convert<half_t>(x1_f);
661  }
662 
663  template <>
664  __host__ __device__ constexpr void
665  operator()<bhalf_t, bhalf_t, bhalf_t>(bhalf_t& e, const bhalf_t& c, const bhalf_t& d) const
666  {
667  const float x0_f = type_convert<float>(c) * type_convert<float>(d);
668 
669  float x1_f = 0;
670 
671  FastGelu{}.template operator()<float, float>(x1_f, x0_f);
672 
673  e = type_convert<bhalf_t>(x1_f);
674  }
675 
676  template <>
677  __host__ __device__ constexpr void
678  operator()<bhalf_t, float, bhalf_t>(bhalf_t& e, const float& c, const bhalf_t& d) const
679  {
680  const float x0_f = c * type_convert<float>(d);
681 
682  float x1_f = 0;
683 
684  FastGelu{}.template operator()<float, float>(x1_f, x0_f);
685 
686  e = type_convert<bhalf_t>(x1_f);
687  }
688 };
689 
690 // E = Silu(C + D)
691 struct AddSilu
692 {
693  template <typename E, typename C, typename D>
694  __host__ __device__ constexpr void operator()(E& e, const C& c, const D& d) const;
695 
696  template <>
697  __host__ __device__ constexpr void
698  operator()<float, float, float>(float& e, const float& c, const float& d) const
699  {
700  const float x = c + d;
701 
702  Silu{}.template operator()<float>(e, x);
703  }
704 
705  template <>
706  __host__ __device__ constexpr void
707  operator()<half_t, half_t, half_t>(half_t& e, const half_t& c, const half_t& d) const
708  {
709  const half_t x = c + d;
710 
711  Silu{}.template operator()<half_t>(e, x);
712  }
713 
714  template <>
715  __host__ __device__ constexpr void
716  operator()<half_t, float, half_t>(half_t& e, const float& c, const half_t& d) const
717  {
718  const float x0_f = c + d;
719 
720  float x1_f = 0;
721 
722  Silu{}.template operator()<float>(x1_f, x0_f);
723 
724  e = type_convert<half_t>(x1_f);
725  }
726 
727  template <>
728  __host__ __device__ constexpr void
729  operator()<bhalf_t, float, bhalf_t>(bhalf_t& e, const float& c, const bhalf_t& d) const
730  {
731  const float x0_f = c + type_convert<float>(d);
732 
733  float x1_f = 0;
734 
735  Silu{}.template operator()<float>(x1_f, x0_f);
736 
737  e = type_convert<bhalf_t>(x1_f);
738  }
739 };
740 
742 {
743  __host__ __device__ ConvScaleAdd(float scale_in = 1.f,
744  float scale_wei = 1.f,
745  float scale_out = 1.f)
746  : scale_in_(scale_in), scale_wei_(scale_wei), scale_out_(scale_out)
747  {
748  }
749 
750  template <typename E, typename C, typename D>
751  __host__ __device__ void operator()(E& e, const C& c, const D& d) const;
752 
753  template <>
754  __host__ __device__ void
755  operator()<f8_t, float, float>(f8_t& e, const float& c, const float& d) const
756  {
757  float x;
758  Add{}.template operator()<float>(x, c * scale_in_ * scale_wei_, d);
759  e = type_convert<f8_t>(x * scale_out_);
760  };
761 
762  float scale_in_;
763  float scale_wei_;
764  float scale_out_;
765 };
766 
767 } // namespace element_wise
768 } // namespace tensor_operation
769 } // namespace ck
__host__ T ceil(T x)
Definition: math_v2.hpp:331
__host__ constexpr __device__ T max(T x)
Definition: math.hpp:84
__host__ T floor(T x)
Definition: math_v2.hpp:367
__host__ constexpr __device__ T min(T x)
Definition: math.hpp:116
Definition: ck.hpp:267
f8_fnuz_t f8_t
Definition: amd_ck_fp8.hpp:1737
_Float16 half_t
Definition: data_type.hpp:30
ushort bhalf_t
Definition: data_type.hpp:29
const GenericPointer< typename T::ValueType > T2 T::AllocatorType & a
Definition: pointer.h:1249
signed int int32_t
Definition: stdint.h:123
signed char int8_t
Definition: stdint.h:121
Definition: numeric_limits.hpp:309
Definition: binary_element_wise_operation.hpp:355
AddClamp(float floor=0.f, float ceil=NumericLimits< float >::Max())
Definition: binary_element_wise_operation.hpp:356
const float ceil_
Definition: binary_element_wise_operation.hpp:440
__host__ constexpr __device__ void operator()(Y &y, const X0 &x0, const X1 &x1) const
const float floor_
Definition: binary_element_wise_operation.hpp:437
Definition: binary_element_wise_operation.hpp:562
__host__ constexpr __device__ void operator()(E &e, const C &c, const D &d) const
Definition: binary_element_wise_operation.hpp:525
__host__ constexpr __device__ void operator()(T &y, const T &x0, const T &x1) const
Definition: binary_element_wise_operation.hpp:14
__host__ constexpr __device__ void operator()(Y &y, const X0 &x0, const X1 &x1) const
Definition: binary_element_wise_operation.hpp:444
__host__ constexpr __device__ void operator()(Y &y, const X0 &x0, const X1 &x1) const
Definition: binary_element_wise_operation.hpp:692
__host__ constexpr __device__ void operator()(E &e, const C &c, const D &d) const
Definition: binary_element_wise_operation.hpp:281
Bilinear(float alpha=1.f, float beta=1.f)
Definition: binary_element_wise_operation.hpp:282
__host__ constexpr __device__ void operator()(Y &, const X0 &, const X1 &) const
float beta_
Definition: binary_element_wise_operation.hpp:351
float alpha_
Definition: binary_element_wise_operation.hpp:348
Definition: binary_element_wise_operation.hpp:742
float scale_in_
Definition: binary_element_wise_operation.hpp:760
float scale_wei_
Definition: binary_element_wise_operation.hpp:763
__host__ __device__ ConvScaleAdd(float scale_in=1.f, float scale_wei=1.f, float scale_out=1.f)
Definition: binary_element_wise_operation.hpp:743
float scale_out_
Definition: binary_element_wise_operation.hpp:764
__host__ __device__ void operator()(E &e, const C &c, const D &d) const
Definition: unary_element_wise_operation.hpp:892
Definition: binary_element_wise_operation.hpp:96
__host__ __device__ void operator()(Y &y, const X0 &x0, const X1 &x1) const
Definition: binary_element_wise_operation.hpp:98
Definition: binary_element_wise_operation.hpp:107
__host__ __device__ void operator()(Y &y, const X0 &x0, const X1 &x1) const
Definition: binary_element_wise_operation.hpp:109
Definition: binary_element_wise_operation.hpp:627
__host__ constexpr __device__ void operator()(E &e, const C &c, const D &d) const
Definition: binary_element_wise_operation.hpp:118
__host__ constexpr __device__ void operator()(Y &y, const X0 &x0, const X1 &x1) const
Definition: binary_element_wise_operation.hpp:210
__host__ constexpr __device__ void operator()(Y &y, const X0 &x0, const X1 &x1) const
Definition: binary_element_wise_operation.hpp:214
float scale_
Definition: binary_element_wise_operation.hpp:231
__host__ __device__ ScaleAdd(float scale=1.f)
Definition: binary_element_wise_operation.hpp:211
Definition: unary_element_wise_operation.hpp:1049
Definition: binary_element_wise_operation.hpp:237
__host__ constexpr __device__ void operator()(T &y, const T &x0, const T &x1) const