/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/grid/gridwise_2d_multiple_reduction_threadwise.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/grid/gridwise_2d_multiple_reduction_threadwise.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/grid/gridwise_2d_multiple_reduction_threadwise.hpp Source File
gridwise_2d_multiple_reduction_threadwise.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
13 
14 namespace ck {
15 
16 template <typename GridwiseMultipleReduction,
17  index_t NumReduction,
18  typename InDataType,
19  typename OutDataTypePointerTuple,
20  typename AccDataType,
21  typename InGridDesc_M_K,
22  typename OutGridDesc_M_Tuple,
23  typename InElementwiseOperationTuple,
24  typename AccElementwiseOperationTuple>
25 __global__ void
26 kernel_multiple_reduce_threadwise(const InGridDesc_M_K in_grid_desc_m_k,
27  const OutGridDesc_M_Tuple out_grid_desc_m_tuple,
28  const InElementwiseOperationTuple in_elementwise_op_tuple,
29  const AccElementwiseOperationTuple acc_elementwise_op_tuple,
31  const InDataType* const __restrict__ p_in_value_global,
33  OutDataTypePointerTuple p_out_value_global_tuple)
34 {
35  GridwiseMultipleReduction::Run(in_grid_desc_m_k,
36  out_grid_desc_m_tuple,
37  in_elementwise_op_tuple,
38  acc_elementwise_op_tuple,
39  alpha_values,
40  p_in_value_global,
41  beta_values,
42  p_out_value_global_tuple);
43 };
44 
45 template <index_t NumReduction,
46  typename InDataType,
47  typename OutDataTypePointerTuple,
48  typename AccDataType,
49  typename InGridDesc_M_K,
50  typename OutGridDesc_M_Tuple,
51  typename ReduceOperation,
52  typename InElementwiseOperationTuple,
53  typename AccElementwiseOperationTuple,
54  InMemoryDataOperationEnum OutMemoryDataOperation,
55  bool PropagateNan,
56  index_t BlockSize,
57  index_t MThreadSliceSize,
58  index_t KThreadSliceSize,
59  index_t InSrcVectorDim,
60  index_t InSrcVectorSize,
61  typename OutDstVectorSizeSeq>
63 {
64  static_assert(((InSrcVectorDim == 0 && MThreadSliceSize % InSrcVectorSize == 0) ||
65  (InSrcVectorDim == 1 && KThreadSliceSize % InSrcVectorSize == 0)),
66  "Invalid thread slice sizes and/or vector sizes configuration, please check!");
67 
68  static_assert(NumReduction == OutDataTypePointerTuple::Size() &&
69  NumReduction == OutGridDesc_M_Tuple::Size() &&
70  NumReduction == OutDstVectorSizeSeq::Size() &&
71  NumReduction == InElementwiseOperationTuple::Size() &&
72  NumReduction == AccElementwiseOperationTuple::Size(),
73  "All tuple should have the same size as the number of Reductions!");
74 
75  static constexpr bool reorder_thread_cluster = (InSrcVectorDim == 0);
76 
79 
84 
88  ReduceOperation,
89  PropagateNan>;
90 
92 
93  static constexpr auto I0 = Number<0>{};
94 
96 
97  __device__ static void Run(const InGridDesc_M_K& in_grid_desc_m_k,
98  const OutGridDesc_M_Tuple& out_grid_desc_m_tuple,
99  const InElementwiseOperationTuple& in_elementwise_op_tuple,
100  const AccElementwiseOperationTuple& acc_elementwise_op_tuple,
102  const InDataType* const __restrict__ p_in_value_global,
104  OutDataTypePointerTuple p_out_value_global_tuple)
105  {
106  const auto identityVal = ReduceOperation::template GetIdentityValue<AccDataType>();
107 
108  const auto in_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
109  p_in_value_global,
110  in_grid_desc_m_k.GetElementSpaceSize(),
111  ReduceOperation::template GetIdentityValue<InDataType>());
112  auto out_global_val_buf_tuple = generate_tuple(
113  [&](auto iR) {
114  return make_dynamic_buffer<AddressSpaceEnum::Global>(
115  p_out_value_global_tuple[iR], out_grid_desc_m_tuple[iR].GetElementSpaceSize());
116  },
118 
120  in_thread_buf;
121 
122  auto in_thread_buf_tuple = generate_tuple(
123  [&](auto iR) {
124  (void)iR;
126  AccDataType,
127  MThreadSliceSize * KThreadSliceSize,
128  true>{};
129  },
131 
132  auto accu_value_buf_tuple = generate_tuple(
133  [&](auto iR) {
134  (void)iR;
136  },
138 
139  static_for<0, NumReduction, 1>{}([&](auto iR) {
141  [&](auto J) { accu_value_buf_tuple(iR)(J) = identityVal; });
142  });
143 
144  const index_t thread_global_1d_id = get_thread_global_1d_id();
145 
146  const auto toReduceLength = in_grid_desc_m_k.GetLength(Number<1>{});
147 
148  using ThreadBufferLengths = Sequence<MThreadSliceSize, KThreadSliceSize>;
149  constexpr auto thread_buffer_desc = make_naive_tensor_descriptor_packed(
151 
152  auto threadwise_src_load = ThreadwiseTensorSliceTransfer_v2<InDataType,
153  AccDataType,
154  InGridDesc_M_K,
155  decltype(thread_buffer_desc),
156  ThreadBufferLengths,
158  InSrcVectorDim,
159  InSrcVectorSize,
160  1,
161  false>(
162  in_grid_desc_m_k, make_multi_index(thread_global_1d_id * MThreadSliceSize, 0));
163 
164  constexpr auto in_thread_copy_step = make_multi_index(0, KThreadSliceSize);
165 
166  index_t reducedLength = 0;
167  do
168  {
169  threadwise_src_load.Run(in_grid_desc_m_k,
170  in_global_val_buf,
171  thread_buffer_desc,
172  make_tuple(I0, I0),
173  in_thread_buf);
174 
175  static_for<0, NumReduction, 1>{}([&](auto iR) {
176  static_for<0, MThreadSliceSize, 1>{}([&](auto iM) {
177  // do element-wise pre-reduction operation
178  static_for<0, KThreadSliceSize, 1>{}([&](auto iK) {
179  constexpr auto offset =
180  thread_buffer_desc.CalculateOffset(make_tuple(iM, iK));
181  in_elementwise_op_tuple[iR](in_thread_buf_tuple(iR)(Number<offset>{}),
182  in_thread_buf(Number<offset>{}));
183  });
184  });
185 
186  ThreadwiseReduce::Reduce(in_thread_buf_tuple(iR), accu_value_buf_tuple(iR));
187  });
188 
189  threadwise_src_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_step);
190 
191  reducedLength += KThreadSliceSize;
192  } while(reducedLength < toReduceLength);
193 
194  constexpr auto reduced_data_desc = ThreadReduceDstDesc_M{};
195 
196  static_for<0, NumReduction, 1>{}([&](auto iR) {
197  using OutDataTypePointer = remove_cvref_t<decltype(OutDataTypePointerTuple{}[iR])>;
199 
200  static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
201  acc_elementwise_op_tuple[iR](accu_value_buf_tuple(iR)(I),
202  accu_value_buf_tuple(iR)(I));
203 
204  accu_value_buf_tuple(iR)(I) *= alpha_values[iR];
205  });
206 
207  if(!float_equal_zero{}(beta_values[iR]))
208  {
210  priorDstValueBuf;
211 
212  auto threadwise_dst_load =
214  OutDataType,
215  decltype(out_grid_desc_m_tuple[iR]),
216  decltype(reduced_data_desc),
218  Sequence<0>,
219  0,
220  OutDstVectorSizeSeq::At(iR),
221  1,
222  false>(
223  out_grid_desc_m_tuple[iR],
224  make_multi_index(thread_global_1d_id * MThreadSliceSize));
225 
226  threadwise_dst_load.Run(out_grid_desc_m_tuple[iR],
227  out_global_val_buf_tuple(iR),
228  reduced_data_desc,
229  make_tuple(I0),
230  priorDstValueBuf);
231 
232  static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
233  accu_value_buf_tuple(iR)(I) +=
234  type_convert<AccDataType>(priorDstValueBuf[I]) * beta_values[iR];
235  });
236  };
237 
238  auto threadwise_dst_store =
240  OutDataType,
241  decltype(reduced_data_desc),
242  decltype(out_grid_desc_m_tuple[iR]),
245  Sequence<0>,
246  0,
247  OutDstVectorSizeSeq::At(iR),
248  OutMemoryDataOperation,
249  1,
250  true>(
251  out_grid_desc_m_tuple[iR],
252  make_multi_index(thread_global_1d_id * MThreadSliceSize),
253  PassThroughOp{});
254 
255  threadwise_dst_store.Run(reduced_data_desc,
256  make_tuple(I0),
257  accu_value_buf_tuple[iR],
258  out_grid_desc_m_tuple[iR],
259  out_global_val_buf_tuple(iR));
260  });
261  };
262 };
263 
264 } // namespace ck
Definition: ck.hpp:267
__host__ constexpr __device__ auto make_multi_index(Xs &&... xs)
Definition: array_multi_index.hpp:15
__host__ constexpr __device__ auto generate_tuple(F &&f, Number< N >)
Definition: tuple_helper.hpp:21
InMemoryDataOperationEnum
Definition: ck.hpp:276
__global__ void kernel_multiple_reduce_threadwise(const InGridDesc_M_K in_grid_desc_m_k, const OutGridDesc_M_Tuple out_grid_desc_m_tuple, const InElementwiseOperationTuple in_elementwise_op_tuple, const AccElementwiseOperationTuple acc_elementwise_op_tuple, Array< AccDataType, NumReduction > alpha_values, const InDataType *const __restrict__ p_in_value_global, Array< AccDataType, NumReduction > beta_values, OutDataTypePointerTuple p_out_value_global_tuple)
Definition: gridwise_2d_multiple_reduction_threadwise.hpp:26
__host__ constexpr __device__ auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition: tensor_descriptor_helper.hpp:101
__device__ index_t get_thread_global_1d_id()
Definition: get_id.hpp:54
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition: type.hpp:297
int32_t index_t
Definition: ck.hpp:298
Definition: gridwise_2d_multiple_reduction_threadwise.hpp:63
decltype(make_naive_tensor_descriptor_packed(make_tuple(Number< MThreadSliceSize >{}))) ThreadReduceDstDesc_M
Definition: gridwise_2d_multiple_reduction_threadwise.hpp:83
typename conditional< reorder_thread_cluster, Sequence< 1, 0 >, Sequence< 0, 1 > >::type ThreadBufferDimAccessOrder
Definition: gridwise_2d_multiple_reduction_threadwise.hpp:78
decltype(make_naive_tensor_descriptor_packed(make_tuple(Number< MThreadSliceSize >{}, Number< KThreadSliceSize >{}))) ThreadReduceSrcDesc_M_K
Definition: gridwise_2d_multiple_reduction_threadwise.hpp:81
static __device__ void Run(const InGridDesc_M_K &in_grid_desc_m_k, const OutGridDesc_M_Tuple &out_grid_desc_m_tuple, const InElementwiseOperationTuple &in_elementwise_op_tuple, const AccElementwiseOperationTuple &acc_elementwise_op_tuple, Array< AccDataType, NumReduction > alpha_values, const InDataType *const __restrict__ p_in_value_global, Array< AccDataType, NumReduction > beta_values, OutDataTypePointerTuple p_out_value_global_tuple)
Definition: gridwise_2d_multiple_reduction_threadwise.hpp:97
static constexpr bool reorder_thread_cluster
Definition: gridwise_2d_multiple_reduction_threadwise.hpp:75
tensor_operation::element_wise::PassThrough PassThroughOp
Definition: gridwise_2d_multiple_reduction_threadwise.hpp:91
static constexpr auto I0
Definition: gridwise_2d_multiple_reduction_threadwise.hpp:93
Definition: sequence.hpp:43
Definition: static_buffer.hpp:16
Definition: reduction_functions_threadwise.hpp:23
static __device__ void Reduce(const SrcBufferType &src_buf, DstBufferType &dst_buf)
Definition: reduction_functions_threadwise.hpp:36
Definition: threadwise_tensor_slice_transfer.hpp:39
Helper structure that facilitates transfer of source (grid) data to destination threads.
Definition: threadwise_tensor_slice_transfer.hpp:234
Definition: functional.hpp:100
Definition: reduction_functions_accumulate.hpp:28
Definition: reduction_common.hpp:20
Definition: integral_constant.hpp:20
Definition: functional2.hpp:33
Definition: unary_element_wise_operation.hpp:334