/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/utility/reduction_operator.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/utility/reduction_operator.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/utility/reduction_operator.hpp Source File
reduction_operator.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include "ck/ck.hpp"
8 #include "ck/utility/type.hpp"
10 
11 namespace ck {
12 
13 namespace reduce {
14 
15 // Every binary operator used in reduction is represented by a templated functor class. Each functor
16 // class must provide at least
17 // three members:
18 // 1) GetIdentityValue() -- the interface to return the "identity element" for the binary
19 // operator, "identity element" is the unique
20 // element in the algebraic space that doesn't affect the value of other elements
21 // when operated against them, and the concept is similar to zero vector in
22 // vector space
23 // (http://pages.cs.wisc.edu/~matthewb/pages/notes/pdf/linearalgebra/VectorSpaces.pdf).
24 // 2) IsCompatibleInMemoryDataOperation() -- return true if the reduction task corresponding to this
25 // operator can use the InMemoryDataOperation to finalize, or else it return false
26 // 3) operator() -- the first argument of the operator must be both an input & output, and the
27 // corresponding variable usually stores
28 // the accumulated result of many operator() calls; the second argument is only an
29 // input. For indexable binary
30 // operator, the second version of operator() has third argument (which is an
31 // output) to indicate whether the
32 // accumulated value (the first argument) has changed, in which case the recorded
33 // accumulated index also need be
34 // changed.
35 
36 struct Add
37 {
38  template <typename T>
39  __host__ __device__ static constexpr T GetIdentityValue()
40  {
41  return type_convert<T>(0.0f);
42  };
43 
44  __host__ __device__ static constexpr bool
46  {
47  return operation == InMemoryDataOperationEnum::AtomicAdd ||
48  operation == InMemoryDataOperationEnum::Set;
49  };
50 
51  template <typename T>
52  __host__ __device__ inline constexpr void operator()(T& a, T b) const
53  {
56  "The data type is not supported by the Add accumulator!");
57 
58  a = a + b;
59  }
60 
61  __host__ __device__ inline constexpr void operator()(f8_t& a, f8_t b) const
62  {
63  float a_ = type_convert<float>(a);
64  float b_ = type_convert<float>(b);
65 
66  a = type_convert<f8_t>(a_ + b_);
67  }
68 
69  __host__ __device__ inline constexpr void operator()(half_t& a, half_t b) const
70  {
71  float a_ = type_convert<float>(a);
72  float b_ = type_convert<float>(b);
73 
74  a = type_convert<half_t>(a_ + b_);
75  }
76 
77  __host__ __device__ inline constexpr void operator()(bhalf_t& a, bhalf_t b) const
78  {
79  float a_ = type_convert<float>(a);
80  float b_ = type_convert<float>(b);
81 
82  a = type_convert<bhalf_t>(a_ + b_);
83  }
84 };
85 
86 struct SquaredAdd
87 {
88  template <class T>
89  __host__ __device__ static constexpr T GetIdentityValue()
90  {
91  return type_convert<T>(0.0f);
92  };
93 
94  __host__ __device__ static constexpr bool
96  {
97  return operation == InMemoryDataOperationEnum::AtomicAdd ||
98  operation == InMemoryDataOperationEnum::Set;
99  };
100 
101  template <class T>
102  __host__ __device__ inline constexpr void operator()(T& a, T b) const
103  {
107  "The data type is not supported by the SquaredAdd accumulator!");
108 
109  a = a + b * b;
110  }
111 };
112 
113 struct Mul
114 {
115  template <typename T>
116  __host__ __device__ static constexpr T GetIdentityValue()
117  {
118  return type_convert<T>(1.0f);
119  };
120 
121  __host__ __device__ static constexpr bool
123  {
124  return operation == InMemoryDataOperationEnum::Set;
125  };
126 
127  template <typename T>
128  __host__ __device__ inline constexpr void operator()(T& a, T b) const
129  {
132  "The data type is not supported by the Mul accumulator!");
133 
134  a = a * b;
135  }
136 
137  __host__ __device__ inline constexpr void operator()(f8_t& a, f8_t b) const
138  {
139  float a_ = type_convert<float>(a);
140  float b_ = type_convert<float>(b);
141 
142  a = type_convert<f8_t>(a_ * b_);
143  }
144 
145  __host__ __device__ inline constexpr void operator()(half_t& a, half_t b) const
146  {
147  float a_ = type_convert<float>(a);
148  float b_ = type_convert<float>(b);
149 
150  a = type_convert<half_t>(a_ * b_);
151  }
152 
153  __host__ __device__ inline constexpr void operator()(bhalf_t& a, bhalf_t b) const
154  {
155  float a_ = type_convert<float>(a);
156  float b_ = type_convert<float>(b);
157 
158  a = type_convert<bhalf_t>(a_ * b_);
159  }
160 };
161 
162 struct Max
163 {
164  template <typename T>
165  __host__ __device__ static constexpr T GetIdentityValue()
166  {
167  if constexpr(is_same_v<T, bhalf_t>)
168  {
169  float val = NumericLimits<float>::Lowest();
170  return type_convert<bhalf_t>(val);
171  }
172  if constexpr(is_same_v<T, f8_t>)
173  {
174  float val = NumericLimits<float>::Lowest();
175  return type_convert<f8_t>(val);
176  }
177  if constexpr(is_same_v<T, half_t>)
178  {
179  float val = NumericLimits<float>::Lowest();
180  return type_convert<half_t>(val);
181  }
182  else
183  {
184  return NumericLimits<T>::Lowest();
185  }
186  };
187 
188  __host__ __device__ static constexpr bool
190  {
191  // ToChange: atomic_max to be added
192  return operation == InMemoryDataOperationEnum::Set;
193  };
194 
195  template <typename T>
196  __host__ __device__ inline constexpr void operator()(T& a, T b) const
197  {
200  "The data type is not supported by the Max accumulator!");
201 
202  if(a < b)
203  a = b;
204  }
205 
206  __host__ __device__ inline constexpr void operator()(bhalf_t& a, bhalf_t b) const
207  {
208  float a_ = type_convert<float>(a);
209  float b_ = type_convert<float>(b);
210 
211  if(a_ < b_)
212  a = b;
213  }
214 
215  __host__ __device__ inline constexpr void operator()(half_t& a, half_t b) const
216  {
217  float a_ = type_convert<float>(a);
218  float b_ = type_convert<float>(b);
219 
220  if(a_ < b_)
221  a = b;
222  }
223 
224  __host__ __device__ inline constexpr void operator()(f8_t& a, f8_t b) const
225  {
226  float a_ = type_convert<float>(a);
227  float b_ = type_convert<float>(b);
228 
229  if(a_ < b_)
230  a = b;
231  }
232 
233  template <typename T>
234  __host__ __device__ inline constexpr void operator()(T& a, T b, bool& changed) const
235  {
238  "The data type is not supported by the Max accumulator!");
239 
240  if(a < b)
241  {
242  a = b;
243  changed = true;
244  }
245  }
246 
247  __host__ __device__ inline constexpr void operator()(bhalf_t& a, bhalf_t b, bool& changed) const
248  {
249  float a_ = type_convert<float>(a);
250  float b_ = type_convert<float>(b);
251 
252  if(a_ < b_)
253  {
254  a = b;
255  changed = true;
256  }
257  }
258 
259  __host__ __device__ inline constexpr void operator()(half_t& a, half_t b, bool& changed) const
260  {
261  float a_ = type_convert<float>(a);
262  float b_ = type_convert<float>(b);
263 
264  if(a_ < b_)
265  {
266  a = b;
267  changed = true;
268  }
269  }
270 
271  __host__ __device__ inline constexpr void operator()(f8_t& a, f8_t b, bool& changed) const
272  {
273  float a_ = type_convert<float>(a);
274  float b_ = type_convert<float>(b);
275 
276  if(a_ < b_)
277  {
278  a = b;
279  changed = true;
280  }
281  }
282 };
283 
284 struct Min
285 {
286  template <typename T>
287  __host__ __device__ static constexpr T GetIdentityValue()
288  {
289  if constexpr(is_same_v<T, bhalf_t>)
290  {
291  float val = NumericLimits<float>::Max();
292  return type_convert<bhalf_t>(val);
293  }
294  else if constexpr(is_same_v<T, half_t>)
295  {
296  float val = NumericLimits<float>::Max();
297  return type_convert<half_t>(val);
298  }
299  else if constexpr(is_same_v<T, f8_t>)
300  {
301  float val = NumericLimits<float>::Max();
302  return type_convert<f8_t>(val);
303  }
304  else
305  {
306  return NumericLimits<T>::Max();
307  }
308  return NumericLimits<T>::Max();
309  };
310 
311  __host__ __device__ static constexpr bool
313  {
314  // ToChange: atomic_min to be added
315  return operation == InMemoryDataOperationEnum::Set;
316  };
317 
318  template <typename T>
319  __host__ __device__ inline constexpr void operator()(T& a, T b) const
320  {
323  "The data type is not supported by the Min accumulator!");
324 
325  if(a > b)
326  a = b;
327  }
328 
329  __host__ __device__ inline constexpr void operator()(bhalf_t& a, bhalf_t b) const
330  {
331  float a_ = type_convert<float>(a);
332  float b_ = type_convert<float>(b);
333 
334  if(a_ > b_)
335  a = b;
336  }
337 
338  __host__ __device__ inline constexpr void operator()(half_t& a, half_t b) const
339  {
340  float a_ = type_convert<float>(a);
341  float b_ = type_convert<float>(b);
342 
343  if(a_ > b_)
344  a = b;
345  }
346 
347  __host__ __device__ inline constexpr void operator()(f8_t& a, f8_t b) const
348  {
349  float a_ = type_convert<float>(a);
350  float b_ = type_convert<float>(b);
351 
352  if(a_ > b_)
353  a = b;
354  }
355 
356  template <typename T>
357  __host__ __device__ inline constexpr void operator()(T& a, T b, bool& changed) const
358  {
362  "The data type is not supported by the Min accumulator!");
363 
364  if(a > b)
365  {
366  a = b;
367  changed = true;
368  }
369  }
370 
371  __host__ __device__ inline constexpr void operator()(bhalf_t& a, bhalf_t b, bool& changed) const
372  {
373  float a_ = type_convert<float>(a);
374  float b_ = type_convert<float>(b);
375 
376  if(a_ > b_)
377  {
378  a = b;
379  changed = true;
380  }
381  }
382 
383  __host__ __device__ inline constexpr void operator()(half_t& a, half_t b, bool& changed) const
384  {
385  float a_ = type_convert<float>(a);
386  float b_ = type_convert<float>(b);
387 
388  if(a_ > b_)
389  {
390  a = b;
391  changed = true;
392  }
393  }
394 
395  __host__ __device__ inline constexpr void operator()(f8_t& a, f8_t b, bool& changed) const
396  {
397  float a_ = type_convert<float>(a);
398  float b_ = type_convert<float>(b);
399 
400  if(a_ > b_)
401  {
402  a = b;
403  changed = true;
404  }
405  }
406 };
407 
408 struct AMax
409 {
410  template <typename T>
411  __host__ __device__ static constexpr T GetIdentityValue()
412  {
413  return type_convert<T>(0.0f);
414  };
415 
416  __host__ __device__ static constexpr bool
418  {
419  // ToChange: atomic_max to be added
420  return operation == InMemoryDataOperationEnum::Set;
421  };
422 
423  template <typename T>
424  __host__ __device__ inline constexpr void operator()(T& a, T b) const
425  {
429  "The data type is not supported by the AMax accumulator!");
430 
431  if(a < b)
432  a = b;
433  }
434 
435  __host__ __device__ inline constexpr void operator()(f8_t& a, f8_t b) const
436  {
437  float a_ = type_convert<float>(a);
438  float b_ = type_convert<float>(b);
439 
440  if(a_ < b_)
441  a = b;
442  }
443 
444  template <typename T>
445  __host__ __device__ inline constexpr void operator()(T& a, T b, bool& changed) const
446  {
450  "The data type is not supported by the AMax accumulator!");
451 
452  if(a < b)
453  {
454  a = b;
455  changed = true;
456  }
457  }
458 
459  __host__ __device__ inline constexpr void operator()(f8_t& a, f8_t b, bool& changed) const
460  {
461  float a_ = type_convert<float>(a);
462  float b_ = type_convert<float>(b);
463 
464  if(a_ < b_)
465  {
466  a = b;
467  changed = true;
468  }
469  }
470 };
471 
472 template <typename T>
474 {
475  T result = ck::type_convert<T>(0.0f);
476 
477  if(operation == InMemoryDataOperationEnum::AtomicMax)
478  result = ck::NumericLimits<T>::Lowest();
479 
480  return (result);
481 };
482 
483 template <InMemoryDataOperationEnum Operation, typename DataType>
485 {
486  static constexpr bool value = false;
487 };
488 
489 template <typename DataType>
491 {
492  static constexpr bool value =
494 };
495 
496 template <typename DataType>
498 {
499  static constexpr bool value =
501 };
502 
503 template <typename DataType>
505 {
506  static constexpr bool value =
511 };
512 
513 template <typename DataType>
515 {
516  static constexpr bool value =
520 };
521 
522 } // namespace reduce
523 } // namespace ck
constexpr T GetIdentityValueForInMemoryDataOperation(InMemoryDataOperationEnum operation)
Definition: reduction_operator.hpp:473
Definition: ck.hpp:267
InMemoryDataOperationEnum
Definition: ck.hpp:276
f8_fnuz_t f8_t
Definition: amd_ck_fp8.hpp:1737
_Float16 half_t
Definition: data_type.hpp:30
ushort bhalf_t
Definition: data_type.hpp:29
const GenericPointer< typename T::ValueType > T2 T::AllocatorType & a
Definition: pointer.h:1249
__host__ static constexpr __device__ T Lowest()
Definition: numeric_limits.hpp:312
__host__ static constexpr __device__ T Max()
Definition: numeric_limits.hpp:311
Definition: type.hpp:177
Definition: reduction_operator.hpp:409
__host__ constexpr __device__ void operator()(T &a, T b, bool &changed) const
Definition: reduction_operator.hpp:445
__host__ constexpr __device__ void operator()(T &a, T b) const
Definition: reduction_operator.hpp:424
__host__ constexpr __device__ void operator()(f8_t &a, f8_t b) const
Definition: reduction_operator.hpp:435
__host__ constexpr __device__ void operator()(f8_t &a, f8_t b, bool &changed) const
Definition: reduction_operator.hpp:459
__host__ static constexpr __device__ T GetIdentityValue()
Definition: reduction_operator.hpp:411
__host__ static constexpr __device__ bool IsCompatibleInMemoryDataOperation(InMemoryDataOperationEnum operation)
Definition: reduction_operator.hpp:417
Definition: reduction_operator.hpp:37
__host__ constexpr __device__ void operator()(f8_t &a, f8_t b) const
Definition: reduction_operator.hpp:61
__host__ static constexpr __device__ T GetIdentityValue()
Definition: reduction_operator.hpp:39
__host__ constexpr __device__ void operator()(half_t &a, half_t b) const
Definition: reduction_operator.hpp:69
__host__ constexpr __device__ void operator()(T &a, T b) const
Definition: reduction_operator.hpp:52
__host__ constexpr __device__ void operator()(bhalf_t &a, bhalf_t b) const
Definition: reduction_operator.hpp:77
__host__ static constexpr __device__ bool IsCompatibleInMemoryDataOperation(InMemoryDataOperationEnum operation)
Definition: reduction_operator.hpp:45
Definition: reduction_operator.hpp:485
static constexpr bool value
Definition: reduction_operator.hpp:486
Definition: reduction_operator.hpp:163
__host__ constexpr __device__ void operator()(bhalf_t &a, bhalf_t b, bool &changed) const
Definition: reduction_operator.hpp:247
__host__ constexpr __device__ void operator()(half_t &a, half_t b, bool &changed) const
Definition: reduction_operator.hpp:259
__host__ constexpr __device__ void operator()(f8_t &a, f8_t b) const
Definition: reduction_operator.hpp:224
__host__ constexpr __device__ void operator()(T &a, T b, bool &changed) const
Definition: reduction_operator.hpp:234
__host__ constexpr __device__ void operator()(bhalf_t &a, bhalf_t b) const
Definition: reduction_operator.hpp:206
__host__ static constexpr __device__ T GetIdentityValue()
Definition: reduction_operator.hpp:165
__host__ constexpr __device__ void operator()(f8_t &a, f8_t b, bool &changed) const
Definition: reduction_operator.hpp:271
__host__ constexpr __device__ void operator()(half_t &a, half_t b) const
Definition: reduction_operator.hpp:215
__host__ static constexpr __device__ bool IsCompatibleInMemoryDataOperation(InMemoryDataOperationEnum operation)
Definition: reduction_operator.hpp:189
__host__ constexpr __device__ void operator()(T &a, T b) const
Definition: reduction_operator.hpp:196
Definition: reduction_operator.hpp:285
__host__ constexpr __device__ void operator()(half_t &a, half_t b) const
Definition: reduction_operator.hpp:338
__host__ constexpr __device__ void operator()(f8_t &a, f8_t b, bool &changed) const
Definition: reduction_operator.hpp:395
__host__ constexpr __device__ void operator()(bhalf_t &a, bhalf_t b, bool &changed) const
Definition: reduction_operator.hpp:371
__host__ static constexpr __device__ bool IsCompatibleInMemoryDataOperation(InMemoryDataOperationEnum operation)
Definition: reduction_operator.hpp:312
__host__ constexpr __device__ void operator()(half_t &a, half_t b, bool &changed) const
Definition: reduction_operator.hpp:383
__host__ constexpr __device__ void operator()(T &a, T b, bool &changed) const
Definition: reduction_operator.hpp:357
__host__ constexpr __device__ void operator()(bhalf_t &a, bhalf_t b) const
Definition: reduction_operator.hpp:329
__host__ static constexpr __device__ T GetIdentityValue()
Definition: reduction_operator.hpp:287
__host__ constexpr __device__ void operator()(f8_t &a, f8_t b) const
Definition: reduction_operator.hpp:347
__host__ constexpr __device__ void operator()(T &a, T b) const
Definition: reduction_operator.hpp:319
Definition: reduction_operator.hpp:114
__host__ constexpr __device__ void operator()(T &a, T b) const
Definition: reduction_operator.hpp:128
__host__ static constexpr __device__ bool IsCompatibleInMemoryDataOperation(InMemoryDataOperationEnum operation)
Definition: reduction_operator.hpp:122
__host__ constexpr __device__ void operator()(f8_t &a, f8_t b) const
Definition: reduction_operator.hpp:137
__host__ static constexpr __device__ T GetIdentityValue()
Definition: reduction_operator.hpp:116
__host__ constexpr __device__ void operator()(bhalf_t &a, bhalf_t b) const
Definition: reduction_operator.hpp:153
__host__ constexpr __device__ void operator()(half_t &a, half_t b) const
Definition: reduction_operator.hpp:145
Definition: reduction_operator.hpp:87
__host__ static constexpr __device__ bool IsCompatibleInMemoryDataOperation(InMemoryDataOperationEnum operation)
Definition: reduction_operator.hpp:95
__host__ static constexpr __device__ T GetIdentityValue()
Definition: reduction_operator.hpp:89
__host__ constexpr __device__ void operator()(T &a, T b) const
Definition: reduction_operator.hpp:102