/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl.hpp Source File
device_gemm_xdl.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include <iostream>
7 #include <sstream>
8 
18 
19 namespace ck {
20 namespace tensor_operation {
21 namespace device {
22 
23 template <typename ADataType,
24  typename BDataType,
25  typename CDataType,
26  typename AccDataType,
27  typename ALayout,
28  typename BLayout,
29  typename CLayout,
30  typename AElementwiseOperation,
31  typename BElementwiseOperation,
32  typename CElementwiseOperation,
33  GemmSpecialization GemmSpec,
34  ck::index_t BlockSize,
35  ck::index_t MPerBlock,
36  ck::index_t NPerBlock,
37  ck::index_t K0PerBlock,
38  ck::index_t K1,
39  ck::index_t MPerXDL,
40  ck::index_t NPerXDL,
41  ck::index_t MXdlPerWave,
42  ck::index_t NXdlPerWave,
43  typename ABlockTransferThreadClusterLengths_K0_M_K1,
44  typename ABlockTransferThreadClusterArrangeOrder,
45  typename ABlockTransferSrcAccessOrder,
46  ck::index_t ABlockTransferSrcVectorDim,
47  ck::index_t ABlockTransferSrcScalarPerVector,
48  ck::index_t ABlockTransferDstScalarPerVector_K1,
49  bool ABlockLdsAddExtraM,
50  typename BBlockTransferThreadClusterLengths_K0_N_K1,
51  typename BBlockTransferThreadClusterArrangeOrder,
52  typename BBlockTransferSrcAccessOrder,
53  ck::index_t BBlockTransferSrcVectorDim,
54  ck::index_t BBlockTransferSrcScalarPerVector,
55  ck::index_t BBlockTransferDstScalarPerVector_K1,
56  bool BBlockLdsAddExtraN,
57  ck::index_t CThreadTransferSrcDstVectorDim,
58  ck::index_t CThreadTransferDstScalarPerVector,
59  ck::index_t NumPrefetch = 1,
62 struct DeviceGemmXdl : public DeviceGemm<ALayout,
63  BLayout,
64  CLayout,
65  ADataType,
66  BDataType,
67  CDataType,
68  AElementwiseOperation,
69  BElementwiseOperation,
70  CElementwiseOperation>
71 {
73  static constexpr auto NXdlPerWave64 = GetNXdlPerWave<true>();
74  static constexpr auto NXdlPerWave32 = GetNXdlPerWave<false>();
75 
76  static constexpr auto I0 = Number<0>{};
77  static constexpr auto I1 = Number<1>{};
78  static constexpr auto I2 = Number<2>{};
79 
80  static constexpr auto K1Number = Number<K1>{};
81 
82  // GridwiseGemm
83  template <index_t NXdlPerWave_>
85  BlockSize,
86  ADataType, // TODO: distinguish A/B datatype
87  AccDataType,
88  CDataType,
90  ALayout,
91  BLayout,
92  CLayout,
93  AElementwiseOperation,
94  BElementwiseOperation,
95  CElementwiseOperation,
96  GemmSpec,
97  MPerBlock,
98  NPerBlock,
99  K0PerBlock,
100  MPerXDL,
101  NPerXDL,
102  K1,
103  MXdlPerWave,
104  NXdlPerWave_,
105  ABlockTransferThreadClusterLengths_K0_M_K1,
106  ABlockTransferThreadClusterArrangeOrder,
107  ABlockTransferSrcAccessOrder,
108  ABlockTransferSrcVectorDim,
109  ABlockTransferSrcScalarPerVector,
110  ABlockTransferDstScalarPerVector_K1,
111  false, // AThreadTransferSrcResetCoordinateAfterRun,
112  ABlockLdsAddExtraM,
113  BBlockTransferThreadClusterLengths_K0_N_K1,
114  BBlockTransferThreadClusterArrangeOrder,
115  BBlockTransferSrcAccessOrder,
116  BBlockTransferSrcVectorDim,
117  BBlockTransferSrcScalarPerVector,
118  BBlockTransferDstScalarPerVector_K1,
119  false, // BThreadTransferSrcResetCoordinateAfterRun,
120  BBlockLdsAddExtraN,
121  Sequence<0, 2, 4, 5, 6, 1, 3, 7>, // CThreadTransferSrcDstAccessOrder,
122  CThreadTransferSrcDstVectorDim,
123  CThreadTransferDstScalarPerVector,
124  NumPrefetch,
125  LoopSched,
126  PipelineVer>;
129 
130  using Argument = typename GridwiseGemm64::Argument;
131 
132  // Invoker
133  struct Invoker : public BaseInvoker
134  {
135  template <typename GridwiseGemm>
136  float RunImp(const typename GridwiseGemm::Argument& karg,
137  const StreamConfig& stream_config = StreamConfig{})
138  {
139  if(stream_config.log_level_ > 0)
140  {
141  karg.Print();
142  }
143 
144  if(!GridwiseGemm::CheckValidity(karg))
145  {
146  throw std::runtime_error(
147  "wrong! GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3_ext has invalid setting");
148  }
149 
150  const auto [gdx, gdy, gdz] = GridwiseGemm::CalculateGridSize(karg.M, karg.N);
151 
152  float ave_time = 0;
153 
154  if(GridwiseGemm::CalculateHasMainKBlockLoop(karg.K))
155  {
156  const auto kernel = kernel_gemm_xdlops_v2r3<GridwiseGemm, true>;
157 
158  ave_time = launch_and_time_kernel(
159  stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, karg);
160  }
161  else
162  {
163  const auto kernel = kernel_gemm_xdlops_v2r3<GridwiseGemm, false>;
164 
165  ave_time = launch_and_time_kernel(
166  stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, karg);
167  }
168 
169  return ave_time;
170  }
171 
173 
174  // polymorphic
175  float Run(const BaseArgument* p_arg,
176  const StreamConfig& stream_config = StreamConfig{}) override
177  {
178  return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
179  }
180  };
181 
182  static constexpr bool IsValidCompilationParameter()
183  {
184  // TODO: properly implement this check
185  return true;
186  }
187 
188  static bool IsSupportedArgument(const Argument& karg)
189  {
190  if(ck::get_device_name() == "gfx908")
191  {
192  if constexpr(!(is_same_v<AccDataType, float> || is_same_v<AccDataType, float> ||
193  is_same_v<AccDataType, int32_t>))
194  {
195  return false;
196  }
197  }
199  {
200  if constexpr(!(is_same_v<AccDataType, float> || is_same_v<AccDataType, float> ||
201  is_same_v<AccDataType, int32_t> || is_same_v<AccDataType, double>))
202  {
203  return false;
204  }
205  }
206  else
207  {
208  return false;
209  }
210 
211  if(karg.K % K1 != 0)
212  {
213  return false;
214  }
215  if(get_warp_size() == 64)
216  {
217  if constexpr(NXdlPerWave64 > 0)
218  {
219  return GridwiseGemm64::CheckValidity(karg);
220  }
221  }
222  else
223  {
224  if constexpr(NXdlPerWave32 > 0)
225  {
227  reinterpret_cast<const typename GridwiseGemm32::Argument&>(karg));
228  }
229  }
230  return false;
231  }
232 
233  // polymorphic
234  bool IsSupportedArgument(const BaseArgument* p_arg) override
235  {
236  return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
237  }
238 
239  static auto MakeArgument(const ADataType* p_a,
240  const BDataType* p_b,
241  CDataType* p_c,
242  index_t M,
243  index_t N,
244  index_t K,
245  index_t StrideA,
246  index_t StrideB,
247  index_t StrideC,
248  AElementwiseOperation,
249  BElementwiseOperation,
250  CElementwiseOperation)
251  {
252  return Argument{p_a, p_b, p_c, M, N, K, StrideA, StrideB, StrideC};
253  }
254 
255  static auto MakeInvoker() { return Invoker{}; }
256 
257  // polymorphic
258  std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a,
259  const void* p_b,
260  void* p_c,
261  index_t M,
262  index_t N,
263  index_t K,
264  index_t StrideA,
265  index_t StrideB,
266  index_t StrideC,
267  AElementwiseOperation,
268  BElementwiseOperation,
269  CElementwiseOperation) override
270  {
271  return std::make_unique<Argument>(static_cast<const ADataType*>(p_a),
272  static_cast<const BDataType*>(p_b),
273  static_cast<CDataType*>(p_c),
274  M,
275  N,
276  K,
277  StrideA,
278  StrideB,
279  StrideC);
280  }
281 
282  // polymorphic
283  std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
284  {
285  return std::make_unique<Invoker>(Invoker{});
286  }
287 
288  // polymorphic
289  std::string GetTypeString() const override
290  {
291  auto str = std::stringstream();
292 
293  std::map<LoopScheduler, std::string> LoopSchedToString{
294  {LoopScheduler::Default, "Default"}, {LoopScheduler::Interwave, "Interwave"}};
295 
296  std::map<PipelineVersion, std::string> PipelineVersionToString{{PipelineVersion::v1, "v1"},
297  {PipelineVersion::v2, "v2"}};
298 
299  // clang-format off
300  str << "DeviceGemmXdl"
301  << "<"
302  << BlockSize << ", "
303  << MPerBlock << ", "
304  << NPerBlock << ", "
305  << K0PerBlock << ", "
306  << K1 << ", "
307  << MPerXDL << ", "
308  << NPerXDL << ", "
309  << MXdlPerWave << ", "
310  << NXdlPerWave << ", "
311  << ABlockTransferSrcScalarPerVector << ", "
312  << ABlockTransferDstScalarPerVector_K1 << ", "
313  << BBlockTransferSrcScalarPerVector << ", "
314  << BBlockTransferDstScalarPerVector_K1
315  << ">"
316  << " NumPrefetch: "
317  << NumPrefetch << ", "
318  << "LoopScheduler: "
319  << LoopSchedToString[LoopSched] << ", "
320  << "PipelineVersion: "
321  << PipelineVersionToString[PipelineVer];
322  // clang-format on
323 
324  return str.str();
325  }
326 };
327 
328 } // namespace device
329 } // namespace tensor_operation
330 } // namespace ck
#define INVOKER_RUN3_IMPL
Definition: device_base.hpp:114
#define GET_NXDL_PER_WAVE_IMPL
Definition: device_base.hpp:81
float launch_and_time_kernel(const StreamConfig &stream_config, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args)
Definition: kernel_launch.hpp:14
__host__ constexpr __device__ T max(T x)
Definition: math.hpp:84
GemmSpecialization
Definition: gemm_specialization.hpp:11
Definition: ck.hpp:268
bool is_lds_direct_load_supported()
Definition: device_prop.hpp:101
std::string get_device_name()
Definition: device_prop.hpp:19
constexpr __device__ index_t get_warp_size()
Definition: get_id.hpp:10
LoopScheduler
Definition: loop_scheduler.hpp:15
int32_t index_t
Definition: ck.hpp:299
PipelineVersion
Definition: gridwise_gemm_pipeline_selector.hpp:18
constexpr LoopScheduler make_default_loop_scheduler()
Definition: loop_scheduler.hpp:20
Definition: stream_config.hpp:10
Definition: gridwise_gemm_xdlops_v2r3.hpp:814
static constexpr __host__ bool CheckValidity(const Problem &problem)
Definition: gridwise_gemm_xdlops_v2r3.hpp:1003
Definition: sequence.hpp:43
Definition: integral_constant.hpp:20
Definition: device_base.hpp:197
Definition: device_base.hpp:208
Definition: device_gemm.hpp:22
Definition: device_gemm_xdl.hpp:134
INVOKER_RUN3_IMPL float Run(const BaseArgument *p_arg, const StreamConfig &stream_config=StreamConfig{}) override
Definition: device_gemm_xdl.hpp:175
float RunImp(const typename GridwiseGemm::Argument &karg, const StreamConfig &stream_config=StreamConfig{})
Definition: device_gemm_xdl.hpp:136
Definition: device_gemm_xdl.hpp:71
static bool IsSupportedArgument(const Argument &karg)
Definition: device_gemm_xdl.hpp:188
static constexpr auto K1Number
Definition: device_gemm_xdl.hpp:80
static constexpr auto I0
Definition: device_gemm_xdl.hpp:76
bool IsSupportedArgument(const BaseArgument *p_arg) override
Definition: device_gemm_xdl.hpp:234
typename GridwiseGemm64::Argument Argument
Definition: device_gemm_xdl.hpp:130
static auto MakeInvoker()
Definition: device_gemm_xdl.hpp:255
std::unique_ptr< BaseInvoker > MakeInvokerPointer() override
Definition: device_gemm_xdl.hpp:283
std::string GetTypeString() const override
Definition: device_gemm_xdl.hpp:289
static constexpr auto I2
Definition: device_gemm_xdl.hpp:78
static constexpr auto I1
Definition: device_gemm_xdl.hpp:77
static auto MakeArgument(const ADataType *p_a, const BDataType *p_b, CDataType *p_c, index_t M, index_t N, index_t K, index_t StrideA, index_t StrideB, index_t StrideC, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation)
Definition: device_gemm_xdl.hpp:239
static constexpr bool IsValidCompilationParameter()
Definition: device_gemm_xdl.hpp:182
static constexpr auto NXdlPerWave32
Definition: device_gemm_xdl.hpp:74
static constexpr GET_NXDL_PER_WAVE_IMPL auto NXdlPerWave64
Definition: device_gemm_xdl.hpp:73
std::unique_ptr< BaseArgument > MakeArgumentPointer(const void *p_a, const void *p_b, void *p_c, index_t M, index_t N, index_t K, index_t StrideA, index_t StrideB, index_t StrideC, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation) override
Definition: device_gemm_xdl.hpp:258