/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/device/device_reduce_multi_d.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/device/device_reduce_multi_d.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/device/device_reduce_multi_d.hpp Source File
device_reduce_multi_d.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include <array>
7 #include <memory>
8 
9 #include "ck/ck.hpp"
11 
12 namespace ck {
13 namespace tensor_operation {
14 namespace device {
15 
16 template <typename InDataType,
17  typename DsDataType,
18  typename AccDataType,
19  typename OutDataType,
20  index_t Rank,
21  index_t NumReduceDim,
22  typename ReduceOperation,
23  typename InElementwiseOperation,
24  typename OutElementwiseOperation>
26 {
27  static constexpr index_t NumOutDim = (Rank - NumReduceDim == 0) ? 1 : Rank - NumReduceDim;
28 
29  static constexpr index_t NumDTensor = DsDataType::Size();
30 
31  virtual std::unique_ptr<BaseArgument>
32  MakeArgumentPointer(const std::array<index_t, Rank> inLengths,
33  const std::array<index_t, Rank> inStrides,
34  const std::array<std::array<index_t, NumOutDim>, NumDTensor> DsLengths,
35  const std::array<std::array<index_t, NumOutDim>, NumDTensor> DsStrides,
36  const std::array<index_t, NumOutDim> outLengths,
37  const std::array<index_t, NumOutDim> outStrides,
38  const std::array<int, NumReduceDim> reduceDims,
39  const void* in_dev,
40  const std::array<const void*, NumDTensor> ds_dev,
41  void* out_dev,
42  const InElementwiseOperation in_elementwise_op,
43  const OutElementwiseOperation out_elementwise_op) = 0;
44 
45  virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
46 };
47 
48 template <typename InDataType,
49  typename DsDataType,
50  typename AccDataType,
51  typename OutDataType,
52  index_t Rank,
53  index_t NumReduceDim,
54  typename ReduceOperation,
55  typename InElementwiseOperation,
56  typename OutElementwiseOperation>
57 using DeviceReduceMultiDPtr = std::unique_ptr<DeviceReduceMultiD<InDataType,
58  DsDataType,
59  AccDataType,
60  OutDataType,
61  Rank,
62  NumReduceDim,
63  ReduceOperation,
64  InElementwiseOperation,
65  OutElementwiseOperation>>;
66 
67 } // namespace device
68 } // namespace tensor_operation
69 } // namespace ck
std::unique_ptr< DeviceReduceMultiD< InDataType, DsDataType, AccDataType, OutDataType, Rank, NumReduceDim, ReduceOperation, InElementwiseOperation, OutElementwiseOperation > > DeviceReduceMultiDPtr
Definition: device_reduce_multi_d.hpp:65
Definition: ck.hpp:267
int32_t index_t
Definition: ck.hpp:298
Definition: device_base.hpp:77
Definition: device_reduce_multi_d.hpp:26
static constexpr index_t NumOutDim
Definition: device_reduce_multi_d.hpp:27
virtual std::unique_ptr< BaseArgument > MakeArgumentPointer(const std::array< index_t, Rank > inLengths, const std::array< index_t, Rank > inStrides, const std::array< std::array< index_t, NumOutDim >, NumDTensor > DsLengths, const std::array< std::array< index_t, NumOutDim >, NumDTensor > DsStrides, const std::array< index_t, NumOutDim > outLengths, const std::array< index_t, NumOutDim > outStrides, const std::array< int, NumReduceDim > reduceDims, const void *in_dev, const std::array< const void *, NumDTensor > ds_dev, void *out_dev, const InElementwiseOperation in_elementwise_op, const OutElementwiseOperation out_elementwise_op)=0
static constexpr index_t NumDTensor
Definition: device_reduce_multi_d.hpp:29
virtual std::unique_ptr< BaseInvoker > MakeInvokerPointer()=0