/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/utility/reduction_functions_accumulate.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/utility/reduction_functions_accumulate.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/utility/reduction_functions_accumulate.hpp Source File
reduction_functions_accumulate.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
7 #include "ck/utility/math_v2.hpp"
10 
11 namespace ck {
12 namespace detail {
13 
14 // Check for NaN; guarantee NaNs are NOT propagated to result (i.e., ignore NaNs)
15 template <typename ReduceOperation, typename AccDataType>
17 {
18  __device__ static inline void Calculate(AccDataType& accuVal, AccDataType currVal)
19  {
20  if(!ck::math::isnan(currVal))
21  {
22  ReduceOperation{}(accuVal, currVal);
23  }
24  };
25 };
26 
27 template <bool PropagateNan, typename ReduceOperation, typename AccDataType>
29 
30 // Does not check for NaN; does not guarantee NaNs be propagated to result
31 // e.g., given that max(a, b) = a > b ? a : b
32 // then max(NaN, 1) returns 1
33 // max(1, NaN) returns NaN
34 // since any comparison involving NaNs returns false
35 template <typename ReduceOperation, typename AccDataType>
36 struct AccumulateWithNanCheck<false, ReduceOperation, AccDataType>
37 {
38  // cppcheck-suppress constParameter
39  __host__ __device__ static inline void Calculate(AccDataType& accuVal, AccDataType currVal)
40  {
41  ReduceOperation{}(accuVal, currVal);
42  };
43 };
44 
45 // Check for NaN; guarantees NaNs be propagated to result
46 template <typename ReduceOperation, typename AccDataType>
47 struct AccumulateWithNanCheck<true, ReduceOperation, AccDataType>
48 {
49  __host__ __device__ static inline void Calculate(AccDataType& accuVal, AccDataType currVal)
50  {
51  using ck::math::isnan;
52 
53  if(isnan(currVal))
54  {
55  accuVal = currVal;
56  }
57  else
58  {
59  ReduceOperation{}(accuVal, currVal);
60  };
61  };
62 };
63 
64 template <bool PropagateNan, typename ReduceOperation, typename AccDataType, typename IndexDataType>
66 
67 template <typename ReduceOperation, typename AccDataType, typename IndexDataType>
68 struct AccumulateWithIndexAndNanCheck<false, ReduceOperation, AccDataType, IndexDataType>
69 {
70  __host__ __device__ static inline void
71  // cppcheck-suppress constParameter
72  Calculate(AccDataType& accuVal,
73  AccDataType currVal,
74  IndexDataType& accuIndex,
75  IndexDataType currIndex)
76  {
77  bool changed = false;
78 
79  ReduceOperation{}(accuVal, currVal, changed);
80 
81  if(changed)
82  accuIndex = currIndex;
83  };
84 };
85 
86 template <typename ReduceOperation, typename AccDataType, typename IndexDataType>
87 struct AccumulateWithIndexAndNanCheck<true, ReduceOperation, AccDataType, IndexDataType>
88 {
89  // The method is called when the ReduceOperation is indexable and the user asked for indices
90  __host__ __device__ static inline void Calculate(AccDataType& accuVal,
91  AccDataType currVal,
92  IndexDataType& accuIndex,
93  IndexDataType currIndex)
94  {
95  using ck::math::isnan;
96 
97  if(isnan(currVal))
98  {
99  accuVal = currVal;
100  accuIndex = currIndex;
101  }
102  else
103  {
104  bool changed = false;
105 
106  ReduceOperation{}(accuVal, currVal, changed);
107 
108  if(changed)
109  accuIndex = currIndex;
110  }
111  };
112 };
113 
114 } // namespace detail
115 } // namespace ck
Definition: ck.hpp:267
__host__ static __device__ void Calculate(AccDataType &accuVal, AccDataType currVal, IndexDataType &accuIndex, IndexDataType currIndex)
Definition: reduction_functions_accumulate.hpp:72
__host__ static __device__ void Calculate(AccDataType &accuVal, AccDataType currVal, IndexDataType &accuIndex, IndexDataType currIndex)
Definition: reduction_functions_accumulate.hpp:90
Definition: reduction_functions_accumulate.hpp:65
__host__ static __device__ void Calculate(AccDataType &accuVal, AccDataType currVal)
Definition: reduction_functions_accumulate.hpp:39
__host__ static __device__ void Calculate(AccDataType &accuVal, AccDataType currVal)
Definition: reduction_functions_accumulate.hpp:49
Definition: reduction_functions_accumulate.hpp:28
Definition: reduction_functions_accumulate.hpp:17
static __device__ void Calculate(AccDataType &accuVal, AccDataType currVal)
Definition: reduction_functions_accumulate.hpp:18