/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/block/reduction_functions_blockwise.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/block/reduction_functions_blockwise.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/block/reduction_functions_blockwise.hpp Source File
reduction_functions_blockwise.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
9 
10 namespace ck {
11 
12 // clang-format off
13 // Assume:
14 // 1) work_buffer is buffer (typically LDS) allocated outside as workspace, does not include any in/out data
15 // 2) work_buffer has AccDataType elements, and space size is no less than BlockSize
16 // 3) in_out_value is the input data in vgpr from each thread
17 // 4) in_out_value is the over-written reduced output in vgpr for each thread
18 // clang-format on
19 template <typename AccDataType,
20  index_t BlockSize,
21  typename ThreadClusterLengths_M_K,
22  typename ThreadClusterArrangeOrder,
23  typename OpReduce,
24  bool PropagateNan,
25  typename Accumulation =
26  detail::AccumulateWithNanCheck<PropagateNan, OpReduce, AccDataType>>
28 {
29  static_assert(BlockSize == ThreadClusterLengths_M_K::At(0) * ThreadClusterLengths_M_K::At(1),
30  "The product of cluster lengths should be same as BlockSize!");
31 
32  static constexpr auto BufferLength_M = ThreadClusterLengths_M_K::At(0);
33  static constexpr auto BufferLength_K = ThreadClusterLengths_M_K::At(1);
34 
35  static_assert(BufferLength_K > 1, "Parallel reduction need work on at least two elements");
36 
39 
40  static constexpr auto thread_cluster_desc =
41  make_cluster_descriptor(ThreadClusterLengths_M_K{}, ThreadClusterArrangeOrder{});
42 
43  template <typename BufferType>
44  __device__ static void Reduce(BufferType& work_buffer, AccDataType& in_out_value)
45  {
47  "Buffer data type should be consistent as AccDataType!");
48 
49  constexpr auto cluster_len_shift = get_shift<BufferLength_K>();
50 
51  const auto thread_cluster_idx =
53 
54  const auto thread_m_cluster_id = thread_cluster_idx[Number<0>{}];
55  const auto thread_k_cluster_id = thread_cluster_idx[Number<1>{}];
56 
57  work_buffer(block_buf_desc_m_k.CalculateOffset(thread_cluster_idx)) = in_out_value;
58 
59  __syncthreads();
60 
62  constexpr index_t indOffset = 1 << (cluster_len_shift - 1 - I());
63 
64  if(thread_k_cluster_id < indOffset)
65  {
66  index_t offset1 = block_buf_desc_m_k.CalculateOffset(thread_cluster_idx);
67  index_t offset2 = block_buf_desc_m_k.CalculateOffset(thread_cluster_idx +
68  make_tuple(0, indOffset));
69 
70  AccDataType opData1 = work_buffer[offset1];
71  AccDataType opData2 = work_buffer[offset2];
72  Accumulation::Calculate(opData1, opData2);
73  work_buffer(offset1) = opData1;
74  }
75 
76  __syncthreads();
77  });
78 
79  index_t offset = block_buf_desc_m_k.CalculateOffset(make_tuple(thread_m_cluster_id, 0));
80 
81  in_out_value = work_buffer[offset];
82  };
83 };
84 
85 // clang-format off
86 // Assume:
87 // 1) work_buffer is buffer (typically LDS) allocated outside as workspace, does not include any in/out data
88 // 2) work_buffer has AccDataType elements, and space size is no less than BlockSize
89 // 3) in_out_value is the input data in vgpr from each thread
90 // 4) in_out_value is the over-written reduced output in vgpr for each thread
91 // clang-format on
92 template <typename AccDataType,
93  index_t BlockSize,
94  typename ThreadClusterLengths_M_K,
95  typename ThreadClusterDesc,
96  typename OpReduce,
97  bool PropagateNan,
98  typename Accumulation =
99  detail::AccumulateWithNanCheck<PropagateNan, OpReduce, AccDataType>>
101 {
102  static_assert(BlockSize == ThreadClusterLengths_M_K::At(0) * ThreadClusterLengths_M_K::At(1),
103  "The product of cluster lengths should be same as BlockSize!");
104 
105  static constexpr auto BufferLength_M = ThreadClusterLengths_M_K::At(0);
106  static constexpr auto BufferLength_K = ThreadClusterLengths_M_K::At(1);
107 
108  static_assert(BufferLength_K > 1, "Parallel reduction need work on at least two elements");
109 
112 
113  static constexpr auto thread_cluster_desc = ThreadClusterDesc{};
114 
115  template <typename BufferType>
116  __device__ static void Reduce(BufferType& work_buffer, AccDataType& in_out_value)
117  {
119  "Buffer data type should be consistent as AccDataType!");
120 
121  constexpr auto cluster_len_shift = get_shift<BufferLength_K>();
122 
123  const auto thread_cluster_idx =
125 
126  const auto thread_m_cluster_id = thread_cluster_idx[Number<0>{}];
127  const auto thread_k_cluster_id = thread_cluster_idx[Number<1>{}];
128 
129  work_buffer(block_buf_desc_m_k.CalculateOffset(thread_cluster_idx)) = in_out_value;
130 
131  __syncthreads();
132 
134  constexpr index_t indOffset = 1 << (cluster_len_shift - 1 - I());
135 
136  if(thread_k_cluster_id < indOffset)
137  {
138  index_t offset1 = block_buf_desc_m_k.CalculateOffset(thread_cluster_idx);
139  index_t offset2 = block_buf_desc_m_k.CalculateOffset(thread_cluster_idx +
140  make_tuple(0, indOffset));
141 
142  AccDataType opData1 = work_buffer[offset1];
143  AccDataType opData2 = work_buffer[offset2];
144  Accumulation::Calculate(opData1, opData2);
145  work_buffer(offset1) = opData1;
146  }
147 
148  __syncthreads();
149  });
150 
151  index_t offset = block_buf_desc_m_k.CalculateOffset(make_tuple(thread_m_cluster_id, 0));
152 
153  in_out_value = work_buffer[offset];
154  };
155 };
156 
157 // clang-format off
158 // Assume:
159 // 1) work_val_buffer/work_idx_buffer is buffer (typically LDS) allocated outside as workspace, does not include any in/out data
160 // 2) work_val_buffer/work_idx_buffer has AccDataType/IndexDataType elements, and space size is no less than BlockSize
161 // 3) in_out_value/in_out_index is the input data in vgpr from each thread
162 // 4) in_out_value/in_out_index is the over-written reduced output in vgpr for each thread
163 // clang-format on
164 template <
165  typename AccDataType,
166  typename IndexDataType,
167  index_t BlockSize,
168  typename ThreadClusterLengths_M_K,
169  typename ThreadClusterArrangeOrder,
170  typename OpReduce,
171  bool PropagateNan,
172  typename Accumulation =
173  detail::AccumulateWithIndexAndNanCheck<PropagateNan, OpReduce, AccDataType, IndexDataType>>
175 {
176  static_assert(BlockSize == ThreadClusterLengths_M_K::At(0) * ThreadClusterLengths_M_K::At(1),
177  "The product of cluster lengths should be same as BlockSize!");
178 
179  static constexpr auto BufferLength_M = ThreadClusterLengths_M_K::At(0);
180  static constexpr auto BufferLength_K = ThreadClusterLengths_M_K::At(1);
181 
182  static_assert(BufferLength_K > 1, "Parallel reduction need work on at least two elements");
183 
186 
187  static constexpr auto thread_cluster_desc =
188  make_cluster_descriptor(ThreadClusterLengths_M_K{}, ThreadClusterArrangeOrder{});
189 
190  // This interface accumulates on both data values and indices
191  template <typename BufferType, typename IdxBufferType>
192  __device__ static void Reduce(BufferType& work_val_buffer,
193  IdxBufferType& work_idx_buffer,
194  AccDataType& in_out_value,
195  IndexDataType& in_out_index)
196  {
198  "Buffer data type should be consistent as AccDataType!");
200  "Buffer data type should be consistent as IndexDataType!");
201 
202  constexpr auto cluster_len_shift = get_shift<BufferLength_K>();
203 
204  const auto thread_cluster_idx =
206 
207  const auto thread_m_cluster_id = thread_cluster_idx[Number<0>{}];
208  const auto thread_k_cluster_id = thread_cluster_idx[Number<1>{}];
209 
210  work_val_buffer(block_buf_desc_m_k.CalculateOffset(thread_cluster_idx)) = in_out_value;
211  work_idx_buffer(block_buf_desc_m_k.CalculateOffset(thread_cluster_idx)) = in_out_index;
212 
213  __syncthreads();
214 
216  constexpr index_t indOffset = 1 << I();
217 
218  if(thread_k_cluster_id % (indOffset * 2) == 0)
219  {
220  index_t offset1 = block_buf_desc_m_k.CalculateOffset(thread_cluster_idx);
221  index_t offset2 = block_buf_desc_m_k.CalculateOffset(thread_cluster_idx +
222  make_tuple(0, indOffset));
223 
224  AccDataType opData1 = work_val_buffer[offset1];
225  AccDataType opData2 = work_val_buffer[offset2];
226  IndexDataType currIndex1 = work_idx_buffer[offset1];
227  IndexDataType currIndex2 = work_idx_buffer[offset2];
228 
229  Accumulation::Calculate(opData1, opData2, currIndex1, currIndex2);
230  work_val_buffer(offset1) = opData1;
231  work_idx_buffer(offset1) = currIndex1;
232  }
233 
234  __syncthreads();
235  });
236 
237  index_t offset = block_buf_desc_m_k.CalculateOffset(make_tuple(thread_m_cluster_id, 0));
238 
239  in_out_value = work_val_buffer[offset];
240  in_out_index = work_idx_buffer[offset];
241  };
242 };
243 
244 } // namespace ck
Definition: ck.hpp:267
__host__ constexpr __device__ auto make_multi_index(Xs &&... xs)
Definition: array_multi_index.hpp:15
__host__ constexpr __device__ auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition: tensor_descriptor_helper.hpp:101
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
int32_t index_t
Definition: ck.hpp:298
__device__ index_t get_thread_local_1d_id()
Definition: get_id.hpp:52
__host__ constexpr __device__ auto make_cluster_descriptor(const Lengths &lengths, ArrangeOrder order=typename arithmetic_sequence_gen< 0, Lengths::Size(), 1 >::type{})
Definition: cluster_descriptor.hpp:13
Definition: reduction_functions_blockwise.hpp:101
static constexpr auto BufferLength_M
Definition: reduction_functions_blockwise.hpp:105
static constexpr auto thread_cluster_desc
Definition: reduction_functions_blockwise.hpp:113
static constexpr auto block_buf_desc_m_k
Definition: reduction_functions_blockwise.hpp:110
static __device__ void Reduce(BufferType &work_buffer, AccDataType &in_out_value)
Definition: reduction_functions_blockwise.hpp:116
static constexpr auto BufferLength_K
Definition: reduction_functions_blockwise.hpp:106
Definition: reduction_functions_blockwise.hpp:28
static constexpr auto BufferLength_K
Definition: reduction_functions_blockwise.hpp:33
static constexpr auto block_buf_desc_m_k
Definition: reduction_functions_blockwise.hpp:37
static constexpr auto thread_cluster_desc
Definition: reduction_functions_blockwise.hpp:40
static __device__ void Reduce(BufferType &work_buffer, AccDataType &in_out_value)
Definition: reduction_functions_blockwise.hpp:44
static constexpr auto BufferLength_M
Definition: reduction_functions_blockwise.hpp:32
Definition: reduction_functions_blockwise.hpp:175
static constexpr auto BufferLength_K
Definition: reduction_functions_blockwise.hpp:180
static constexpr auto block_buf_desc_m_k
Definition: reduction_functions_blockwise.hpp:184
static constexpr auto thread_cluster_desc
Definition: reduction_functions_blockwise.hpp:187
static __device__ void Reduce(BufferType &work_val_buffer, IdxBufferType &work_idx_buffer, AccDataType &in_out_value, IndexDataType &in_out_index)
Definition: reduction_functions_blockwise.hpp:192
static constexpr auto BufferLength_M
Definition: reduction_functions_blockwise.hpp:179
Definition: integral_constant.hpp:20
Definition: type.hpp:177
Definition: functional2.hpp:33