16 template <
typename GridwiseReduction,
 
   18           bool TransformIndexKtoGlobal,
 
   23           typename IndexDataType,
 
   24           typename InGridDesc_M_K,
 
   25           typename OutGridDesc_M,
 
   26           typename InElementwiseOperation,
 
   27           typename AccElementwiseOperation>
 
   29                                          const OutGridDesc_M out_grid_desc_m,
 
   30                                          const InElementwiseOperation in_elementwise_op,
 
   31                                          const AccElementwiseOperation acc_elementwise_op,
 
   33                                          const InDataType* 
const __restrict__ p_in_value_global,
 
   34                                          const IndexDataType* 
const __restrict__ p_in_index_global,
 
   36                                          OutDataType* 
const __restrict__ p_out_value_global,
 
   37                                          IndexDataType* 
const __restrict__ p_out_index_global)
 
   39     if constexpr(!OutputIndex)
 
   41         GridwiseReduction::Run(in_grid_desc_m_k,
 
   52         GridwiseReduction::template RunWithIndex<TransformIndexKtoGlobal, HaveIndexInput>(
 
   66 template <
typename InDataType,
 
   69           typename IndexDataType,
 
   70           typename InGridDesc_M_K,
 
   71           typename OutGridDesc_M,
 
   72           typename ReduceOperation,
 
   73           typename InElementwiseOperation,
 
   74           typename AccElementwiseOperation,
 
   85     static_assert(((InSrcVectorDim == 0 && MThreadSliceSize % InSrcVectorSize == 0) ||
 
   86                    (InSrcVectorDim == 1 && KThreadSliceSize % InSrcVectorSize == 0)) &&
 
   87                       (MThreadSliceSize % OutDstVectorSize == 0),
 
   88                   "Invalid thread slice sizes and/or vector sizes configuration, please check!");
 
  102     __device__ 
static void Run(
const InGridDesc_M_K& in_grid_desc_m_k,
 
  103                                const OutGridDesc_M& out_grid_desc_m,
 
  104                                const InElementwiseOperation& in_elementwise_op,
 
  105                                const AccElementwiseOperation& acc_elementwise_op,
 
  107                                const InDataType* 
const __restrict__ p_in_value_global,
 
  109                                OutDataType* 
const __restrict__ p_out_value_global)
 
  117         const auto identityVal = ReduceOperation::template GetIdentityValue<AccDataType>();
 
  119         const auto in_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
 
  121             in_grid_desc_m_k.GetElementSpaceSize(),
 
  122             ReduceOperation::template GetIdentityValue<InDataType>());
 
  123         auto dst_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
 
  124             p_out_value_global, out_grid_desc_m.GetElementSpaceSize());
 
  133         const auto toReduceLength = in_grid_desc_m_k.GetLength(
Number<1>{});
 
  141         auto threadwise_src_val_load =
 
  145                                              decltype(thread_buffer_desc),
 
  152                 in_grid_desc_m_k, 
make_multi_index(thread_global_1d_id * MThreadSliceSize, 0));
 
  159             threadwise_src_val_load.Run(in_grid_desc_m_k,
 
  168                     constexpr 
auto offset = thread_buffer_desc.CalculateOffset(
make_tuple(iM, iK));
 
  174             ThreadwiseReduce::Reduce(in_thread_buf, accu_value_buf);
 
  176             threadwise_src_val_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_step);
 
  178             reducedLength += KThreadSliceSize;
 
  179         } 
while(reducedLength < toReduceLength);
 
  182             acc_elementwise_op(accu_value_buf(I), accu_value_buf(I));
 
  184             accu_value_buf(I) *= alpha;
 
  194                                                                         decltype(reduced_data_desc),
 
  206             threadwise_dst_load.Run(out_grid_desc_m,
 
  213                 accu_value_buf(I) += type_convert<AccDataType>(priorDstValue_buf[I]) * beta;
 
  219                                                                        decltype(reduced_data_desc),
 
  226                                                                        OutMemoryDataOperation,
 
  233         threadwise_dst_store.Run(
 
  234             reduced_data_desc, 
make_tuple(
I0), accu_value_buf, out_grid_desc_m, dst_global_buf);
 
  237     template <
bool TransformIndexKtoGlobal, 
bool HaveIndexInput>
 
  238     __device__ 
static void RunWithIndex(
const InGridDesc_M_K& in_grid_desc_m_k,
 
  239                                         const OutGridDesc_M& out_grid_desc_m,
 
  240                                         const InElementwiseOperation& in_elementwise_op,
 
  241                                         const AccElementwiseOperation& acc_elementwise_op,
 
  243                                         const InDataType* 
const __restrict__ p_in_value_global,
 
  244                                         const IndexDataType* 
const __restrict__ p_in_index_global,
 
  246                                         OutDataType* 
const __restrict__ p_out_value_global,
 
  247                                         IndexDataType* 
const __restrict__ p_out_index_global)
 
  256         (void)acc_elementwise_op;
 
  258         const auto identityVal = ReduceOperation::template GetIdentityValue<AccDataType>();
 
  260         const auto in_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
 
  262             in_grid_desc_m_k.GetElementSpaceSize(),
 
  263             ReduceOperation::template GetIdentityValue<InDataType>());
 
  264         const auto in_global_idx_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
 
  265             p_in_index_global, in_grid_desc_m_k.GetElementSpaceSize());
 
  267         auto out_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
 
  268             p_out_value_global, out_grid_desc_m.GetElementSpaceSize());
 
  269         auto out_global_idx_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
 
  270             p_out_index_global, out_grid_desc_m.GetElementSpaceSize());
 
  277                      MThreadSliceSize * KThreadSliceSize,
 
  285             accu_value_buf(I) = identityVal;
 
  286             accu_index_buf(I) = 0;
 
  289         const auto toReduceLength = in_grid_desc_m_k.GetLength(
Number<1>{});
 
  297         auto threadwise_src_val_load =
 
  301                                              decltype(thread_buffer_desc),
 
  308                 in_grid_desc_m_k, 
make_multi_index(thread_global_1d_id * MThreadSliceSize, 0));
 
  314         if constexpr(HaveIndexInput)
 
  316             auto threadwise_src_idx_load =
 
  320                                                  decltype(thread_buffer_desc),
 
  327                     in_grid_desc_m_k, 
make_multi_index(thread_global_1d_id * MThreadSliceSize, 0));
 
  331                 threadwise_src_val_load.
Run(in_grid_desc_m_k,
 
  337                 threadwise_src_idx_load.Run(in_grid_desc_m_k,
 
  346                         constexpr 
auto offset =
 
  347                             thread_buffer_desc.CalculateOffset(
make_tuple(iM, iK));
 
  354                 ThreadwiseReduceWithIndex::Reduce(
 
  355                     in_thread_val_buf, in_thread_idx_buf, accu_value_buf, accu_index_buf);
 
  357                 threadwise_src_val_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_step);
 
  358                 threadwise_src_idx_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_step);
 
  360                 indexStart += KThreadSliceSize;
 
  361                 reducedLength += KThreadSliceSize;
 
  362             } 
while(reducedLength < toReduceLength);
 
  368                 threadwise_src_val_load.Run(in_grid_desc_m_k,
 
  377                         constexpr 
auto offset =
 
  378                             thread_buffer_desc.CalculateOffset(
make_tuple(iM, iK));
 
  387                 ThreadwiseReduceWithIndex::Reduce(
 
  388                     in_thread_val_buf, in_thread_idx_buf, accu_value_buf, accu_index_buf);
 
  390                 threadwise_src_val_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_step);
 
  392                 indexStart += KThreadSliceSize;
 
  393                 reducedLength += KThreadSliceSize;
 
  394             } 
while(reducedLength < toReduceLength);
 
  396             if constexpr(TransformIndexKtoGlobal)
 
  404                     accu_index_buf(I) = coord.GetOffset();
 
  411             acc_elementwise_op(accu_value_buf(I), accu_value_buf(I));
 
  413             accu_value_buf(I) *= alpha;
 
  423                                                                         decltype(reduced_data_desc),
 
  435             threadwise_dst_load.Run(out_grid_desc_m,
 
  442                 accu_value_buf(I) += type_convert<AccDataType>(priorDstValue_buf[I]) * beta;
 
  446         auto threadwise_dst_val_store =
 
  449                                                decltype(reduced_data_desc),
 
  456                                                OutMemoryDataOperation,
 
  463         auto threadwise_dst_idx_store =
 
  466                                                decltype(reduced_data_desc),
 
  473                                                OutMemoryDataOperation,
 
  480         threadwise_dst_val_store.
Run(
 
  481             reduced_data_desc, 
make_tuple(
I0), accu_value_buf, out_grid_desc_m, out_global_val_buf);
 
  483         threadwise_dst_idx_store.
Run(
 
  484             reduced_data_desc, 
make_tuple(
I0), accu_index_buf, out_grid_desc_m, out_global_idx_buf);
 
ck_tile::element_wise::PassThrough PassThrough
Definition: grouped_convolution_utils.hpp:47
 
__host__ constexpr __device__ auto make_multi_index(Xs &&... xs)
Definition: array_multi_index.hpp:15
 
InMemoryDataOperationEnum
Definition: ck.hpp:277
 
__host__ constexpr __device__ auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition: tensor_descriptor_helper.hpp:101
 
__global__ void kernel_reduce_threadwise(const InGridDesc_M_K in_grid_desc_m_k, const OutGridDesc_M out_grid_desc_m, const InElementwiseOperation in_elementwise_op, const AccElementwiseOperation acc_elementwise_op, AccDataType alpha, const InDataType *const __restrict__ p_in_value_global, const IndexDataType *const __restrict__ p_in_index_global, AccDataType beta, OutDataType *const __restrict__ p_out_value_global, IndexDataType *const __restrict__ p_out_index_global)
Definition: gridwise_2d_reduction_threadwise.hpp:28
 
__host__ constexpr __device__ auto make_tensor_coordinate(const TensorDesc &tensor_desc, const VisibleIndex &idx_visible)
Definition: tensor_descriptor.hpp:407
 
__device__ index_t get_block_1d_id()
Definition: get_id.hpp:47
 
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
 
int32_t index_t
Definition: ck.hpp:299
 
__device__ index_t get_thread_local_1d_id()
Definition: get_id.hpp:41
 
Definition: gridwise_2d_reduction_threadwise.hpp:84
 
static constexpr auto I0
Definition: gridwise_2d_reduction_threadwise.hpp:100
 
tensor_operation::element_wise::PassThrough PassThroughOp
Definition: gridwise_2d_reduction_threadwise.hpp:98
 
static __device__ void RunWithIndex(const InGridDesc_M_K &in_grid_desc_m_k, const OutGridDesc_M &out_grid_desc_m, const InElementwiseOperation &in_elementwise_op, const AccElementwiseOperation &acc_elementwise_op, AccDataType alpha, const InDataType *const __restrict__ p_in_value_global, const IndexDataType *const __restrict__ p_in_index_global, AccDataType beta, OutDataType *const __restrict__ p_out_value_global, IndexDataType *const __restrict__ p_out_index_global)
Definition: gridwise_2d_reduction_threadwise.hpp:238
 
typename conditional< InSrcVectorDim==0, Sequence< 1, 0 >, Sequence< 0, 1 > >::type ThreadBufferDimAccessOrder
Definition: gridwise_2d_reduction_threadwise.hpp:91
 
decltype(make_naive_tensor_descriptor_packed(make_tuple(Number< MThreadSliceSize >{}))) ThreadReduceDstDesc_M
Definition: gridwise_2d_reduction_threadwise.hpp:96
 
decltype(make_naive_tensor_descriptor_packed(make_tuple(Number< MThreadSliceSize >{}, Number< KThreadSliceSize >{}))) ThreadReduceSrcDesc_M_K
Definition: gridwise_2d_reduction_threadwise.hpp:94
 
static __device__ void Run(const InGridDesc_M_K &in_grid_desc_m_k, const OutGridDesc_M &out_grid_desc_m, const InElementwiseOperation &in_elementwise_op, const AccElementwiseOperation &acc_elementwise_op, AccDataType alpha, const InDataType *const __restrict__ p_in_value_global, AccDataType beta, OutDataType *const __restrict__ p_out_value_global)
Definition: gridwise_2d_reduction_threadwise.hpp:102
 
Definition: sequence.hpp:43
 
Definition: static_buffer.hpp:16
 
Definition: reduction_functions_threadwise.hpp:23
 
Definition: reduction_functions_threadwise.hpp:65
 
Definition: threadwise_tensor_slice_transfer.hpp:39
 
__device__ void Run(const SrcDesc &, const SrcSliceOriginIdx &, const SrcBuffer &src_buf, const DstDesc &dst_desc, DstBuffer &dst_buf)
Definition: threadwise_tensor_slice_transfer.hpp:66
 
Helper structure that facilitates transfer of source (grid) data to destination threads.
Definition: threadwise_tensor_slice_transfer.hpp:234
 
__device__ void Run(const SrcDesc &src_desc, const SrcBuffer &src_buf, const DstDesc &, const DstSliceOriginIdx &, DstBuffer &dst_buf)
Definition: threadwise_tensor_slice_transfer.hpp:276
 
Definition: functional.hpp:100
 
Definition: reduction_common.hpp:20
 
Definition: integral_constant.hpp:20
 
Definition: functional2.hpp:33
 
Definition: unary_element_wise_operation.hpp:340