/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/device/device_batched_gemm.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/device/device_batched_gemm.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/device/device_batched_gemm.hpp Source File
device_batched_gemm.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include <iostream>
7 #include <vector>
8 
9 #include "device_base.hpp"
10 
11 namespace ck {
12 namespace tensor_operation {
13 namespace device {
14 
15 template <typename ALayout,
16  typename BLayout,
17  typename CLayout,
18  typename ADataType,
19  typename BDataType,
20  typename CDataType,
21  typename AElementwiseOperation,
22  typename BElementwiseOperation,
23  typename CElementwiseOperation>
25 {
26  virtual std::unique_ptr<BaseArgument>
27  MakeArgumentPointer(const void* p_a,
28  const void* p_b,
29  void* p_c,
30  ck::index_t M,
31  ck::index_t N,
32  ck::index_t K,
33  ck::index_t StrideA,
34  ck::index_t StrideB,
35  ck::index_t StrideC,
36  ck::index_t BatchStrideA,
37  ck::index_t BatchStrideB,
38  ck::index_t BatchStrideC,
39  ck::index_t Batch,
40  AElementwiseOperation a_element_op,
41  BElementwiseOperation b_element_op,
42  CElementwiseOperation c_element_op) = 0;
43 
44  virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
45 };
46 
47 template <typename ALayout,
48  typename BLayout,
49  typename CLayout,
50  typename ADataType,
51  typename BDataType,
52  typename BScaleType,
53  typename CDataType,
54  index_t ScaleBlockN,
55  index_t ScaleBlockK,
56  typename AElementwiseOperation,
57  typename BElementwiseOperation,
58  typename CElementwiseOperation>
60 {
61  virtual std::unique_ptr<BaseArgument>
62  MakeArgumentPointer(const void* p_a,
63  const void* p_b,
64  void* p_c,
65  ck::index_t M,
66  ck::index_t N,
67  ck::index_t K,
68  ck::index_t StrideA,
69  ck::index_t StrideB,
70  ck::index_t StrideC,
71  ck::index_t StrideScaleB,
72  ck::index_t BatchStrideA,
73  ck::index_t BatchStrideB,
74  ck::index_t BatchStrideC,
75  ck::index_t BatchStrideScaleB,
76  const void* p_b_scale,
77  ck::index_t Batch,
78  ck::index_t KBatch,
79  AElementwiseOperation a_element_op,
80  BElementwiseOperation b_element_op,
81  CElementwiseOperation c_element_op) = 0;
82 
83  virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
84 
85  virtual bool GetPermuteB() = 0;
86  virtual ck::index_t GetKPerBlock() = 0;
87 };
88 
89 template <typename ALayout,
90  typename BLayout,
91  typename CLayout,
92  typename ADataType,
93  typename BDataType,
94  typename CDataType,
95  typename AElementwiseOperation,
96  typename BElementwiseOperation,
97  typename CElementwiseOperation>
98 using DeviceBatchedGemmPtr = std::unique_ptr<DeviceBatchedGemm<ALayout,
99  BLayout,
100  CLayout,
101  ADataType,
102  BDataType,
103  CDataType,
104  AElementwiseOperation,
105  BElementwiseOperation,
106  CElementwiseOperation>>;
107 
108 } // namespace device
109 } // namespace tensor_operation
110 } // namespace ck
std::unique_ptr< DeviceBatchedGemm< ALayout, BLayout, CLayout, ADataType, BDataType, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation > > DeviceBatchedGemmPtr
Definition: device_batched_gemm.hpp:106
Definition: ck.hpp:267
int32_t index_t
Definition: ck.hpp:298
Definition: device_base.hpp:77
Definition: device_batched_gemm.hpp:25
virtual std::unique_ptr< BaseArgument > MakeArgumentPointer(const void *p_a, const void *p_b, void *p_c, ck::index_t M, ck::index_t N, ck::index_t K, ck::index_t StrideA, ck::index_t StrideB, ck::index_t StrideC, ck::index_t BatchStrideA, ck::index_t BatchStrideB, ck::index_t BatchStrideC, ck::index_t Batch, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CElementwiseOperation c_element_op)=0
virtual std::unique_ptr< BaseInvoker > MakeInvokerPointer()=0
Definition: device_batched_gemm.hpp:60
virtual std::unique_ptr< BaseArgument > MakeArgumentPointer(const void *p_a, const void *p_b, void *p_c, ck::index_t M, ck::index_t N, ck::index_t K, ck::index_t StrideA, ck::index_t StrideB, ck::index_t StrideC, ck::index_t StrideScaleB, ck::index_t BatchStrideA, ck::index_t BatchStrideB, ck::index_t BatchStrideC, ck::index_t BatchStrideScaleB, const void *p_b_scale, ck::index_t Batch, ck::index_t KBatch, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CElementwiseOperation c_element_op)=0
virtual std::unique_ptr< BaseInvoker > MakeInvokerPointer()=0