/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/device/device_gemm_multiple_abd.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/device/device_gemm_multiple_abd.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/device/device_gemm_multiple_abd.hpp Source File
device_gemm_multiple_abd.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include <array>
7 
9 
10 namespace ck {
11 namespace tensor_operation {
12 namespace device {
13 
14 // GEMM:
15 // input : A0[M, K], B0[K, N],
16 // input : D0[M, N], D1[M, N], ...
17 // output : E[M, N]
18 // C = a_op(A) * b_op(B)
19 // E = cde_op(C, D0, D1, ...)
20 // Assume:
21 // D0, D1, ... and E have the same layout
22 template <typename AsLayout,
23  typename BsLayout,
24  typename DsLayout,
25  typename ELayout,
26  typename AsDataType,
27  typename BsDataType,
28  typename DsDataType,
29  typename EDataType,
30  typename AElementwiseOperation,
31  typename BElementwiseOperation,
32  typename CDEElementwiseOperation>
34 {
35  static constexpr index_t NumATensor = AsDataType::Size();
36  static constexpr index_t NumBTensor = BsDataType::Size();
37  static constexpr index_t NumDTensor = DsDataType::Size();
38 
39  virtual std::unique_ptr<BaseArgument>
40  MakeArgumentPointer(std::array<const void*, NumATensor> p_as,
41  std::array<const void*, NumBTensor> p_bs,
42  std::array<const void*, NumDTensor> p_ds,
43  void* p_e,
44  ck::index_t M,
45  ck::index_t N,
46  ck::index_t K,
47  std::array<ck::index_t, NumATensor> StrideAs,
48  std::array<ck::index_t, NumBTensor> StrideBs,
49  std::array<ck::index_t, NumDTensor> StrideDs,
50  ck::index_t StrideE,
51  AElementwiseOperation a_element_op,
52  BElementwiseOperation b_element_op,
53  CDEElementwiseOperation cde_element_op) = 0;
54 
55  virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
56 };
57 
58 // GEMM:
59 // input : A0[M, K], B0[K, N],
60 // input : D0[M, N], D1[M, N], ...
61 // output : E[M, N]
62 // C = a_op(A) * b_op(B)
63 // E = cde_op(C, D0, D1, ...)
64 // Assume:
65 // D0, D1, ... and E have the same layout
66 template <typename AsLayout,
67  typename BsLayout,
68  typename DsLayout,
69  typename ELayout,
70  typename AsDataType,
71  typename BsDataType,
72  typename DsDataType,
73  typename EDataType,
74  typename AElementwiseOperation,
75  typename BElementwiseOperation,
76  typename CDEElementwiseOperation>
78 {
79  static constexpr index_t NumATensor = AsDataType::Size();
80  static constexpr index_t NumBTensor = BsDataType::Size();
81  static constexpr index_t NumDTensor = DsDataType::Size();
82 
83  virtual std::unique_ptr<BaseArgument>
84  MakeArgumentPointer(std::array<const void*, NumATensor> p_as,
85  std::array<const void*, NumBTensor> p_bs,
86  std::array<const void*, NumDTensor> p_ds,
87  void* p_e,
88  ck::index_t M,
89  ck::index_t N,
90  ck::index_t K,
91  std::array<ck::index_t, NumATensor> StrideAs,
92  std::array<ck::index_t, NumBTensor> StrideBs,
93  std::array<ck::index_t, NumDTensor> StrideDs,
94  ck::index_t StrideE,
95  ck::index_t KBatch,
96  AElementwiseOperation a_element_op,
97  BElementwiseOperation b_element_op,
98  CDEElementwiseOperation cde_element_op) = 0;
99 
100  virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
101 };
102 
110 template <typename AsLayout,
111  typename BsLayout,
112  typename DsLayout,
113  typename ELayout,
114  typename AsDataType,
115  typename BsDataType,
116  typename DsDataType,
117  typename EDataType,
118  typename AElementwiseOperation,
119  typename BElementwiseOperation,
120  typename CDEElementwiseOperation>
122  BsLayout,
123  DsLayout,
124  ELayout,
125  AsDataType,
126  BsDataType,
127  DsDataType,
128  EDataType,
129  AElementwiseOperation,
130  BElementwiseOperation,
131  CDEElementwiseOperation>
132 {
133 
135  BsLayout,
136  DsLayout,
137  ELayout,
138  AsDataType,
139  BsDataType,
140  DsDataType,
141  EDataType,
142  AElementwiseOperation,
143  BElementwiseOperation,
144  CDEElementwiseOperation>;
145 
146  static constexpr index_t NumATensor = AsDataType::Size();
147  static constexpr index_t NumBTensor = BsDataType::Size();
148  static constexpr index_t NumDTensor = DsDataType::Size();
149 
150 #ifndef __HIPCC_RTC__
151 
152  explicit DeviceGemmMultipleABDSplitKWrapper(std::unique_ptr<DeviceOp> p_op)
153  : p_op_(std::move(p_op))
154  {
155  }
156 
157  bool IsSupportedArgument(const BaseArgument* p_arg) override
158  {
159  return p_op_->IsSupportedArgument(p_arg);
160  }
161  std::unique_ptr<BaseArgument>
162  MakeArgumentPointer(std::array<const void*, NumATensor> p_as,
163  std::array<const void*, NumBTensor> p_bs,
164  std::array<const void*, NumDTensor> p_ds,
165  void* p_e,
166  ck::index_t M,
167  ck::index_t N,
168  ck::index_t K,
169  std::array<ck::index_t, NumATensor> StrideAs,
170  std::array<ck::index_t, NumBTensor> StrideBs,
171  std::array<ck::index_t, NumDTensor> StrideDs,
172  ck::index_t StrideE,
173  AElementwiseOperation a_element_op,
174  BElementwiseOperation b_element_op,
175  CDEElementwiseOperation cde_element_op) override
176  {
177  return p_op_->MakeArgumentPointer(p_as,
178  p_bs,
179  p_ds,
180  p_e,
181  M,
182  N,
183  K,
184  StrideAs,
185  StrideBs,
186  StrideDs,
187  StrideE,
188  1, // KBatch
189  a_element_op,
190  b_element_op,
191  cde_element_op);
192  }
193 
194  std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
195  {
196  return p_op_->MakeInvokerPointer();
197  }
198 
199  std::string GetTypeString() const override { return p_op_->GetTypeString(); }
200 
201  private:
202  std::unique_ptr<DeviceOp> p_op_;
203 
204 #endif // __HIPCC_RTC__
205 };
206 
207 } // namespace device
208 } // namespace tensor_operation
209 } // namespace ck
Definition: ck.hpp:268
int32_t index_t
Definition: ck.hpp:299
Definition: device_base.hpp:197
Definition: device_base.hpp:223
Definition: device_gemm_multiple_abd.hpp:34
static constexpr index_t NumATensor
Definition: device_gemm_multiple_abd.hpp:35
virtual std::unique_ptr< BaseArgument > MakeArgumentPointer(std::array< const void *, NumATensor > p_as, std::array< const void *, NumBTensor > p_bs, std::array< const void *, NumDTensor > p_ds, void *p_e, ck::index_t M, ck::index_t N, ck::index_t K, std::array< ck::index_t, NumATensor > StrideAs, std::array< ck::index_t, NumBTensor > StrideBs, std::array< ck::index_t, NumDTensor > StrideDs, ck::index_t StrideE, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CDEElementwiseOperation cde_element_op)=0
virtual std::unique_ptr< BaseInvoker > MakeInvokerPointer()=0
static constexpr index_t NumBTensor
Definition: device_gemm_multiple_abd.hpp:36
static constexpr index_t NumDTensor
Definition: device_gemm_multiple_abd.hpp:37
Definition: device_gemm_multiple_abd.hpp:78
static constexpr index_t NumDTensor
Definition: device_gemm_multiple_abd.hpp:81
static constexpr index_t NumATensor
Definition: device_gemm_multiple_abd.hpp:79
virtual std::unique_ptr< BaseArgument > MakeArgumentPointer(std::array< const void *, NumATensor > p_as, std::array< const void *, NumBTensor > p_bs, std::array< const void *, NumDTensor > p_ds, void *p_e, ck::index_t M, ck::index_t N, ck::index_t K, std::array< ck::index_t, NumATensor > StrideAs, std::array< ck::index_t, NumBTensor > StrideBs, std::array< ck::index_t, NumDTensor > StrideDs, ck::index_t StrideE, ck::index_t KBatch, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CDEElementwiseOperation cde_element_op)=0
virtual std::unique_ptr< BaseInvoker > MakeInvokerPointer()=0
static constexpr index_t NumBTensor
Definition: device_gemm_multiple_abd.hpp:80
Wrapper for backward compatibility that allows to use instances of DeviceGemmMultipleABDSplitK in con...
Definition: device_gemm_multiple_abd.hpp:132
std::unique_ptr< BaseInvoker > MakeInvokerPointer() override
Definition: device_gemm_multiple_abd.hpp:194
bool IsSupportedArgument(const BaseArgument *p_arg) override
Definition: device_gemm_multiple_abd.hpp:157
static constexpr index_t NumBTensor
Definition: device_gemm_multiple_abd.hpp:147
static constexpr index_t NumATensor
Definition: device_gemm_multiple_abd.hpp:146
static constexpr index_t NumDTensor
Definition: device_gemm_multiple_abd.hpp:148
std::unique_ptr< BaseArgument > MakeArgumentPointer(std::array< const void *, NumATensor > p_as, std::array< const void *, NumBTensor > p_bs, std::array< const void *, NumDTensor > p_ds, void *p_e, ck::index_t M, ck::index_t N, ck::index_t K, std::array< ck::index_t, NumATensor > StrideAs, std::array< ck::index_t, NumBTensor > StrideBs, std::array< ck::index_t, NumDTensor > StrideDs, ck::index_t StrideE, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CDEElementwiseOperation cde_element_op) override
Definition: device_gemm_multiple_abd.hpp:162
DeviceGemmMultipleABDSplitKWrapper(std::unique_ptr< DeviceOp > p_op)
Definition: device_gemm_multiple_abd.hpp:152
std::string GetTypeString() const override
Definition: device_gemm_multiple_abd.hpp:199