/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/device/device_cgemm.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/device/device_cgemm.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/device/device_cgemm.hpp Source File
device_cgemm.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 #include "device_base.hpp"
6 
7 namespace ck {
8 namespace tensor_operation {
9 namespace device {
10 
11 template <typename AElementwiseOperation,
12  typename BElementwiseOperation,
13  typename CElementwiseOperation>
14 struct DeviceCGemm : public BaseOperator
15 {
16  virtual std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a_real,
17  const void* p_a_imag,
18  const void* p_b_real,
19  const void* p_b_imag,
20  void* p_c_real,
21  void* p_c_imag,
22  void* p_workspace,
23  ck::index_t M,
24  ck::index_t N,
25  ck::index_t K,
26  ck::index_t StrideA,
27  ck::index_t StrideB,
28  ck::index_t StrideC,
29  AElementwiseOperation a_element_op,
30  BElementwiseOperation b_element_op,
31  CElementwiseOperation c_element_op,
32  ck::index_t KBatch = 1) = 0;
33 
34  virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
35  virtual std::size_t GetWorkspaceSize(index_t MRaw,
36  index_t NRaw,
37  index_t KRaw,
38  index_t StrideA,
39  index_t StrideB,
40  index_t StrideC) const = 0;
41 };
42 
43 template <typename AElementwiseOperation,
44  typename BElementwiseOperation,
45  typename CElementwiseOperation>
46 using DeviceCGemmPtr = std::unique_ptr<
48 
49 } // namespace device
50 } // namespace tensor_operation
51 } // namespace ck
std::unique_ptr< DeviceCGemm< AElementwiseOperation, BElementwiseOperation, CElementwiseOperation > > DeviceCGemmPtr
Definition: device_cgemm.hpp:47
Definition: ck.hpp:267
int32_t index_t
Definition: ck.hpp:298
Definition: device_base.hpp:77
Definition: device_cgemm.hpp:15
virtual std::size_t GetWorkspaceSize(index_t MRaw, index_t NRaw, index_t KRaw, index_t StrideA, index_t StrideB, index_t StrideC) const =0
virtual std::unique_ptr< BaseInvoker > MakeInvokerPointer()=0
virtual std::unique_ptr< BaseArgument > MakeArgumentPointer(const void *p_a_real, const void *p_a_imag, const void *p_b_real, const void *p_b_imag, void *p_c_real, void *p_c_imag, void *p_workspace, ck::index_t M, ck::index_t N, ck::index_t K, ck::index_t StrideA, ck::index_t StrideB, ck::index_t StrideC, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CElementwiseOperation c_element_op, ck::index_t KBatch=1)=0