/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v2.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v2.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v2.hpp Source File
device_gemm_xdl_cshuffle_v2.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include <iostream>
7 #include <sstream>
8 
18 
19 namespace ck {
20 namespace tensor_operation {
21 namespace device {
22 
23 // Note: inter-wave loop scheduler is rolled out to c-shuffle version first. Becuase non c-shuffle
24 // version currently has compiler issues with register spill which further causes validation
25 // failures.
26 template <typename ALayout,
27  typename BLayout,
28  typename CLayout,
29  typename ADataType,
30  typename BDataType,
31  typename CDataType,
32  typename GemmAccDataType,
33  typename CShuffleDataType,
34  typename AElementwiseOperation,
35  typename BElementwiseOperation,
36  typename CElementwiseOperation,
37  GemmSpecialization GemmSpec,
38  index_t NumGemmKPrefetchStage,
39  index_t BlockSize,
40  index_t MPerBlock,
41  index_t NPerBlock,
42  index_t KPerBlock,
43  index_t AK1,
44  index_t BK1,
45  index_t MPerXDL,
46  index_t NPerXDL,
47  index_t MXdlPerWave,
48  index_t NXdlPerWave,
49  typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
50  typename ABlockTransferThreadClusterArrangeOrder,
51  typename ABlockTransferSrcAccessOrder,
52  index_t ABlockTransferSrcVectorDim,
53  index_t ABlockTransferSrcScalarPerVector,
54  index_t ABlockTransferDstScalarPerVector_AK1,
55  bool ABlockLdsExtraM,
56  typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
57  typename BBlockTransferThreadClusterArrangeOrder,
58  typename BBlockTransferSrcAccessOrder,
59  index_t BBlockTransferSrcVectorDim,
60  index_t BBlockTransferSrcScalarPerVector,
61  index_t BBlockTransferDstScalarPerVector_BK1,
62  bool BBlockLdsExtraN,
63  index_t CShuffleMXdlPerWavePerShuffle,
64  index_t CShuffleNXdlPerWavePerShuffle,
65  typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
66  index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
69  typename ComputeTypeA = CDataType,
70  typename ComputeTypeB = ComputeTypeA>
71 struct DeviceGemm_Xdl_CShuffleV2 : public DeviceGemm<ALayout,
72  BLayout,
73  CLayout,
74  ADataType,
75  BDataType,
76  CDataType,
77  AElementwiseOperation,
78  BElementwiseOperation,
79  CElementwiseOperation>
80 {
83  static constexpr auto NXdlPerWave64 = GetNXdlPerWave<true>();
84  static constexpr auto NXdlPerWave32 = GetNXdlPerWave<false>();
85 
86  static constexpr auto I0 = Number<0>{};
87  static constexpr auto I1 = Number<1>{};
88  static constexpr auto I2 = Number<2>{};
89 
90  // GridwiseGemm
91  template <index_t NXdlPerWave_>
93  ALayout,
94  BLayout,
95  CLayout,
96  ADataType,
97  BDataType,
98  GemmAccDataType,
99  CShuffleDataType,
100  CDataType,
101  AElementwiseOperation,
102  BElementwiseOperation,
103  CElementwiseOperation,
104  GemmSpec,
106  NumGemmKPrefetchStage,
107  BlockSize,
108  MPerBlock,
109  NPerBlock,
110  KPerBlock,
111  AK1,
112  BK1,
113  MPerXDL,
114  NPerXDL,
115  MXdlPerWave,
116  NXdlPerWave_,
117  ABlockTransferThreadClusterLengths_AK0_M_AK1,
118  ABlockTransferThreadClusterArrangeOrder,
119  ABlockTransferSrcAccessOrder,
120  ABlockTransferSrcVectorDim,
121  ABlockTransferSrcScalarPerVector,
122  ABlockTransferDstScalarPerVector_AK1,
123  false,
124  ABlockLdsExtraM,
125  BBlockTransferThreadClusterLengths_BK0_N_BK1,
126  BBlockTransferThreadClusterArrangeOrder,
127  BBlockTransferSrcAccessOrder,
128  BBlockTransferSrcVectorDim,
129  BBlockTransferSrcScalarPerVector,
130  BBlockTransferDstScalarPerVector_BK1,
131  false,
132  BBlockLdsExtraN,
133  CShuffleMXdlPerWavePerShuffle,
134  CShuffleNXdlPerWavePerShuffle,
135  CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
136  CShuffleBlockTransferScalarPerVector_NPerBlock,
137  LoopSched,
138  PipelineVer,
139  ComputeTypeA,
140  ComputeTypeB>;
143 
145 
146  // Invoker
147  struct Invoker : public BaseInvoker
148  {
149  template <typename GridwiseGemm>
150  float RunImp(const typename GridwiseGemm::Argument& arg,
151  const StreamConfig& stream_config = StreamConfig{})
152  {
153  if(stream_config.log_level_ > 0)
154  {
155  arg.Print();
156  }
157 
158  if(!GridwiseGemm::CheckValidity(arg))
159  {
160  throw std::runtime_error("wrong! GridwiseGemm has invalid setting");
161  }
162 
163  index_t gdx, gdy, gdz;
164  std::tie(gdx, gdy, gdz) = GridwiseGemm::CalculateGridSize(arg.M, arg.N);
165 
166  float ave_time = 0;
167  const auto K = GridwiseGemm::CalculateAK0(arg.K) * AK1;
168 
169  if(GridwiseGemm::CalculateKBlockLoopTailNum(K) == 3)
170  {
171  const auto kernel = kernel_gemm_xdl_cshuffle_v2<GridwiseGemm, true>;
172  ave_time = launch_and_time_kernel(
173  stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, arg);
174  }
175  else
176  {
177  const auto kernel = kernel_gemm_xdl_cshuffle_v2<GridwiseGemm, true, 2>;
178  ave_time = launch_and_time_kernel(
179  stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, arg);
180  }
181 
182  return ave_time;
183  }
184 
186 
187  // polymorphic
188  float Run(const BaseArgument* p_arg,
189  const StreamConfig& stream_config = StreamConfig{}) override
190  {
191  return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
192  }
193  };
194 
195  static constexpr bool IsValidCompilationParameter()
196  {
197  // TODO: properly implement this check
198  return true;
199  }
200 
201  static bool IsSupportedArgument(const Argument& arg)
202  {
203  if(!ck::is_xdl_wmma_supported<ComputeTypeA, ComputeTypeB, MPerXDL, NPerXDL>())
204  {
205  return false;
206  }
207  if((arg.K % AK1 != 0 || arg.K % BK1 != 0) && !(GemmSpec == GemmSpecialization::MKPadding ||
208  GemmSpec == GemmSpecialization::NKPadding ||
209  GemmSpec == GemmSpecialization::MNKPadding ||
210  GemmSpec == GemmSpecialization::KPadding))
211  {
212  return false;
213  }
214 
215  if(get_warp_size() == 64)
216  {
217  if constexpr(NXdlPerWave64 > 0)
218  {
219  return GridwiseGemm64::CheckValidity(arg);
220  }
221  }
222  else
223  {
224  if constexpr(NXdlPerWave32 > 0)
225  {
227  reinterpret_cast<const typename GridwiseGemm32::Argument&>(arg));
228  }
229  }
230  return false;
231  }
232 
233  // polymorphic
234  bool IsSupportedArgument(const BaseArgument* p_arg) override
235  {
236  return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
237  }
238 
239  static auto MakeArgument(const ADataType* p_a,
240  const BDataType* p_b,
241  CDataType* p_c,
242  index_t M,
243  index_t N,
244  index_t K,
245  index_t StrideA,
246  index_t StrideB,
247  index_t StrideC,
248  AElementwiseOperation,
249  BElementwiseOperation,
250  CElementwiseOperation)
251  {
252  return Argument{p_a, p_b, p_c, M, N, K, StrideA, StrideB, StrideC};
253  }
254 
255  static auto MakeInvoker() { return Invoker{}; }
256 
257  // polymorphic
258  std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a,
259  const void* p_b,
260  void* p_c,
261  index_t M,
262  index_t N,
263  index_t K,
264  index_t StrideA,
265  index_t StrideB,
266  index_t StrideC,
267  AElementwiseOperation,
268  BElementwiseOperation,
269  CElementwiseOperation) override
270  {
271  return std::make_unique<Argument>(static_cast<const ADataType*>(p_a),
272  static_cast<const BDataType*>(p_b),
273  static_cast<CDataType*>(p_c),
274  M,
275  N,
276  K,
277  StrideA,
278  StrideB,
279  StrideC);
280  }
281 
282  // polymorphic
283  std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
284  {
285  return std::make_unique<Invoker>(Invoker{});
286  }
287 
288  // polymorphic
289  std::string GetTypeString() const override
290  {
291  auto str = std::stringstream();
292 
293  std::map<LoopScheduler, std::string> LoopSchedToString{
294  {LoopScheduler::Default, "Default"}, {LoopScheduler::Interwave, "Interwave"}};
295 
296  std::map<PipelineVersion, std::string> PipelineVersionToString{{PipelineVersion::v1, "v1"},
297  {PipelineVersion::v2, "v2"}};
298 
299  // clang-format off
300  str << "DeviceGemm_Xdl_CShuffleV2"
301  << "<"
302  << getGemmSpecializationString(GemmSpec) << ", "
303  << BlockSize << ", "
304  << MPerBlock << ", "
305  << NPerBlock << ", "
306  << KPerBlock << ", "
307  << AK1 << ", "
308  << BK1 << ", "
309  << MPerXDL << ", "
310  << NPerXDL << ", "
311  << MXdlPerWave << ", "
312  << NXdlPerWave << ", "
313  << ABlockTransferSrcScalarPerVector << ", "
314  << BBlockTransferSrcScalarPerVector << ", "
315  << CShuffleMXdlPerWavePerShuffle << ", "
316  << CShuffleNXdlPerWavePerShuffle
317  << ">"
318  << " LoopScheduler: "
319  << LoopSchedToString[LoopSched] << ", "
320  << "PipelineVersion: "
321  << PipelineVersionToString[PipelineVer];
322  // clang-format on
323 
324  return str.str();
325  }
326 };
327 
328 } // namespace device
329 } // namespace tensor_operation
330 } // namespace ck
#define INVOKER_RUN3_IMPL
Definition: device_base.hpp:114
#define GET_NXDL_PER_WAVE_IMPL
Definition: device_base.hpp:81
float launch_and_time_kernel(const StreamConfig &stream_config, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args)
Definition: kernel_launch.hpp:14
__host__ constexpr __device__ T max(T x)
Definition: math.hpp:84
std::string getGemmSpecializationString(const GemmSpecialization &s)
Definition: gemm_specialization.hpp:32
GemmSpecialization
Definition: gemm_specialization.hpp:11
Definition: ck.hpp:268
constexpr Tuple< Args &... > tie(Args &... args) noexcept
Definition: tuple.hpp:218
constexpr __device__ index_t get_warp_size()
Definition: get_id.hpp:10
LoopScheduler
Definition: loop_scheduler.hpp:15
int32_t index_t
Definition: ck.hpp:299
PipelineVersion
Definition: gridwise_gemm_pipeline_selector.hpp:18
constexpr LoopScheduler make_default_loop_scheduler()
Definition: loop_scheduler.hpp:20
Definition: stream_config.hpp:10
Definition: gridwise_gemm_xdl_cshuffle_v2.hpp:508
Definition: gridwise_gemm_xdl_cshuffle_v2.hpp:126
static constexpr __host__ bool CheckValidity(const Problem &problem)
Definition: gridwise_gemm_xdl_cshuffle_v2.hpp:609
Definition: integral_constant.hpp:20
Definition: device_base.hpp:197
Definition: device_base.hpp:208
Definition: device_gemm_xdl_cshuffle_v2.hpp:148
INVOKER_RUN3_IMPL float Run(const BaseArgument *p_arg, const StreamConfig &stream_config=StreamConfig{}) override
Definition: device_gemm_xdl_cshuffle_v2.hpp:188
float RunImp(const typename GridwiseGemm::Argument &arg, const StreamConfig &stream_config=StreamConfig{})
Definition: device_gemm_xdl_cshuffle_v2.hpp:150
Definition: device_gemm_xdl_cshuffle_v2.hpp:80
static auto MakeArgument(const ADataType *p_a, const BDataType *p_b, CDataType *p_c, index_t M, index_t N, index_t K, index_t StrideA, index_t StrideB, index_t StrideC, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation)
Definition: device_gemm_xdl_cshuffle_v2.hpp:239
std::unique_ptr< BaseArgument > MakeArgumentPointer(const void *p_a, const void *p_b, void *p_c, index_t M, index_t N, index_t K, index_t StrideA, index_t StrideB, index_t StrideC, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation) override
Definition: device_gemm_xdl_cshuffle_v2.hpp:258
static constexpr GET_NXDL_PER_WAVE_IMPL auto NXdlPerWave64
Definition: device_gemm_xdl_cshuffle_v2.hpp:83
bool IsSupportedArgument(const BaseArgument *p_arg) override
Definition: device_gemm_xdl_cshuffle_v2.hpp:234
typename GridwiseGemm64::Argument Argument
Definition: device_gemm_xdl_cshuffle_v2.hpp:144
static constexpr auto I0
Definition: device_gemm_xdl_cshuffle_v2.hpp:86
std::string GetTypeString() const override
Definition: device_gemm_xdl_cshuffle_v2.hpp:289
static constexpr auto I1
Definition: device_gemm_xdl_cshuffle_v2.hpp:87
static constexpr bool IsValidCompilationParameter()
Definition: device_gemm_xdl_cshuffle_v2.hpp:195
static auto MakeInvoker()
Definition: device_gemm_xdl_cshuffle_v2.hpp:255
static bool IsSupportedArgument(const Argument &arg)
Definition: device_gemm_xdl_cshuffle_v2.hpp:201
std::unique_ptr< BaseInvoker > MakeInvokerPointer() override
Definition: device_gemm_xdl_cshuffle_v2.hpp:283
static constexpr auto I2
Definition: device_gemm_xdl_cshuffle_v2.hpp:88
static constexpr auto NXdlPerWave32
Definition: device_gemm_xdl_cshuffle_v2.hpp:84
Definition: device_gemm.hpp:22