/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/device/device_gemm_bias_e_permute.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/device/device_gemm_bias_e_permute.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/device/device_gemm_bias_e_permute.hpp Source File
device_gemm_bias_e_permute.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include <array>
7 
8 #include "device_base.hpp"
9 
10 namespace ck {
11 namespace tensor_operation {
12 namespace device {
13 
15 {
18 };
19 
20 // input : A[M, K], B[K, N],
21 // input : D[M, N], ...
22 // output : E[M, N]
23 // C = a_op(A) * b_op(B)
24 // E = cde_op(C, D)
25 template <typename AElementwiseOperation,
26  typename BElementwiseOperation,
27  typename CDEElementwiseOperation>
29 {
30  virtual std::unique_ptr<BaseArgument>
31  MakeArgumentPointer(const void* p_a,
32  const void* p_b,
33  const void* p_d,
34  void* p_e,
35  ck::index_t M,
36  ck::index_t N,
37  ck::index_t K,
38  ck::index_t StrideA,
39  ck::index_t StrideB,
40  DEGridDesc_M0_M1_M2_N0_N1 d_gride_desc,
41  DEGridDesc_M0_M1_M2_N0_N1 e_gride_desc,
42  AElementwiseOperation a_element_op,
43  BElementwiseOperation b_element_op,
44  CDEElementwiseOperation cde_element_op) = 0;
45 
46  virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
47 };
48 
49 } // namespace device
50 } // namespace tensor_operation
51 } // namespace ck
Definition: ck.hpp:267
int32_t index_t
Definition: ck.hpp:298
Definition: device_base.hpp:77
Definition: device_gemm_bias_e_permute.hpp:15
ck::index_t stride_M1_
Definition: device_gemm_bias_e_permute.hpp:17
ck::index_t stride_M0_
Definition: device_gemm_bias_e_permute.hpp:17
ck::index_t stride_M2_
Definition: device_gemm_bias_e_permute.hpp:17
ck::index_t N1_
Definition: device_gemm_bias_e_permute.hpp:16
ck::index_t stride_N0_
Definition: device_gemm_bias_e_permute.hpp:17
ck::index_t M1_
Definition: device_gemm_bias_e_permute.hpp:16
ck::index_t M0_
Definition: device_gemm_bias_e_permute.hpp:16
ck::index_t M2_
Definition: device_gemm_bias_e_permute.hpp:16
ck::index_t N0_
Definition: device_gemm_bias_e_permute.hpp:16
ck::index_t stride_N1_
Definition: device_gemm_bias_e_permute.hpp:17
Definition: device_gemm_bias_e_permute.hpp:29
virtual std::unique_ptr< BaseInvoker > MakeInvokerPointer()=0
virtual std::unique_ptr< BaseArgument > MakeArgumentPointer(const void *p_a, const void *p_b, const void *p_d, void *p_e, ck::index_t M, ck::index_t N, ck::index_t K, ck::index_t StrideA, ck::index_t StrideB, DEGridDesc_M0_M1_M2_N0_N1 d_gride_desc, DEGridDesc_M0_M1_M2_N0_N1 e_gride_desc, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CDEElementwiseOperation cde_element_op)=0