/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/device/device_gemm_multiple_d.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/device/device_gemm_multiple_d.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/device/device_gemm_multiple_d.hpp Source File
device_gemm_multiple_d.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 #ifndef __HIPCC_RTC__
6 #include <array>
7 #endif
8 
9 #include "ck/utility/array.hpp"
11 
12 namespace ck {
13 namespace tensor_operation {
14 namespace device {
15 
16 // GEMM:
17 // input : A[M, K], B[K, N],
18 // input : D0[M, N], D1[M, N], ...
19 // output : E[M, N]
20 // C = a_op(A) * b_op(B)
21 // E = cde_op(C, D0, D1, ...)
22 // Assume:
23 // D0, D1, ... and E have the same layout
24 template <typename ALayout,
25  typename BLayout,
26  typename DsLayout,
27  typename ELayout,
28  typename ADataType,
29  typename BDataType,
30  typename DsDataType,
31  typename EDataType,
32  typename AElementwiseOperation,
33  typename BElementwiseOperation,
34  typename CDEElementwiseOperation>
36 {
37  static constexpr index_t NumDTensor = DsDataType::Size();
38 
39 #ifndef __HIPCC_RTC__
40  virtual std::unique_ptr<BaseArgument>
41  MakeArgumentPointer(const void* p_a,
42  const void* p_b,
43  std::array<const void*, NumDTensor> p_ds,
44  void* p_e,
45  ck::index_t M,
46  ck::index_t N,
47  ck::index_t K,
48  ck::index_t StrideA,
49  ck::index_t StrideB,
50  std::array<ck::index_t, NumDTensor> StrideDs,
51  ck::index_t StrideE,
52  AElementwiseOperation a_element_op,
53  BElementwiseOperation b_element_op,
54  CDEElementwiseOperation cde_element_op) = 0;
55 
56  virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
57 #endif
58 };
59 
60 // GEMM:
61 // input : A[M, K], B[K, N],
62 // input : D0[M, N], D1[M, N], ...
63 // output : E[M, N]
64 // C = a_op(A) * b_op(B)
65 // E = cde_op(C, D0, D1, ...)
66 // Assume:
67 // D0, D1, ... and E have the same layout
68 template <typename ALayout,
69  typename BLayout,
70  typename DsLayout,
71  typename ELayout,
72  typename ADataType,
73  typename BDataType,
74  typename DsDataType,
75  typename EDataType,
76  typename AElementwiseOperation,
77  typename BElementwiseOperation,
78  typename CDEElementwiseOperation>
80 {
81  static constexpr index_t NumDTensor = DsDataType::Size();
82 
83 #ifndef __HIPCC_RTC__
84  virtual std::unique_ptr<BaseArgument>
85  MakeArgumentPointer(const void* p_a,
86  const void* p_b,
87  std::array<const void*, NumDTensor> p_ds,
88  void* p_e,
89  ck::index_t M,
90  ck::index_t N,
91  ck::index_t K,
92  ck::index_t StrideA,
93  ck::index_t StrideB,
94  std::array<ck::index_t, NumDTensor> StrideDs,
95  ck::index_t StrideE,
96  ck::index_t KBatch,
97  AElementwiseOperation a_element_op,
98  BElementwiseOperation b_element_op,
99  CDEElementwiseOperation cde_element_op) = 0;
100 
101  virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
102 #endif
103 };
104 
105 // GEMM:
106 // input : A[M, K], B[K, N],
107 // input : D0[M, N], D1[M, N], ...
108 // output : E[M, N]
109 // C = a_op(A) * b_op(B)
110 // E = cde_op(C, D0, D1, ...)
111 // Assume:
112 // D0, D1, ... and E have the same layout
113 template <typename ALayout,
114  typename BLayout,
115  typename DsLayout,
116  typename ELayout,
117  typename ADataType,
118  typename BDataType,
119  typename DsDataType,
120  typename EDataType,
121  typename AElementwiseOperation,
122  typename BElementwiseOperation,
123  typename CDEElementwiseOperation>
125 {
126  static constexpr index_t NumDTensor = DsDataType::Size();
127 
128 #ifndef CK_CODE_GEN_RTC
129  virtual std::unique_ptr<BaseArgument>
130  MakeArgumentPointer(const void* p_a,
131  const void* p_b,
132  std::array<const void*, NumDTensor> p_ds,
133  void* p_e,
134  ck::index_t M,
135  ck::index_t N,
136  ck::index_t K,
137  ck::index_t StrideA,
138  ck::index_t StrideB,
139  std::array<ck::index_t, NumDTensor> StrideDs,
140  ck::index_t StrideE,
141  ck::index_t KBatch,
142  AElementwiseOperation a_element_op,
143  BElementwiseOperation b_element_op,
144  CDEElementwiseOperation cde_element_op) = 0;
145 
146  virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
147 
148  virtual int GetPreShuffleParameters() = 0;
149 #endif
150 };
151 
152 template <typename ALayout,
153  typename BLayout,
154  typename DsLayout,
155  typename ELayout,
156  typename ADataType,
157  typename AScaleDataType,
158  typename BDataType,
159  typename BScaleDataType,
160  typename DsDataType,
161  typename EDataType,
162  index_t ScaleBlockSize,
163  typename AElementwiseOperation,
164  typename BElementwiseOperation,
165  typename CDEElementwiseOperation>
167 {
168  static constexpr index_t NumDTensor = DsDataType::Size();
169 
170 #ifndef CK_CODE_GEN_RTC
171  virtual std::unique_ptr<BaseArgument>
172  MakeArgumentPointer(const void* p_a,
173  const void* p_a_scale,
174  const void* p_b,
175  const void* p_b_scale,
176  std::array<const void*, NumDTensor> p_ds,
177  void* p_e,
178  ck::index_t M,
179  ck::index_t N,
180  ck::index_t K,
181  ck::index_t StrideA,
182  ck::index_t StrideAScale,
183  ck::index_t StrideB,
184  ck::index_t StrideBScale,
185  std::array<ck::index_t, NumDTensor> StrideDs,
186  ck::index_t StrideE,
187  ck::index_t KBatch,
188  AElementwiseOperation a_element_op,
189  BElementwiseOperation b_element_op,
190  CDEElementwiseOperation cde_element_op) = 0;
191 
192  virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
193 
194  virtual int GetPreShuffleParameters() = 0;
195 #endif
196 };
197 
198 } // namespace device
199 } // namespace tensor_operation
200 } // namespace ck
Definition: ck.hpp:267
int32_t index_t
Definition: ck.hpp:298
Definition: device_base.hpp:77
Definition: device_gemm_multiple_d.hpp:36
static constexpr index_t NumDTensor
Definition: device_gemm_multiple_d.hpp:37
virtual std::unique_ptr< BaseInvoker > MakeInvokerPointer()=0
virtual std::unique_ptr< BaseArgument > MakeArgumentPointer(const void *p_a, const void *p_b, std::array< const void *, NumDTensor > p_ds, void *p_e, ck::index_t M, ck::index_t N, ck::index_t K, ck::index_t StrideA, ck::index_t StrideB, std::array< ck::index_t, NumDTensor > StrideDs, ck::index_t StrideE, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CDEElementwiseOperation cde_element_op)=0
virtual std::unique_ptr< BaseInvoker > MakeInvokerPointer()=0
static constexpr index_t NumDTensor
Definition: device_gemm_multiple_d.hpp:126
virtual std::unique_ptr< BaseArgument > MakeArgumentPointer(const void *p_a, const void *p_b, std::array< const void *, NumDTensor > p_ds, void *p_e, ck::index_t M, ck::index_t N, ck::index_t K, ck::index_t StrideA, ck::index_t StrideB, std::array< ck::index_t, NumDTensor > StrideDs, ck::index_t StrideE, ck::index_t KBatch, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CDEElementwiseOperation cde_element_op)=0
Definition: device_gemm_multiple_d.hpp:80
virtual std::unique_ptr< BaseInvoker > MakeInvokerPointer()=0
virtual std::unique_ptr< BaseArgument > MakeArgumentPointer(const void *p_a, const void *p_b, std::array< const void *, NumDTensor > p_ds, void *p_e, ck::index_t M, ck::index_t N, ck::index_t K, ck::index_t StrideA, ck::index_t StrideB, std::array< ck::index_t, NumDTensor > StrideDs, ck::index_t StrideE, ck::index_t KBatch, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CDEElementwiseOperation cde_element_op)=0
static constexpr index_t NumDTensor
Definition: device_gemm_multiple_d.hpp:81
Definition: device_gemm_multiple_d.hpp:167
virtual std::unique_ptr< BaseArgument > MakeArgumentPointer(const void *p_a, const void *p_a_scale, const void *p_b, const void *p_b_scale, std::array< const void *, NumDTensor > p_ds, void *p_e, ck::index_t M, ck::index_t N, ck::index_t K, ck::index_t StrideA, ck::index_t StrideAScale, ck::index_t StrideB, ck::index_t StrideBScale, std::array< ck::index_t, NumDTensor > StrideDs, ck::index_t StrideE, ck::index_t KBatch, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CDEElementwiseOperation cde_element_op)=0
virtual std::unique_ptr< BaseInvoker > MakeInvokerPointer()=0
static constexpr index_t NumDTensor
Definition: device_gemm_multiple_d.hpp:168