/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/device/device_gemm_multiple_abd.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/device/device_gemm_multiple_abd.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/device/device_gemm_multiple_abd.hpp Source File
device_gemm_multiple_abd.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include <array>
7 
9 
10 namespace ck {
11 namespace tensor_operation {
12 namespace device {
13 
14 // GEMM:
15 // input : A0[M, K], B0[K, N],
16 // input : D0[M, N], D1[M, N], ...
17 // output : E[M, N]
18 // C = a_op(A) * b_op(B)
19 // E = cde_op(C, D0, D1, ...)
20 // Assume:
21 // D0, D1, ... and E have the same layout
22 template <typename AsLayout,
23  typename BsLayout,
24  typename DsLayout,
25  typename ELayout,
26  typename AsDataType,
27  typename BsDataType,
28  typename DsDataType,
29  typename EDataType,
30  typename AElementwiseOperation,
31  typename BElementwiseOperation,
32  typename CDEElementwiseOperation>
34 {
35  static constexpr index_t NumATensor = AsDataType::Size();
36  static constexpr index_t NumBTensor = BsDataType::Size();
37  static constexpr index_t NumDTensor = DsDataType::Size();
38 
39  virtual std::unique_ptr<BaseArgument>
40  MakeArgumentPointer(std::array<const void*, NumATensor> p_as,
41  std::array<const void*, NumBTensor> p_bs,
42  std::array<const void*, NumDTensor> p_ds,
43  void* p_e,
44  ck::index_t M,
45  ck::index_t N,
46  ck::index_t K,
47  std::array<ck::index_t, NumATensor> StrideAs,
48  std::array<ck::index_t, NumBTensor> StrideBs,
49  std::array<ck::index_t, NumDTensor> StrideDs,
50  ck::index_t StrideE,
51  AElementwiseOperation a_element_op,
52  BElementwiseOperation b_element_op,
53  CDEElementwiseOperation cde_element_op) = 0;
54 
55  virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
56 };
57 
58 } // namespace device
59 } // namespace tensor_operation
60 } // namespace ck
Definition: ck.hpp:267
int32_t index_t
Definition: ck.hpp:298
Definition: device_base.hpp:77
Definition: device_gemm_multiple_abd.hpp:34
static constexpr index_t NumATensor
Definition: device_gemm_multiple_abd.hpp:35
virtual std::unique_ptr< BaseArgument > MakeArgumentPointer(std::array< const void *, NumATensor > p_as, std::array< const void *, NumBTensor > p_bs, std::array< const void *, NumDTensor > p_ds, void *p_e, ck::index_t M, ck::index_t N, ck::index_t K, std::array< ck::index_t, NumATensor > StrideAs, std::array< ck::index_t, NumBTensor > StrideBs, std::array< ck::index_t, NumDTensor > StrideDs, ck::index_t StrideE, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CDEElementwiseOperation cde_element_op)=0
virtual std::unique_ptr< BaseInvoker > MakeInvokerPointer()=0
static constexpr index_t NumBTensor
Definition: device_gemm_multiple_abd.hpp:36
static constexpr index_t NumDTensor
Definition: device_gemm_multiple_abd.hpp:37