16 template <
typename GridwiseReduction,
18 bool TransformIndexKtoGlobal,
23 typename IndexDataType,
24 typename InGridDesc_M_K,
25 typename OutGridDesc_M,
26 typename InElementwiseOperation,
27 typename AccElementwiseOperation>
29 const OutGridDesc_M out_grid_desc_m,
30 const InElementwiseOperation in_elementwise_op,
31 const AccElementwiseOperation acc_elementwise_op,
33 const InDataType*
const __restrict__ p_in_value_global,
34 const IndexDataType*
const __restrict__ p_in_index_global,
36 OutDataType*
const __restrict__ p_out_value_global,
37 IndexDataType*
const __restrict__ p_out_index_global)
39 if constexpr(!OutputIndex)
41 GridwiseReduction::Run(in_grid_desc_m_k,
52 GridwiseReduction::template RunWithIndex<TransformIndexKtoGlobal, HaveIndexInput>(
66 template <
typename InDataType,
69 typename IndexDataType,
70 typename InGridDesc_M_K,
71 typename OutGridDesc_M,
72 typename ReduceOperation,
73 typename InElementwiseOperation,
74 typename AccElementwiseOperation,
85 static_assert(((InSrcVectorDim == 0 && MThreadSliceSize % InSrcVectorSize == 0) ||
86 (InSrcVectorDim == 1 && KThreadSliceSize % InSrcVectorSize == 0)) &&
87 (MThreadSliceSize % OutDstVectorSize == 0),
88 "Invalid thread slice sizes and/or vector sizes configuration, please check!");
102 __device__
static void Run(
const InGridDesc_M_K& in_grid_desc_m_k,
103 const OutGridDesc_M& out_grid_desc_m,
104 const InElementwiseOperation& in_elementwise_op,
105 const AccElementwiseOperation& acc_elementwise_op,
107 const InDataType*
const __restrict__ p_in_value_global,
109 OutDataType*
const __restrict__ p_out_value_global)
117 const auto identityVal = ReduceOperation::template GetIdentityValue<AccDataType>();
119 const auto in_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
121 in_grid_desc_m_k.GetElementSpaceSize(),
122 ReduceOperation::template GetIdentityValue<InDataType>());
123 auto dst_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
124 p_out_value_global, out_grid_desc_m.GetElementSpaceSize());
133 const auto toReduceLength = in_grid_desc_m_k.GetLength(
Number<1>{});
141 auto threadwise_src_val_load =
145 decltype(thread_buffer_desc),
152 in_grid_desc_m_k,
make_multi_index(thread_global_1d_id * MThreadSliceSize, 0));
159 threadwise_src_val_load.Run(in_grid_desc_m_k,
168 constexpr
auto offset = thread_buffer_desc.CalculateOffset(
make_tuple(iM, iK));
174 ThreadwiseReduce::Reduce(in_thread_buf, accu_value_buf);
176 threadwise_src_val_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_step);
178 reducedLength += KThreadSliceSize;
179 }
while(reducedLength < toReduceLength);
182 acc_elementwise_op(accu_value_buf(I), accu_value_buf(I));
184 accu_value_buf(I) *= alpha;
194 decltype(reduced_data_desc),
206 threadwise_dst_load.Run(out_grid_desc_m,
213 accu_value_buf(I) += type_convert<AccDataType>(priorDstValue_buf[I]) * beta;
219 decltype(reduced_data_desc),
226 OutMemoryDataOperation,
233 threadwise_dst_store.Run(
234 reduced_data_desc,
make_tuple(
I0), accu_value_buf, out_grid_desc_m, dst_global_buf);
237 template <
bool TransformIndexKtoGlobal,
bool HaveIndexInput>
238 __device__
static void RunWithIndex(
const InGridDesc_M_K& in_grid_desc_m_k,
239 const OutGridDesc_M& out_grid_desc_m,
240 const InElementwiseOperation& in_elementwise_op,
241 const AccElementwiseOperation& acc_elementwise_op,
243 const InDataType*
const __restrict__ p_in_value_global,
244 const IndexDataType*
const __restrict__ p_in_index_global,
246 OutDataType*
const __restrict__ p_out_value_global,
247 IndexDataType*
const __restrict__ p_out_index_global)
256 (void)acc_elementwise_op;
258 const auto identityVal = ReduceOperation::template GetIdentityValue<AccDataType>();
260 const auto in_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
262 in_grid_desc_m_k.GetElementSpaceSize(),
263 ReduceOperation::template GetIdentityValue<InDataType>());
264 const auto in_global_idx_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
265 p_in_index_global, in_grid_desc_m_k.GetElementSpaceSize());
267 auto out_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
268 p_out_value_global, out_grid_desc_m.GetElementSpaceSize());
269 auto out_global_idx_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
270 p_out_index_global, out_grid_desc_m.GetElementSpaceSize());
277 MThreadSliceSize * KThreadSliceSize,
285 accu_value_buf(I) = identityVal;
286 accu_index_buf(I) = 0;
289 const auto toReduceLength = in_grid_desc_m_k.GetLength(
Number<1>{});
297 auto threadwise_src_val_load =
301 decltype(thread_buffer_desc),
308 in_grid_desc_m_k,
make_multi_index(thread_global_1d_id * MThreadSliceSize, 0));
314 if constexpr(HaveIndexInput)
316 auto threadwise_src_idx_load =
320 decltype(thread_buffer_desc),
327 in_grid_desc_m_k,
make_multi_index(thread_global_1d_id * MThreadSliceSize, 0));
331 threadwise_src_val_load.
Run(in_grid_desc_m_k,
337 threadwise_src_idx_load.Run(in_grid_desc_m_k,
346 constexpr
auto offset =
347 thread_buffer_desc.CalculateOffset(
make_tuple(iM, iK));
354 ThreadwiseReduceWithIndex::Reduce(
355 in_thread_val_buf, in_thread_idx_buf, accu_value_buf, accu_index_buf);
357 threadwise_src_val_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_step);
358 threadwise_src_idx_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_step);
360 indexStart += KThreadSliceSize;
361 reducedLength += KThreadSliceSize;
362 }
while(reducedLength < toReduceLength);
368 threadwise_src_val_load.Run(in_grid_desc_m_k,
377 constexpr
auto offset =
378 thread_buffer_desc.CalculateOffset(
make_tuple(iM, iK));
387 ThreadwiseReduceWithIndex::Reduce(
388 in_thread_val_buf, in_thread_idx_buf, accu_value_buf, accu_index_buf);
390 threadwise_src_val_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_step);
392 indexStart += KThreadSliceSize;
393 reducedLength += KThreadSliceSize;
394 }
while(reducedLength < toReduceLength);
396 if constexpr(TransformIndexKtoGlobal)
404 accu_index_buf(I) = coord.GetOffset();
411 acc_elementwise_op(accu_value_buf(I), accu_value_buf(I));
413 accu_value_buf(I) *= alpha;
423 decltype(reduced_data_desc),
435 threadwise_dst_load.Run(out_grid_desc_m,
442 accu_value_buf(I) += type_convert<AccDataType>(priorDstValue_buf[I]) * beta;
446 auto threadwise_dst_val_store =
449 decltype(reduced_data_desc),
456 OutMemoryDataOperation,
463 auto threadwise_dst_idx_store =
466 decltype(reduced_data_desc),
473 OutMemoryDataOperation,
480 threadwise_dst_val_store.
Run(
481 reduced_data_desc,
make_tuple(
I0), accu_value_buf, out_grid_desc_m, out_global_val_buf);
483 threadwise_dst_idx_store.
Run(
484 reduced_data_desc,
make_tuple(
I0), accu_index_buf, out_grid_desc_m, out_global_idx_buf);
__host__ constexpr __device__ auto make_multi_index(Xs &&... xs)
Definition: array_multi_index.hpp:15
InMemoryDataOperationEnum
Definition: ck.hpp:276
__host__ constexpr __device__ auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition: tensor_descriptor_helper.hpp:101
__global__ void kernel_reduce_threadwise(const InGridDesc_M_K in_grid_desc_m_k, const OutGridDesc_M out_grid_desc_m, const InElementwiseOperation in_elementwise_op, const AccElementwiseOperation acc_elementwise_op, AccDataType alpha, const InDataType *const __restrict__ p_in_value_global, const IndexDataType *const __restrict__ p_in_index_global, AccDataType beta, OutDataType *const __restrict__ p_out_value_global, IndexDataType *const __restrict__ p_out_index_global)
Definition: gridwise_2d_reduction_threadwise.hpp:28
__host__ constexpr __device__ auto make_tensor_coordinate(const TensorDesc &tensor_desc, const VisibleIndex &idx_visible)
Definition: tensor_descriptor.hpp:407
__device__ index_t get_block_1d_id()
Definition: get_id.hpp:58
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
int32_t index_t
Definition: ck.hpp:298
__device__ index_t get_thread_local_1d_id()
Definition: get_id.hpp:52
Definition: gridwise_2d_reduction_threadwise.hpp:84
static constexpr auto I0
Definition: gridwise_2d_reduction_threadwise.hpp:100
tensor_operation::element_wise::PassThrough PassThroughOp
Definition: gridwise_2d_reduction_threadwise.hpp:98
static __device__ void RunWithIndex(const InGridDesc_M_K &in_grid_desc_m_k, const OutGridDesc_M &out_grid_desc_m, const InElementwiseOperation &in_elementwise_op, const AccElementwiseOperation &acc_elementwise_op, AccDataType alpha, const InDataType *const __restrict__ p_in_value_global, const IndexDataType *const __restrict__ p_in_index_global, AccDataType beta, OutDataType *const __restrict__ p_out_value_global, IndexDataType *const __restrict__ p_out_index_global)
Definition: gridwise_2d_reduction_threadwise.hpp:238
typename conditional< InSrcVectorDim==0, Sequence< 1, 0 >, Sequence< 0, 1 > >::type ThreadBufferDimAccessOrder
Definition: gridwise_2d_reduction_threadwise.hpp:91
decltype(make_naive_tensor_descriptor_packed(make_tuple(Number< MThreadSliceSize >{}))) ThreadReduceDstDesc_M
Definition: gridwise_2d_reduction_threadwise.hpp:96
decltype(make_naive_tensor_descriptor_packed(make_tuple(Number< MThreadSliceSize >{}, Number< KThreadSliceSize >{}))) ThreadReduceSrcDesc_M_K
Definition: gridwise_2d_reduction_threadwise.hpp:94
static __device__ void Run(const InGridDesc_M_K &in_grid_desc_m_k, const OutGridDesc_M &out_grid_desc_m, const InElementwiseOperation &in_elementwise_op, const AccElementwiseOperation &acc_elementwise_op, AccDataType alpha, const InDataType *const __restrict__ p_in_value_global, AccDataType beta, OutDataType *const __restrict__ p_out_value_global)
Definition: gridwise_2d_reduction_threadwise.hpp:102
Definition: sequence.hpp:43
Definition: static_buffer.hpp:16
Definition: reduction_functions_threadwise.hpp:23
Definition: reduction_functions_threadwise.hpp:65
Definition: threadwise_tensor_slice_transfer.hpp:39
__device__ void Run(const SrcDesc &, const SrcSliceOriginIdx &, const SrcBuffer &src_buf, const DstDesc &dst_desc, DstBuffer &dst_buf)
Definition: threadwise_tensor_slice_transfer.hpp:66
Helper structure that facilitates transfer of source (grid) data to destination threads.
Definition: threadwise_tensor_slice_transfer.hpp:234
__device__ void Run(const SrcDesc &src_desc, const SrcBuffer &src_buf, const DstDesc &, const DstSliceOriginIdx &, DstBuffer &dst_buf)
Definition: threadwise_tensor_slice_transfer.hpp:276
Definition: functional.hpp:100
Definition: reduction_common.hpp:20
Definition: integral_constant.hpp:20
Definition: functional2.hpp:33
Definition: unary_element_wise_operation.hpp:334