/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/device/device_grouped_gemm.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/device/device_grouped_gemm.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/device/device_grouped_gemm.hpp Source File
device_grouped_gemm.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include <array>
7 #include <iostream>
8 #include <sstream>
9 #include <stdexcept>
10 #include <vector>
11 
12 #include "device_base.hpp"
13 #include "ck/utility/ignore.hpp"
14 
15 namespace ck {
16 namespace tensor_operation {
17 namespace device {
18 
27 template <index_t NumDTensor = 0>
29 {
30  __host__ __device__ GroupedGemmKernelArgument(const void* p_a_grid_,
31  const void* p_b_grid_,
32  std::array<const void*, NumDTensor> p_ds_grid_,
33  void* p_e_grid_,
34  index_t M_,
35  index_t N_,
36  index_t K_,
37  index_t StrideA_,
38  index_t StrideB_,
39  std::array<index_t, NumDTensor> StrideDs_,
40  index_t StrideE_)
41  : p_a_grid{p_a_grid_},
42  p_b_grid{p_b_grid_},
43  p_ds_grid{p_ds_grid_},
44  p_e_grid{p_e_grid_},
45  M{M_},
46  N{N_},
47  K{K_},
48  StrideA{StrideA_},
49  StrideB{StrideB_},
50  StrideDs{StrideDs_},
51  StrideE{StrideE_}
52  {
53  }
54 
55  const void* p_a_grid;
56  const void* p_b_grid;
57  std::array<const void*, NumDTensor> p_ds_grid;
58  void* p_e_grid;
64  std::array<index_t, NumDTensor> StrideDs;
66 
67  void Print() const
68  {
69  std::stringstream str;
70  for(auto sd : StrideDs)
71  str << sd << ",";
72 
73  std::cout << "arg {" << "M:" << M << ", " << "N:" << N << ", " << "K:" << K << ", "
74  << "SA:" << StrideA << ", " << "SB:" << StrideB << ", " << "SE:" << StrideE
75  << ", " << "SDs: {" << str.str() << "}" << "}" << std::endl;
76  }
77 };
78 
79 struct GemmDesc
80 {
83 
84  std::vector<ck::index_t> stride_Ds_;
85 };
86 
87 template <typename ALayout,
88  typename BLayout,
89  typename DsLayout,
90  typename ELayout,
91  typename ADataType,
92  typename BDataType,
93  typename DsDataType,
94  typename EDataType,
95  typename AElementwiseOperation,
96  typename BElementwiseOperation,
97  typename CElementwiseOperation>
99 {
100  static constexpr index_t NumDTensor = DsDataType::Size();
101 
102  static_assert(DsLayout::Size() == DsDataType::Size(), "wrong! inconsistent NumDTensor");
103 
104  virtual std::unique_ptr<BaseArgument>
105  MakeArgumentPointer(std::vector<const void*>& p_a,
106  std::vector<const void*>& p_b,
107  std::vector<std::array<const void*, NumDTensor>>& p_ds,
108  std::vector<void*>& p_e,
109  std::vector<GemmDesc>& gemm_desc,
110  AElementwiseOperation a_element_op,
111  BElementwiseOperation b_element_op,
112  CElementwiseOperation c_element_op) = 0;
113 
114  virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
115 
116  //---------------------------------------------------------------------------------------------
127  virtual void SetDeviceKernelArgs(BaseArgument* p_arg,
128  void* p_dev_kernel_args,
129  const void* p_host_kernel_args) const
130  {
131  ignore = p_arg;
132  ignore = p_dev_kernel_args;
133  ignore = p_host_kernel_args;
134 
135  std::ostringstream err;
136  err << "This function is not implemented by the kernel: " << this->GetTypeString()
137  << __FILE__ << ":" << __LINE__ << ", in function: " << __func__;
138  throw std::runtime_error(err.str());
139  }
140 
141  //----------------------------------------------------------------------------------------------
148  virtual void SetDeviceKernelArgs(BaseArgument* p_arg, void* p_dev_kernel_args) const
149  {
150  ignore = p_arg;
151  ignore = p_dev_kernel_args;
152 
153  std::ostringstream err;
154  err << "This function is not implemented by the kernel: " << this->GetTypeString()
155  << __FILE__ << ":" << __LINE__ << ", in function: " << __func__;
156  throw std::runtime_error(err.str());
157  }
158 
159  //----------------------------------------------------------------------------------------------
166  virtual size_t GetDeviceKernelArgSize(const BaseArgument* p_arg) const
167  {
168  ignore = p_arg;
169 
170  std::ostringstream err;
171  err << "This function is not implemented by the kernel: " << this->GetTypeString()
172  << __FILE__ << ":" << __LINE__ << ", in function: " << __func__;
173  throw std::runtime_error(err.str());
174  }
175 };
176 
177 } // namespace device
178 } // namespace tensor_operation
179 } // namespace ck
Definition: ck.hpp:267
constexpr detail::ignore_t ignore
Definition: ignore.hpp:20
int32_t index_t
Definition: ck.hpp:298
Definition: device_base.hpp:51
Definition: device_base.hpp:77
virtual std::string GetTypeString() const
Definition: device_base.hpp:83
Definition: device_grouped_gemm.hpp:99
virtual void SetDeviceKernelArgs(BaseArgument *p_arg, void *p_dev_kernel_args) const
Sets the device kernel arguments pointer and may copy data to device.
Definition: device_grouped_gemm.hpp:148
virtual std::unique_ptr< BaseInvoker > MakeInvokerPointer()=0
virtual std::unique_ptr< BaseArgument > MakeArgumentPointer(std::vector< const void * > &p_a, std::vector< const void * > &p_b, std::vector< std::array< const void *, NumDTensor >> &p_ds, std::vector< void * > &p_e, std::vector< GemmDesc > &gemm_desc, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CElementwiseOperation c_element_op)=0
virtual void SetDeviceKernelArgs(BaseArgument *p_arg, void *p_dev_kernel_args, const void *p_host_kernel_args) const
Sets the device kernel arguments pointer and may copy data to device.
Definition: device_grouped_gemm.hpp:127
static constexpr index_t NumDTensor
Definition: device_grouped_gemm.hpp:100
virtual size_t GetDeviceKernelArgSize(const BaseArgument *p_arg) const
Gets the device kernel argument size.
Definition: device_grouped_gemm.hpp:166
Definition: device_grouped_gemm.hpp:80
ck::index_t stride_C_
Definition: device_grouped_gemm.hpp:82
std::vector< ck::index_t > stride_Ds_
Definition: device_grouped_gemm.hpp:84
ck::index_t K_
Definition: device_grouped_gemm.hpp:81
ck::index_t stride_A_
Definition: device_grouped_gemm.hpp:82
ck::index_t N_
Definition: device_grouped_gemm.hpp:81
ck::index_t stride_B_
Definition: device_grouped_gemm.hpp:82
ck::index_t M_
Definition: device_grouped_gemm.hpp:81
Structure representing single GEMM problem arguments.
Definition: device_grouped_gemm.hpp:29
void Print() const
Definition: device_grouped_gemm.hpp:67
index_t StrideB
Definition: device_grouped_gemm.hpp:63
void * p_e_grid
Definition: device_grouped_gemm.hpp:58
index_t StrideE
Definition: device_grouped_gemm.hpp:65
index_t N
Definition: device_grouped_gemm.hpp:60
const void * p_a_grid
Definition: device_grouped_gemm.hpp:55
index_t K
Definition: device_grouped_gemm.hpp:61
std::array< index_t, NumDTensor > StrideDs
Definition: device_grouped_gemm.hpp:64
index_t StrideA
Definition: device_grouped_gemm.hpp:62
index_t M
Definition: device_grouped_gemm.hpp:59
__host__ __device__ GroupedGemmKernelArgument(const void *p_a_grid_, const void *p_b_grid_, std::array< const void *, NumDTensor > p_ds_grid_, void *p_e_grid_, index_t M_, index_t N_, index_t K_, index_t StrideA_, index_t StrideB_, std::array< index_t, NumDTensor > StrideDs_, index_t StrideE_)
Definition: device_grouped_gemm.hpp:30
const void * p_b_grid
Definition: device_grouped_gemm.hpp:56
std::array< const void *, NumDTensor > p_ds_grid
Definition: device_grouped_gemm.hpp:57