/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/grid/gridwise_2d_multiple_reduction_multiblock.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/grid/gridwise_2d_multiple_reduction_multiblock.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/grid/gridwise_2d_multiple_reduction_multiblock.hpp Source File
gridwise_2d_multiple_reduction_multiblock.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
13 
14 namespace ck {
15 
16 template <typename GridwiseMultipleReduction,
17  index_t NumReduction,
18  typename InDataType,
19  typename OutDataTypePointerTuple,
20  typename AccDataType,
21  typename InGridDesc_M_K,
22  typename OutGridDesc_M_Tuple,
23  typename InElementwiseOperationTuple,
24  typename AccElementwiseOperationTuple>
25 __global__ void
26 kernel_multiple_reduce_multiblock(const InGridDesc_M_K in_grid_desc_m_k,
27  const OutGridDesc_M_Tuple out_grid_desc_m_tuple,
28  const InElementwiseOperationTuple in_elementwise_op_tuple,
29  const AccElementwiseOperationTuple acc_elementwise_op_tuple,
30  index_t block_group_size,
31  index_t num_k_block_tile_iteration,
33  const InDataType* const __restrict__ p_in_value_global,
35  OutDataTypePointerTuple p_out_value_global_tuple)
36 {
37  GridwiseMultipleReduction::Run(in_grid_desc_m_k,
38  out_grid_desc_m_tuple,
39  in_elementwise_op_tuple,
40  acc_elementwise_op_tuple,
41  block_group_size,
42  num_k_block_tile_iteration,
43  alpha_values,
44  p_in_value_global,
45  beta_values,
46  p_out_value_global_tuple);
47 };
48 
49 template <index_t NumReduction,
50  typename InDataType,
51  typename OutDataTypePointerTuple,
52  typename AccDataType,
53  typename InGridDesc_M_K,
54  typename OutGridDesc_M_Tuple,
55  typename ReduceOperation,
56  typename InElementwiseOperationTuple,
57  typename AccElementwiseOperationTuple,
58  InMemoryDataOperationEnum OutMemoryDataOperation,
59  bool PropagateNan,
60  index_t BlockSize,
61  index_t MThreadClusterSize,
62  index_t KThreadClusterSize,
63  index_t MThreadSliceSize,
64  index_t KThreadSliceSize,
65  index_t InSrcVectorDim,
66  index_t InSrcVectorSize,
67  typename OutDstVectorSizeSeq>
69 {
70  static_assert(((InSrcVectorDim == 0 && MThreadSliceSize % InSrcVectorSize == 0) ||
71  (InSrcVectorDim == 1 && KThreadSliceSize % InSrcVectorSize == 0)),
72  "Invalid thread slice sizes and/or vector sizes configuration, please check!");
73 
74  static_assert(NumReduction == OutDataTypePointerTuple::Size() &&
75  NumReduction == OutGridDesc_M_Tuple::Size() &&
76  NumReduction == OutDstVectorSizeSeq::Size() &&
77  NumReduction == InElementwiseOperationTuple::Size() &&
78  NumReduction == AccElementwiseOperationTuple::Size(),
79  "All tuple should have the same size as the number of Reductions!");
80 
81  static constexpr bool reorder_thread_cluster = (InSrcVectorDim == 0);
82 
84 
87 
90 
91  static constexpr auto thread_cluster_desc =
93 
98 
100  BlockSize,
103  ReduceOperation,
104  PropagateNan>;
105 
109  ReduceOperation,
110  PropagateNan>;
111 
113 
114  static constexpr auto I0 = Number<0>{};
115  static constexpr auto I1 = Number<1>{};
116 
117  static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize;
118  static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize;
119 
121 
122  __device__ static void Run(const InGridDesc_M_K& in_grid_desc_m_k,
123  const OutGridDesc_M_Tuple& out_grid_desc_m_tuple,
124  const InElementwiseOperationTuple& in_elementwise_op_tuple,
125  const AccElementwiseOperationTuple& acc_elementwise_op_tuple,
126  index_t block_group_size,
127  index_t num_k_block_tile_iteration,
129  const InDataType* const __restrict__ p_in_value_global,
131  OutDataTypePointerTuple p_out_value_global_tuple)
132  {
133  const auto identityVal = ReduceOperation::template GetIdentityValue<AccDataType>();
134 
135  // LDS, reused by all reductions
136  __shared__ AccDataType p_reduce_work_buffer[BlockSize];
137 
138  const auto in_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
139  p_in_value_global,
140  in_grid_desc_m_k.GetElementSpaceSize(),
141  ReduceOperation::template GetIdentityValue<InDataType>());
142  auto out_global_val_buf_tuple = generate_tuple(
143  [&](auto iR) {
144  return make_dynamic_buffer<AddressSpaceEnum::Global>(
145  p_out_value_global_tuple[iR], out_grid_desc_m_tuple[iR].GetElementSpaceSize());
146  },
148 
149  auto reduce_work_buf =
150  make_dynamic_buffer<AddressSpaceEnum::Lds>(p_reduce_work_buffer, BlockSize);
151 
153  in_thread_buf;
154 
155  auto in_thread_buf_tuple = generate_tuple(
156  [&](auto iR) {
157  (void)iR;
159  AccDataType,
160  MThreadSliceSize * KThreadSliceSize,
161  true>{};
162  },
164 
165  auto accu_value_buf_tuple = generate_tuple(
166  [&](auto iR) {
167  (void)iR;
169  },
171 
172  static_for<0, NumReduction, 1>{}([&](auto iR) {
174  [&](auto J) { accu_value_buf_tuple(iR)(J) = identityVal; });
175  });
176 
177  const index_t thread_local_id = get_thread_local_1d_id();
178  const index_t block_global_id = get_block_1d_id();
179  const index_t blkgroup_id = block_global_id / block_group_size;
180  const index_t block_local_id = block_global_id % block_group_size;
181 
182  const auto thread_cluster_idx =
183  thread_cluster_desc.CalculateBottomIndex(make_multi_index(thread_local_id));
184 
185  const auto thread_m_cluster_id = thread_cluster_idx[I0];
186  const auto thread_k_cluster_id = thread_cluster_idx[I1];
187 
188  const index_t reduceSizePerBlock = K_BlockTileSize * num_k_block_tile_iteration;
189 
190  using ThreadBufferLengths = Sequence<MThreadSliceSize, KThreadSliceSize>;
191  constexpr auto thread_buffer_desc = make_naive_tensor_descriptor_packed(
193 
194  auto threadwise_src_load = ThreadwiseTensorSliceTransfer_v2<InDataType,
195  AccDataType,
196  InGridDesc_M_K,
197  decltype(thread_buffer_desc),
198  ThreadBufferLengths,
200  InSrcVectorDim,
201  InSrcVectorSize,
202  1,
203  false>(
204  in_grid_desc_m_k,
205  make_multi_index(blkgroup_id * M_BlockTileSize + thread_m_cluster_id * MThreadSliceSize,
206  block_local_id * reduceSizePerBlock +
207  thread_k_cluster_id * KThreadSliceSize));
208 
209  constexpr auto in_thread_copy_step = make_multi_index(0, K_BlockTileSize);
210 
211  index_t reducedTiles = 0;
212  do
213  {
214  threadwise_src_load.Run(in_grid_desc_m_k,
215  in_global_val_buf,
216  thread_buffer_desc,
217  make_tuple(I0, I0),
218  in_thread_buf);
219 
220  static_for<0, NumReduction, 1>{}([&](auto iR) {
221  static_for<0, MThreadSliceSize, 1>{}([&](auto iM) {
222  // do element-wise pre-reduction operation
223  static_for<0, KThreadSliceSize, 1>{}([&](auto iK) {
224  constexpr auto offset =
225  thread_buffer_desc.CalculateOffset(make_tuple(iM, iK));
226  in_elementwise_op_tuple[iR](in_thread_buf_tuple(iR)(Number<offset>{}),
227  in_thread_buf(Number<offset>{}));
228  });
229  });
230 
231  ThreadwiseReduce::Reduce(in_thread_buf_tuple(iR), accu_value_buf_tuple(iR));
232  });
233 
234  threadwise_src_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_step);
235 
236  reducedTiles++;
237  } while(reducedTiles < num_k_block_tile_iteration);
238 
239  constexpr auto reduced_data_desc = ThreadReduceDstDesc_M{};
240 
241  static_for<0, NumReduction, 1>{}([&](auto iR) {
242  using OutDataTypePointer = remove_cvref_t<decltype(OutDataTypePointerTuple{}[iR])>;
244 
245  static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
246  BlockwiseReduce::Reduce(reduce_work_buf, accu_value_buf_tuple(iR)(I));
247  });
248 
249  static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
250  if(thread_k_cluster_id == 0)
251  {
252  acc_elementwise_op_tuple[iR](accu_value_buf_tuple(iR)(I),
253  accu_value_buf_tuple(iR)(I));
254 
255  accu_value_buf_tuple(iR)(I) *= alpha_values[iR];
256  }
257  });
258 
259  if(thread_k_cluster_id == 0)
260  {
261  if(!float_equal_zero{}(beta_values[iR]))
262  {
264  priorDstValueBuf;
265 
266  auto threadwise_dst_load =
268  OutDataType,
269  decltype(out_grid_desc_m_tuple[iR]),
270  decltype(reduced_data_desc),
272  Sequence<0>,
273  0,
274  OutDstVectorSizeSeq::At(iR),
275  1,
276  false>(
277  out_grid_desc_m_tuple[iR],
278  make_multi_index(blkgroup_id * M_BlockTileSize +
279  thread_m_cluster_id * MThreadSliceSize));
280 
281  threadwise_dst_load.Run(out_grid_desc_m_tuple[iR],
282  out_global_val_buf_tuple(iR),
283  reduced_data_desc,
284  make_tuple(I0),
285  priorDstValueBuf);
286 
287  static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
288  accu_value_buf_tuple(iR)(I) +=
289  type_convert<AccDataType>(priorDstValueBuf[I]) * beta_values[iR];
290  });
291  };
292 
293  auto threadwise_dst_store =
295  OutDataType,
296  decltype(reduced_data_desc),
297  decltype(out_grid_desc_m_tuple[iR]),
300  Sequence<0>,
301  0,
302  OutDstVectorSizeSeq::At(iR),
303  OutMemoryDataOperation,
304  1,
305  true>(
306  out_grid_desc_m_tuple[iR],
307  make_multi_index(blkgroup_id * M_BlockTileSize +
308  thread_m_cluster_id * MThreadSliceSize),
309  PassThroughOp{});
310 
311  threadwise_dst_store.Run(reduced_data_desc,
312  make_tuple(I0),
313  accu_value_buf_tuple[iR],
314  out_grid_desc_m_tuple[iR],
315  out_global_val_buf_tuple(iR));
316  };
317  });
318  };
319 }; // namespace ck
320 
321 } // namespace ck
Definition: ck.hpp:267
__host__ constexpr __device__ auto make_multi_index(Xs &&... xs)
Definition: array_multi_index.hpp:15
__host__ constexpr __device__ auto generate_tuple(F &&f, Number< N >)
Definition: tuple_helper.hpp:21
InMemoryDataOperationEnum
Definition: ck.hpp:276
__host__ constexpr __device__ auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition: tensor_descriptor_helper.hpp:101
__device__ index_t get_block_1d_id()
Definition: get_id.hpp:58
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition: type.hpp:297
__global__ void kernel_multiple_reduce_multiblock(const InGridDesc_M_K in_grid_desc_m_k, const OutGridDesc_M_Tuple out_grid_desc_m_tuple, const InElementwiseOperationTuple in_elementwise_op_tuple, const AccElementwiseOperationTuple acc_elementwise_op_tuple, index_t block_group_size, index_t num_k_block_tile_iteration, Array< AccDataType, NumReduction > alpha_values, const InDataType *const __restrict__ p_in_value_global, Array< AccDataType, NumReduction > beta_values, OutDataTypePointerTuple p_out_value_global_tuple)
Definition: gridwise_2d_multiple_reduction_multiblock.hpp:26
int32_t index_t
Definition: ck.hpp:298
__device__ index_t get_thread_local_1d_id()
Definition: get_id.hpp:52
__host__ constexpr __device__ auto make_cluster_descriptor(const Lengths &lengths, ArrangeOrder order=typename arithmetic_sequence_gen< 0, Lengths::Size(), 1 >::type{})
Definition: cluster_descriptor.hpp:13
Definition: gridwise_2d_multiple_reduction_multiblock.hpp:69
static constexpr index_t K_BlockTileSize
Definition: gridwise_2d_multiple_reduction_multiblock.hpp:118
typename conditional< reorder_thread_cluster, Sequence< 1, 0 >, Sequence< 0, 1 > >::type ThreadClusterArrangeOrder
Definition: gridwise_2d_multiple_reduction_multiblock.hpp:89
tensor_operation::element_wise::PassThrough PassThroughOp
Definition: gridwise_2d_multiple_reduction_multiblock.hpp:112
typename conditional< reorder_thread_cluster, Sequence< 1, 0 >, Sequence< 0, 1 > >::type ThreadBufferDimAccessOrder
Definition: gridwise_2d_multiple_reduction_multiblock.hpp:86
Sequence< MThreadClusterSize, KThreadClusterSize > ThreadClusterLengths_M_K
Definition: gridwise_2d_multiple_reduction_multiblock.hpp:83
static __device__ void Run(const InGridDesc_M_K &in_grid_desc_m_k, const OutGridDesc_M_Tuple &out_grid_desc_m_tuple, const InElementwiseOperationTuple &in_elementwise_op_tuple, const AccElementwiseOperationTuple &acc_elementwise_op_tuple, index_t block_group_size, index_t num_k_block_tile_iteration, Array< AccDataType, NumReduction > alpha_values, const InDataType *const __restrict__ p_in_value_global, Array< AccDataType, NumReduction > beta_values, OutDataTypePointerTuple p_out_value_global_tuple)
Definition: gridwise_2d_multiple_reduction_multiblock.hpp:122
static constexpr auto I0
Definition: gridwise_2d_multiple_reduction_multiblock.hpp:114
static constexpr auto I1
Definition: gridwise_2d_multiple_reduction_multiblock.hpp:115
static constexpr index_t M_BlockTileSize
Definition: gridwise_2d_multiple_reduction_multiblock.hpp:117
static constexpr auto thread_cluster_desc
Definition: gridwise_2d_multiple_reduction_multiblock.hpp:91
decltype(make_naive_tensor_descriptor_packed(make_tuple(Number< MThreadSliceSize >{}))) ThreadReduceDstDesc_M
Definition: gridwise_2d_multiple_reduction_multiblock.hpp:97
static constexpr bool reorder_thread_cluster
Definition: gridwise_2d_multiple_reduction_multiblock.hpp:81
decltype(make_naive_tensor_descriptor_packed(make_tuple(Number< MThreadSliceSize >{}, Number< KThreadSliceSize >{}))) ThreadReduceSrcDesc_M_K
Definition: gridwise_2d_multiple_reduction_multiblock.hpp:95
Definition: reduction_functions_blockwise.hpp:28
static __device__ void Reduce(BufferType &work_buffer, AccDataType &in_out_value)
Definition: reduction_functions_blockwise.hpp:44
Definition: sequence.hpp:43
Definition: static_buffer.hpp:16
Definition: reduction_functions_threadwise.hpp:23
static __device__ void Reduce(const SrcBufferType &src_buf, DstBufferType &dst_buf)
Definition: reduction_functions_threadwise.hpp:36
Definition: threadwise_tensor_slice_transfer.hpp:39
Helper structure that facilitates transfer of source (grid) data to destination threads.
Definition: threadwise_tensor_slice_transfer.hpp:234
Definition: functional.hpp:100
Definition: reduction_functions_accumulate.hpp:28
Definition: reduction_common.hpp:20
Definition: integral_constant.hpp:20
Definition: functional2.hpp:33
Definition: unary_element_wise_operation.hpp:334