/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/device/device_gemm_multiple_d.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/device/device_gemm_multiple_d.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/device/device_gemm_multiple_d.hpp Source File
device_gemm_multiple_d.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 #ifndef __HIPCC_RTC__
6 #include <array>
7 #endif
8 
9 #include "ck/utility/array.hpp"
11 
12 namespace ck {
13 namespace tensor_operation {
14 namespace device {
15 
16 // GEMM:
17 // input : A[M, K], B[K, N],
18 // input : D0[M, N], D1[M, N], ...
19 // output : E[M, N]
20 // C = a_op(A) * b_op(B)
21 // E = cde_op(C, D0, D1, ...)
22 // Assume:
23 // D0, D1, ... and E have the same layout
24 template <typename ALayout,
25  typename BLayout,
26  typename DsLayout,
27  typename ELayout,
28  typename ADataType,
29  typename BDataType,
30  typename DsDataType,
31  typename EDataType,
32  typename AElementwiseOperation,
33  typename BElementwiseOperation,
34  typename CDEElementwiseOperation>
36 {
37  static constexpr index_t NumDTensor = DsDataType::Size();
38 
39 #ifndef __HIPCC_RTC__
40  virtual std::unique_ptr<BaseArgument>
41  MakeArgumentPointer(const void* p_a,
42  const void* p_b,
43  std::array<const void*, NumDTensor> p_ds,
44  void* p_e,
45  ck::index_t M,
46  ck::index_t N,
47  ck::index_t K,
48  ck::index_t StrideA,
49  ck::index_t StrideB,
50  std::array<ck::index_t, NumDTensor> StrideDs,
51  ck::index_t StrideE,
52  AElementwiseOperation a_element_op,
53  BElementwiseOperation b_element_op,
54  CDEElementwiseOperation cde_element_op) = 0;
55 
56  virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
57 #endif
58 };
59 
60 // GEMM:
61 // input : A[M, K], B[K, N],
62 // input : D0[M, N], D1[M, N], ...
63 // output : E[M, N]
64 // C = a_op(A) * b_op(B)
65 // E = cde_op(C, D0, D1, ...)
66 // Assume:
67 // D0, D1, ... and E have the same layout
68 template <typename ALayout,
69  typename BLayout,
70  typename DsLayout,
71  typename ELayout,
72  typename ADataType,
73  typename BDataType,
74  typename DsDataType,
75  typename EDataType,
76  typename AElementwiseOperation,
77  typename BElementwiseOperation,
78  typename CDEElementwiseOperation>
80 {
81  static constexpr index_t NumDTensor = DsDataType::Size();
82 
83 #ifndef __HIPCC_RTC__
84  virtual std::unique_ptr<BaseArgument>
85  MakeArgumentPointer(const void* p_a,
86  const void* p_b,
87  std::array<const void*, NumDTensor> p_ds,
88  void* p_e,
89  ck::index_t M,
90  ck::index_t N,
91  ck::index_t K,
92  ck::index_t StrideA,
93  ck::index_t StrideB,
94  std::array<ck::index_t, NumDTensor> StrideDs,
95  ck::index_t StrideE,
96  ck::index_t KBatch,
97  AElementwiseOperation a_element_op,
98  BElementwiseOperation b_element_op,
99  CDEElementwiseOperation cde_element_op) = 0;
100 
101  virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
102 #endif
103 };
104 
105 // GEMM:
106 // input : A[M, K], B[K, N],
107 // input : D0[M, N], D1[M, N], ...
108 // output : E[M, N]
109 // C = a_op(A) * b_op(B)
110 // E = cde_op(C, D0, D1, ...)
111 // Assume:
112 // D0, D1, ... and E have the same layout
113 template <typename ALayout,
114  typename BLayout,
115  typename DsLayout,
116  typename ELayout,
117  typename ADataType,
118  typename BDataType,
119  typename DsDataType,
120  typename EDataType,
121  typename AElementwiseOperation,
122  typename BElementwiseOperation,
123  typename CDEElementwiseOperation>
125 {
126  static constexpr index_t NumDTensor = DsDataType::Size();
127 
128 #ifndef CK_CODE_GEN_RTC
129  virtual std::unique_ptr<BaseArgument>
130  MakeArgumentPointer(const void* p_a,
131  const void* p_b,
132  std::array<const void*, NumDTensor> p_ds,
133  void* p_e,
134  ck::index_t M,
135  ck::index_t N,
136  ck::index_t K,
137  ck::index_t StrideA,
138  ck::index_t StrideB,
139  std::array<ck::index_t, NumDTensor> StrideDs,
140  ck::index_t StrideE,
141  ck::index_t KBatch,
142  AElementwiseOperation a_element_op,
143  BElementwiseOperation b_element_op,
144  CDEElementwiseOperation cde_element_op) = 0;
145 
146  virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
147 
148  virtual int GetPreShuffleParameters() = 0;
149 #endif
150 };
151 
152 template <typename ALayout,
153  typename BLayout,
154  typename DsLayout,
155  typename ELayout,
156  typename ADataType,
157  typename AScaleDataType,
158  typename BDataType,
159  typename BScaleDataType,
160  typename DsDataType,
161  typename EDataType,
162  index_t ScaleBlockSize,
163  typename AElementwiseOperation,
164  typename BElementwiseOperation,
165  typename CDEElementwiseOperation>
167 {
168  static constexpr index_t NumDTensor = DsDataType::Size();
169 
170 #ifndef CK_CODE_GEN_RTC
171  virtual std::unique_ptr<BaseArgument>
172  MakeArgumentPointer(const void* p_a,
173  const void* p_a_scale,
174  const void* p_b,
175  const void* p_b_scale,
176  std::array<const void*, NumDTensor> p_ds,
177  void* p_e,
178  ck::index_t M,
179  ck::index_t N,
180  ck::index_t K,
181  ck::index_t StrideA,
182  ck::index_t StrideAScale,
183  ck::index_t StrideB,
184  ck::index_t StrideBScale,
185  std::array<ck::index_t, NumDTensor> StrideDs,
186  ck::index_t StrideE,
187  ck::index_t KBatch,
188  AElementwiseOperation a_element_op,
189  BElementwiseOperation b_element_op,
190  CDEElementwiseOperation cde_element_op) = 0;
191 
192  virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
193 
194  virtual int GetPreShuffleParameters() = 0;
195 #endif
196 };
197 
205 template <typename ALayout,
206  typename BLayout,
207  typename DsLayout,
208  typename ELayout,
209  typename ADataType,
210  typename BDataType,
211  typename DsDataType,
212  typename EDataType,
213  typename AElementwiseOperation,
214  typename BElementwiseOperation,
215  typename CDEElementwiseOperation>
217  BLayout,
218  DsLayout,
219  ELayout,
220  ADataType,
221  BDataType,
222  DsDataType,
223  EDataType,
224  AElementwiseOperation,
225  BElementwiseOperation,
226  CDEElementwiseOperation>
227 {
229  BLayout,
230  DsLayout,
231  ELayout,
232  ADataType,
233  BDataType,
234  DsDataType,
235  EDataType,
236  AElementwiseOperation,
237  BElementwiseOperation,
238  CDEElementwiseOperation>;
239 
240  static constexpr index_t NumDTensor = DsDataType::Size();
241 
242 #ifndef __HIPCC_RTC__
243 
244  explicit DeviceGemmMultipleDSplitKWrapper(std::unique_ptr<DeviceOp> p_op)
245  : p_op_(std::move(p_op))
246  {
247  }
248 
249  bool IsSupportedArgument(const BaseArgument* p_arg) override
250  {
251  return p_op_->IsSupportedArgument(p_arg);
252  }
253  std::unique_ptr<BaseArgument>
254  MakeArgumentPointer(const void* p_a,
255  const void* p_b,
256  std::array<const void*, NumDTensor> p_ds,
257  void* p_e,
258  ck::index_t M,
259  ck::index_t N,
260  ck::index_t K,
261  ck::index_t StrideA,
262  ck::index_t StrideB,
263  std::array<ck::index_t, NumDTensor> StrideDs,
264  ck::index_t StrideE,
265  AElementwiseOperation a_element_op,
266  BElementwiseOperation b_element_op,
267  CDEElementwiseOperation cde_element_op) override
268  {
269  return p_op_->MakeArgumentPointer(p_a,
270  p_b,
271  p_ds,
272  p_e,
273  M,
274  N,
275  K,
276  StrideA,
277  StrideB,
278  StrideDs,
279  StrideE,
280  1, // KBatch
281  a_element_op,
282  b_element_op,
283  cde_element_op);
284  }
285 
286  std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
287  {
288  return p_op_->MakeInvokerPointer();
289  }
290 
291  std::string GetTypeString() const override { return p_op_->GetTypeString(); }
292 
293  private:
294  std::unique_ptr<DeviceOp> p_op_;
295 
296 #endif // __HIPCC_RTC__
297 };
298 
299 } // namespace device
300 } // namespace tensor_operation
301 } // namespace ck
Definition: ck.hpp:268
int32_t index_t
Definition: ck.hpp:299
Definition: device_base.hpp:197
Definition: device_base.hpp:223
Definition: device_gemm_multiple_d.hpp:36
static constexpr index_t NumDTensor
Definition: device_gemm_multiple_d.hpp:37
virtual std::unique_ptr< BaseInvoker > MakeInvokerPointer()=0
virtual std::unique_ptr< BaseArgument > MakeArgumentPointer(const void *p_a, const void *p_b, std::array< const void *, NumDTensor > p_ds, void *p_e, ck::index_t M, ck::index_t N, ck::index_t K, ck::index_t StrideA, ck::index_t StrideB, std::array< ck::index_t, NumDTensor > StrideDs, ck::index_t StrideE, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CDEElementwiseOperation cde_element_op)=0
virtual std::unique_ptr< BaseInvoker > MakeInvokerPointer()=0
static constexpr index_t NumDTensor
Definition: device_gemm_multiple_d.hpp:126
virtual std::unique_ptr< BaseArgument > MakeArgumentPointer(const void *p_a, const void *p_b, std::array< const void *, NumDTensor > p_ds, void *p_e, ck::index_t M, ck::index_t N, ck::index_t K, ck::index_t StrideA, ck::index_t StrideB, std::array< ck::index_t, NumDTensor > StrideDs, ck::index_t StrideE, ck::index_t KBatch, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CDEElementwiseOperation cde_element_op)=0
Definition: device_gemm_multiple_d.hpp:80
virtual std::unique_ptr< BaseInvoker > MakeInvokerPointer()=0
virtual std::unique_ptr< BaseArgument > MakeArgumentPointer(const void *p_a, const void *p_b, std::array< const void *, NumDTensor > p_ds, void *p_e, ck::index_t M, ck::index_t N, ck::index_t K, ck::index_t StrideA, ck::index_t StrideB, std::array< ck::index_t, NumDTensor > StrideDs, ck::index_t StrideE, ck::index_t KBatch, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CDEElementwiseOperation cde_element_op)=0
static constexpr index_t NumDTensor
Definition: device_gemm_multiple_d.hpp:81
Wrapper for backward compatibility that allows to use instances of DeviceGemmMultipleDSplitK in conte...
Definition: device_gemm_multiple_d.hpp:227
std::unique_ptr< BaseInvoker > MakeInvokerPointer() override
Definition: device_gemm_multiple_d.hpp:286
bool IsSupportedArgument(const BaseArgument *p_arg) override
Definition: device_gemm_multiple_d.hpp:249
DeviceGemmMultipleDSplitKWrapper(std::unique_ptr< DeviceOp > p_op)
Definition: device_gemm_multiple_d.hpp:244
std::string GetTypeString() const override
Definition: device_gemm_multiple_d.hpp:291
std::unique_ptr< BaseArgument > MakeArgumentPointer(const void *p_a, const void *p_b, std::array< const void *, NumDTensor > p_ds, void *p_e, ck::index_t M, ck::index_t N, ck::index_t K, ck::index_t StrideA, ck::index_t StrideB, std::array< ck::index_t, NumDTensor > StrideDs, ck::index_t StrideE, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CDEElementwiseOperation cde_element_op) override
Definition: device_gemm_multiple_d.hpp:254
static constexpr index_t NumDTensor
Definition: device_gemm_multiple_d.hpp:240
Definition: device_gemm_multiple_d.hpp:167
virtual std::unique_ptr< BaseArgument > MakeArgumentPointer(const void *p_a, const void *p_a_scale, const void *p_b, const void *p_b_scale, std::array< const void *, NumDTensor > p_ds, void *p_e, ck::index_t M, ck::index_t N, ck::index_t K, ck::index_t StrideA, ck::index_t StrideAScale, ck::index_t StrideB, ck::index_t StrideBScale, std::array< ck::index_t, NumDTensor > StrideDs, ck::index_t StrideE, ck::index_t KBatch, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CDEElementwiseOperation cde_element_op)=0
virtual std::unique_ptr< BaseInvoker > MakeInvokerPointer()=0
static constexpr index_t NumDTensor
Definition: device_gemm_multiple_d.hpp:168