19 template <
typename AccDataType,
21 typename ThreadClusterLengths_M_K,
22 typename ThreadClusterArrangeOrder,
25 typename Accumulation =
26 detail::AccumulateWithNanCheck<PropagateNan, OpReduce, AccDataType>>
29 static_assert(BlockSize == ThreadClusterLengths_M_K::At(0) * ThreadClusterLengths_M_K::At(1),
30 "The product of cluster lengths should be same as BlockSize!");
35 static_assert(
BufferLength_K > 1,
"Parallel reduction need work on at least two elements");
43 template <
typename BufferType>
44 __device__
static void Reduce(BufferType& work_buffer, AccDataType& in_out_value)
47 "Buffer data type should be consistent as AccDataType!");
49 constexpr
auto cluster_len_shift = get_shift<BufferLength_K>();
51 const auto thread_cluster_idx =
54 const auto thread_m_cluster_id = thread_cluster_idx[
Number<0>{}];
55 const auto thread_k_cluster_id = thread_cluster_idx[
Number<1>{}];
62 constexpr
index_t indOffset = 1 << (cluster_len_shift - 1 - I());
64 if(thread_k_cluster_id < indOffset)
70 AccDataType opData1 = work_buffer[offset1];
71 AccDataType opData2 = work_buffer[offset2];
72 Accumulation::Calculate(opData1, opData2);
73 work_buffer(offset1) = opData1;
81 in_out_value = work_buffer[offset];
92 template <
typename AccDataType,
94 typename ThreadClusterLengths_M_K,
95 typename ThreadClusterDesc,
98 typename Accumulation =
99 detail::AccumulateWithNanCheck<PropagateNan, OpReduce, AccDataType>>
102 static_assert(BlockSize == ThreadClusterLengths_M_K::At(0) * ThreadClusterLengths_M_K::At(1),
103 "The product of cluster lengths should be same as BlockSize!");
108 static_assert(
BufferLength_K > 1,
"Parallel reduction need work on at least two elements");
115 template <
typename BufferType>
116 __device__
static void Reduce(BufferType& work_buffer, AccDataType& in_out_value)
119 "Buffer data type should be consistent as AccDataType!");
121 constexpr
auto cluster_len_shift = get_shift<BufferLength_K>();
123 const auto thread_cluster_idx =
126 const auto thread_m_cluster_id = thread_cluster_idx[
Number<0>{}];
127 const auto thread_k_cluster_id = thread_cluster_idx[
Number<1>{}];
134 constexpr
index_t indOffset = 1 << (cluster_len_shift - 1 - I());
136 if(thread_k_cluster_id < indOffset)
142 AccDataType opData1 = work_buffer[offset1];
143 AccDataType opData2 = work_buffer[offset2];
144 Accumulation::Calculate(opData1, opData2);
145 work_buffer(offset1) = opData1;
153 in_out_value = work_buffer[offset];
165 typename AccDataType,
166 typename IndexDataType,
168 typename ThreadClusterLengths_M_K,
169 typename ThreadClusterArrangeOrder,
172 typename Accumulation =
173 detail::AccumulateWithIndexAndNanCheck<PropagateNan, OpReduce, AccDataType, IndexDataType>>
176 static_assert(BlockSize == ThreadClusterLengths_M_K::At(0) * ThreadClusterLengths_M_K::At(1),
177 "The product of cluster lengths should be same as BlockSize!");
182 static_assert(
BufferLength_K > 1,
"Parallel reduction need work on at least two elements");
191 template <
typename BufferType,
typename IdxBufferType>
192 __device__
static void Reduce(BufferType& work_val_buffer,
193 IdxBufferType& work_idx_buffer,
194 AccDataType& in_out_value,
195 IndexDataType& in_out_index)
198 "Buffer data type should be consistent as AccDataType!");
200 "Buffer data type should be consistent as IndexDataType!");
202 constexpr
auto cluster_len_shift = get_shift<BufferLength_K>();
204 const auto thread_cluster_idx =
207 const auto thread_m_cluster_id = thread_cluster_idx[
Number<0>{}];
208 const auto thread_k_cluster_id = thread_cluster_idx[
Number<1>{}];
210 work_val_buffer(
block_buf_desc_m_k.CalculateOffset(thread_cluster_idx)) = in_out_value;
211 work_idx_buffer(
block_buf_desc_m_k.CalculateOffset(thread_cluster_idx)) = in_out_index;
216 constexpr
index_t indOffset = 1 << I();
218 if(thread_k_cluster_id % (indOffset * 2) == 0)
224 AccDataType opData1 = work_val_buffer[offset1];
225 AccDataType opData2 = work_val_buffer[offset2];
226 IndexDataType currIndex1 = work_idx_buffer[offset1];
227 IndexDataType currIndex2 = work_idx_buffer[offset2];
229 Accumulation::Calculate(opData1, opData2, currIndex1, currIndex2);
230 work_val_buffer(offset1) = opData1;
231 work_idx_buffer(offset1) = currIndex1;
239 in_out_value = work_val_buffer[offset];
240 in_out_index = work_idx_buffer[offset];
__host__ constexpr __device__ auto make_multi_index(Xs &&... xs)
Definition: array_multi_index.hpp:15
__host__ constexpr __device__ auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition: tensor_descriptor_helper.hpp:101
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
int32_t index_t
Definition: ck.hpp:298
__device__ index_t get_thread_local_1d_id()
Definition: get_id.hpp:52
__host__ constexpr __device__ auto make_cluster_descriptor(const Lengths &lengths, ArrangeOrder order=typename arithmetic_sequence_gen< 0, Lengths::Size(), 1 >::type{})
Definition: cluster_descriptor.hpp:13
Definition: reduction_functions_blockwise.hpp:101
static constexpr auto BufferLength_M
Definition: reduction_functions_blockwise.hpp:105
static constexpr auto thread_cluster_desc
Definition: reduction_functions_blockwise.hpp:113
static constexpr auto block_buf_desc_m_k
Definition: reduction_functions_blockwise.hpp:110
static __device__ void Reduce(BufferType &work_buffer, AccDataType &in_out_value)
Definition: reduction_functions_blockwise.hpp:116
static constexpr auto BufferLength_K
Definition: reduction_functions_blockwise.hpp:106
Definition: reduction_functions_blockwise.hpp:28
static constexpr auto BufferLength_K
Definition: reduction_functions_blockwise.hpp:33
static constexpr auto block_buf_desc_m_k
Definition: reduction_functions_blockwise.hpp:37
static constexpr auto thread_cluster_desc
Definition: reduction_functions_blockwise.hpp:40
static __device__ void Reduce(BufferType &work_buffer, AccDataType &in_out_value)
Definition: reduction_functions_blockwise.hpp:44
static constexpr auto BufferLength_M
Definition: reduction_functions_blockwise.hpp:32
Definition: reduction_functions_blockwise.hpp:175
static constexpr auto BufferLength_K
Definition: reduction_functions_blockwise.hpp:180
static constexpr auto block_buf_desc_m_k
Definition: reduction_functions_blockwise.hpp:184
static constexpr auto thread_cluster_desc
Definition: reduction_functions_blockwise.hpp:187
static __device__ void Reduce(BufferType &work_val_buffer, IdxBufferType &work_idx_buffer, AccDataType &in_out_value, IndexDataType &in_out_index)
Definition: reduction_functions_blockwise.hpp:192
static constexpr auto BufferLength_M
Definition: reduction_functions_blockwise.hpp:179
Definition: integral_constant.hpp:20
Definition: functional2.hpp:33