/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/device/device_gemm_mx.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/device/device_gemm_mx.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/device/device_gemm_mx.hpp Source File
device_gemm_mx.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
7 
8 namespace ck {
9 namespace tensor_operation {
10 namespace device {
11 
12 template <typename ALayout,
13  typename BLayout,
14  typename CLayout,
15  typename ADataType,
16  typename AScaleDataType,
17  typename BDataType,
18  typename BScaleDataType,
19  typename CDataType,
20  index_t ScaleBlockSize,
21  typename AElementwiseOperation,
22  typename BElementwiseOperation,
23  typename CElementwiseOperation>
24 struct DeviceGemmMX : public BaseOperator
25 {
26  virtual std::unique_ptr<BaseArgument>
27  MakeArgumentPointer(const void* p_a,
28  const void* p_a_scale,
29  const void* p_b,
30  const void* p_b_scale,
31  void* p_c,
32  ck::index_t M,
33  ck::index_t N,
34  ck::index_t K,
35  ck::index_t StrideA,
36  ck::index_t StrideAScale,
37  ck::index_t StrideB,
38  ck::index_t StrideBScale,
39  ck::index_t StrideC,
40  ck::index_t KBatch,
41  AElementwiseOperation a_element_op,
42  BElementwiseOperation b_element_op,
43  CElementwiseOperation c_element_op) = 0;
44 
45  virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
46 };
47 
48 template <typename ALayout,
49  typename BLayout,
50  typename CLayout,
51  typename ADataType,
52  typename AScaleDataType,
53  typename BDataType,
54  typename BScaleDataType,
55  typename CDataType,
56  index_t ScaleBlockSize,
57  typename AElementwiseOperation,
58  typename BElementwiseOperation,
59  typename CElementwiseOperation>
61 {
62  virtual std::unique_ptr<BaseArgument>
63  MakeArgumentPointer(const void* p_a,
64  const void* p_a_scale,
65  const void* p_b,
66  const void* p_b_scale,
67  void* p_c,
68  ck::index_t M,
69  ck::index_t N,
70  ck::index_t K,
71  ck::index_t StrideA,
72  ck::index_t StrideAScale,
73  ck::index_t StrideB,
74  ck::index_t StrideBScale,
75  ck::index_t StrideC,
76  ck::index_t KBatch,
77  AElementwiseOperation a_element_op,
78  BElementwiseOperation b_element_op,
79  CElementwiseOperation c_element_op) = 0;
80 
81  virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
82 
83  virtual int GetPreShuffleParameters() = 0;
84 };
85 
86 } // namespace device
87 } // namespace tensor_operation
88 } // namespace ck
Definition: ck.hpp:267
int32_t index_t
Definition: ck.hpp:298
Definition: device_base.hpp:77
virtual std::unique_ptr< BaseArgument > MakeArgumentPointer(const void *p_a, const void *p_a_scale, const void *p_b, const void *p_b_scale, void *p_c, ck::index_t M, ck::index_t N, ck::index_t K, ck::index_t StrideA, ck::index_t StrideAScale, ck::index_t StrideB, ck::index_t StrideBScale, ck::index_t StrideC, ck::index_t KBatch, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CElementwiseOperation c_element_op)=0
virtual std::unique_ptr< BaseInvoker > MakeInvokerPointer()=0
Definition: device_gemm_mx.hpp:25
virtual std::unique_ptr< BaseArgument > MakeArgumentPointer(const void *p_a, const void *p_a_scale, const void *p_b, const void *p_b_scale, void *p_c, ck::index_t M, ck::index_t N, ck::index_t K, ck::index_t StrideA, ck::index_t StrideAScale, ck::index_t StrideB, ck::index_t StrideBScale, ck::index_t StrideC, ck::index_t KBatch, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CElementwiseOperation c_element_op)=0
virtual std::unique_ptr< BaseInvoker > MakeInvokerPointer()=0