/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/device/device_gemm_v2.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/device/device_gemm_v2.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/device/device_gemm_v2.hpp Source File
device_gemm_v2.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
7 
8 namespace ck {
9 namespace tensor_operation {
10 namespace device {
11 
12 template <typename ALayout,
13  typename BLayout,
14  typename CLayout,
15  typename ADataType,
16  typename BDataType,
17  typename CDataType,
18  typename AElementwiseOperation,
19  typename BElementwiseOperation,
20  typename CElementwiseOperation>
21 struct DeviceGemmV2 : public BaseOperator
22 {
23  virtual std::unique_ptr<BaseArgument>
24  MakeArgumentPointer(const void* p_a,
25  const void* p_b,
26  void* p_c,
27  ck::index_t M,
28  ck::index_t N,
29  ck::index_t K,
30  ck::index_t StrideA,
31  ck::index_t StrideB,
32  ck::index_t StrideC,
33  ck::index_t KSplit,
34  AElementwiseOperation a_element_op,
35  BElementwiseOperation b_element_op,
36  CElementwiseOperation c_element_op) = 0;
37 
38  virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
39 
40  virtual bool GetPermuteA() = 0;
41  virtual bool GetPermuteB() = 0;
42  virtual ck::index_t GetKPerBlock() = 0;
43 };
44 
45 template <typename ALayout,
46  typename BLayout,
47  typename DsLayout,
48  typename CLayout,
49  typename ADataType,
50  typename BDataType,
51  typename DsDataType,
52  typename CDataType,
53  typename AElementwiseOperation,
54  typename BElementwiseOperation,
55  typename CElementwiseOperation>
57 {
58  static constexpr index_t NumDTensor = DsDataType::Size();
59 
60  virtual std::unique_ptr<BaseArgument>
61  MakeArgumentPointer(const void* p_a,
62  const void* p_b,
63  std::array<const void*, NumDTensor> p_ds,
64  void* p_c,
65  ck::index_t M,
66  ck::index_t N,
67  ck::index_t K,
68  ck::index_t StrideA,
69  ck::index_t StrideB,
70  std::array<ck::index_t, NumDTensor> DsStrides,
71  ck::index_t StrideC,
72  ck::index_t KSplit,
73  AElementwiseOperation a_element_op,
74  BElementwiseOperation b_element_op,
75  CElementwiseOperation c_element_op) = 0;
76 
77  virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
78 };
79 
80 template <typename ALayout,
81  typename BLayout,
82  typename CLayout,
83  typename ADataType,
84  typename BDataType,
85  typename BScaleType,
86  typename CDataType,
87  index_t ScaleBlockN,
88  index_t ScaleBlockK,
89  typename AElementwiseOperation,
90  typename BElementwiseOperation,
91  typename CElementwiseOperation>
93 {
94  virtual std::unique_ptr<BaseArgument>
95  MakeArgumentPointer(const void* p_a,
96  const void* p_b,
97  void* p_c,
98  ck::index_t M,
99  ck::index_t N,
100  ck::index_t K,
101  ck::index_t StrideA,
102  ck::index_t StrideB,
103  ck::index_t StrideC,
104  ck::index_t StrideScaleB,
105  const void* p_b_scale,
106  ck::index_t KSplit,
107  AElementwiseOperation a_element_op,
108  BElementwiseOperation b_element_op,
109  CElementwiseOperation c_element_op) = 0;
110 
111  virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
112 
113  virtual bool GetPermuteB() = 0;
114  virtual ck::index_t GetKPerBlock() = 0;
115 };
116 
117 template <typename ALayout,
118  typename BLayout,
119  typename CLayout,
120  typename ADataType,
121  typename BDataType,
122  typename CDataType,
123  typename AElementwiseOperation,
124  typename BElementwiseOperation,
125  typename CElementwiseOperation>
127 {
128  virtual std::unique_ptr<BaseArgument>
129  MakeArgumentPointer(const void* p_a,
130  const void* p_b,
131  void* p_c,
132  ck::index_t M,
133  ck::index_t N,
134  ck::index_t K,
135  ck::index_t StrideA,
136  ck::index_t StrideB,
137  ck::index_t StrideC,
138  ck::index_t KSplit,
139  AElementwiseOperation a_element_op,
140  BElementwiseOperation b_element_op,
141  CElementwiseOperation c_element_op) = 0;
142 
143  virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
144 
145  virtual bool GetPermuteA() = 0;
146  virtual bool GetPermuteB() = 0;
147  virtual ck::index_t GetKPerBlock() = 0;
148  virtual int GetPreShuffleParameters() = 0;
149 };
150 
151 } // namespace device
152 } // namespace tensor_operation
153 } // namespace ck
Definition: ck.hpp:267
int32_t index_t
Definition: ck.hpp:298
Definition: device_base.hpp:77
virtual std::unique_ptr< BaseInvoker > MakeInvokerPointer()=0
virtual std::unique_ptr< BaseArgument > MakeArgumentPointer(const void *p_a, const void *p_b, void *p_c, ck::index_t M, ck::index_t N, ck::index_t K, ck::index_t StrideA, ck::index_t StrideB, ck::index_t StrideC, ck::index_t KSplit, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CElementwiseOperation c_element_op)=0
Definition: device_gemm_v2.hpp:93
virtual std::unique_ptr< BaseArgument > MakeArgumentPointer(const void *p_a, const void *p_b, void *p_c, ck::index_t M, ck::index_t N, ck::index_t K, ck::index_t StrideA, ck::index_t StrideB, ck::index_t StrideC, ck::index_t StrideScaleB, const void *p_b_scale, ck::index_t KSplit, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CElementwiseOperation c_element_op)=0
virtual std::unique_ptr< BaseInvoker > MakeInvokerPointer()=0
Definition: device_gemm_v2.hpp:22
virtual std::unique_ptr< BaseInvoker > MakeInvokerPointer()=0
virtual std::unique_ptr< BaseArgument > MakeArgumentPointer(const void *p_a, const void *p_b, void *p_c, ck::index_t M, ck::index_t N, ck::index_t K, ck::index_t StrideA, ck::index_t StrideB, ck::index_t StrideC, ck::index_t KSplit, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CElementwiseOperation c_element_op)=0
Definition: device_gemm_v2.hpp:57
virtual std::unique_ptr< BaseInvoker > MakeInvokerPointer()=0
static constexpr index_t NumDTensor
Definition: device_gemm_v2.hpp:58
virtual std::unique_ptr< BaseArgument > MakeArgumentPointer(const void *p_a, const void *p_b, std::array< const void *, NumDTensor > p_ds, void *p_c, ck::index_t M, ck::index_t N, ck::index_t K, ck::index_t StrideA, ck::index_t StrideB, std::array< ck::index_t, NumDTensor > DsStrides, ck::index_t StrideC, ck::index_t KSplit, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CElementwiseOperation c_element_op)=0