/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v2.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v2.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v2.hpp Source File
device_gemm_xdl_cshuffle_v2.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include <iostream>
7 #include <sstream>
8 
18 
19 namespace ck {
20 namespace tensor_operation {
21 namespace device {
22 
23 // Note: inter-wave loop scheduler is rolled out to c-shuffle version first. Becuase non c-shuffle
24 // version currently has compiler issues with register spill which further causes validation
25 // failures.
26 template <typename ALayout,
27  typename BLayout,
28  typename CLayout,
29  typename ADataType,
30  typename BDataType,
31  typename CDataType,
32  typename GemmAccDataType,
33  typename CShuffleDataType,
34  typename AElementwiseOperation,
35  typename BElementwiseOperation,
36  typename CElementwiseOperation,
37  GemmSpecialization GemmSpec,
38  index_t NumGemmKPrefetchStage,
39  index_t BlockSize,
40  index_t MPerBlock,
41  index_t NPerBlock,
42  index_t KPerBlock,
43  index_t AK1,
44  index_t BK1,
45  index_t MPerXDL,
46  index_t NPerXDL,
47  index_t MXdlPerWave,
48  index_t NXdlPerWave,
49  typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
50  typename ABlockTransferThreadClusterArrangeOrder,
51  typename ABlockTransferSrcAccessOrder,
52  index_t ABlockTransferSrcVectorDim,
53  index_t ABlockTransferSrcScalarPerVector,
54  index_t ABlockTransferDstScalarPerVector_AK1,
55  bool ABlockLdsExtraM,
56  typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
57  typename BBlockTransferThreadClusterArrangeOrder,
58  typename BBlockTransferSrcAccessOrder,
59  index_t BBlockTransferSrcVectorDim,
60  index_t BBlockTransferSrcScalarPerVector,
61  index_t BBlockTransferDstScalarPerVector_BK1,
62  bool BBlockLdsExtraN,
63  index_t CShuffleMXdlPerWavePerShuffle,
64  index_t CShuffleNXdlPerWavePerShuffle,
65  typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
66  index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
69  typename ComputeTypeA = CDataType,
70  typename ComputeTypeB = ComputeTypeA>
71 struct DeviceGemm_Xdl_CShuffleV2 : public DeviceGemm<ALayout,
72  BLayout,
73  CLayout,
74  ADataType,
75  BDataType,
76  CDataType,
77  AElementwiseOperation,
78  BElementwiseOperation,
79  CElementwiseOperation>
80 {
82 
83  static constexpr auto I0 = Number<0>{};
84  static constexpr auto I1 = Number<1>{};
85  static constexpr auto I2 = Number<2>{};
86 
87  // GridwiseGemm
89  ALayout,
90  BLayout,
91  CLayout,
92  ADataType,
93  BDataType,
94  GemmAccDataType,
95  CShuffleDataType,
96  CDataType,
97  AElementwiseOperation,
98  BElementwiseOperation,
99  CElementwiseOperation,
100  GemmSpec,
102  NumGemmKPrefetchStage,
103  BlockSize,
104  MPerBlock,
105  NPerBlock,
106  KPerBlock,
107  AK1,
108  BK1,
109  MPerXDL,
110  NPerXDL,
111  MXdlPerWave,
112  NXdlPerWave,
113  ABlockTransferThreadClusterLengths_AK0_M_AK1,
114  ABlockTransferThreadClusterArrangeOrder,
115  ABlockTransferSrcAccessOrder,
116  ABlockTransferSrcVectorDim,
117  ABlockTransferSrcScalarPerVector,
118  ABlockTransferDstScalarPerVector_AK1,
119  false,
120  ABlockLdsExtraM,
121  BBlockTransferThreadClusterLengths_BK0_N_BK1,
122  BBlockTransferThreadClusterArrangeOrder,
123  BBlockTransferSrcAccessOrder,
124  BBlockTransferSrcVectorDim,
125  BBlockTransferSrcScalarPerVector,
126  BBlockTransferDstScalarPerVector_BK1,
127  false,
128  BBlockLdsExtraN,
129  CShuffleMXdlPerWavePerShuffle,
130  CShuffleNXdlPerWavePerShuffle,
131  CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
132  CShuffleBlockTransferScalarPerVector_NPerBlock,
133  LoopSched,
134  PipelineVer,
135  ComputeTypeA,
136  ComputeTypeB>;
137 
139 
140  // Invoker
141  struct Invoker : public BaseInvoker
142  {
143  float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
144  {
145  if(stream_config.log_level_ > 0)
146  {
147  arg.Print();
148  }
149 
151  {
152  throw std::runtime_error("wrong! GridwiseGemm has invalid setting");
153  }
154 
155  index_t gdx, gdy, gdz;
156  std::tie(gdx, gdy, gdz) = GridwiseGemm::CalculateGridSize(arg.M, arg.N);
157 
158  float ave_time = 0;
159  const auto K = GridwiseGemm::CalculateAK0(arg.K) * AK1;
160 
162  {
163  const auto kernel = kernel_gemm_xdl_cshuffle_v2<GridwiseGemm, true>;
164  ave_time = launch_and_time_kernel(
165  stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, arg);
166  }
167  else
168  {
169  const auto kernel = kernel_gemm_xdl_cshuffle_v2<GridwiseGemm, true, 2>;
170  ave_time = launch_and_time_kernel(
171  stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, arg);
172  }
173 
174  return ave_time;
175  }
176 
177  // polymorphic
178  float Run(const BaseArgument* p_arg,
179  const StreamConfig& stream_config = StreamConfig{}) override
180  {
181  return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
182  }
183  };
184 
185  static constexpr bool IsValidCompilationParameter()
186  {
187  // TODO: properly implement this check
188  return true;
189  }
190 
191  static bool IsSupportedArgument(const Argument& arg)
192  {
193  if(!ck::is_xdl_supported())
194  {
195  return false;
196  }
197 
198  if((arg.K % AK1 != 0 || arg.K % BK1 != 0) && !(GemmSpec == GemmSpecialization::MKPadding ||
199  GemmSpec == GemmSpecialization::NKPadding ||
200  GemmSpec == GemmSpecialization::MNKPadding ||
201  GemmSpec == GemmSpecialization::KPadding))
202  {
203  return false;
204  }
205 
206  return GridwiseGemm::CheckValidity(arg);
207  }
208 
209  // polymorphic
210  bool IsSupportedArgument(const BaseArgument* p_arg) override
211  {
212  return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
213  }
214 
215  static auto MakeArgument(const ADataType* p_a,
216  const BDataType* p_b,
217  CDataType* p_c,
218  index_t M,
219  index_t N,
220  index_t K,
221  index_t StrideA,
222  index_t StrideB,
223  index_t StrideC,
224  AElementwiseOperation,
225  BElementwiseOperation,
226  CElementwiseOperation)
227  {
228  return Argument{p_a, p_b, p_c, M, N, K, StrideA, StrideB, StrideC};
229  }
230 
231  static auto MakeInvoker() { return Invoker{}; }
232 
233  // polymorphic
234  std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a,
235  const void* p_b,
236  void* p_c,
237  index_t M,
238  index_t N,
239  index_t K,
240  index_t StrideA,
241  index_t StrideB,
242  index_t StrideC,
243  AElementwiseOperation,
244  BElementwiseOperation,
245  CElementwiseOperation) override
246  {
247  return std::make_unique<Argument>(static_cast<const ADataType*>(p_a),
248  static_cast<const BDataType*>(p_b),
249  static_cast<CDataType*>(p_c),
250  M,
251  N,
252  K,
253  StrideA,
254  StrideB,
255  StrideC);
256  }
257 
258  // polymorphic
259  std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
260  {
261  return std::make_unique<Invoker>(Invoker{});
262  }
263 
264  // polymorphic
265  std::string GetTypeString() const override
266  {
267  auto str = std::stringstream();
268 
269  std::map<LoopScheduler, std::string> LoopSchedToString{
270  {LoopScheduler::Default, "Default"}, {LoopScheduler::Interwave, "Interwave"}};
271 
272  std::map<PipelineVersion, std::string> PipelineVersionToString{{PipelineVersion::v1, "v1"},
273  {PipelineVersion::v2, "v2"}};
274 
275  // clang-format off
276  str << "DeviceGemm_Xdl_CShuffleV2"
277  << "<"
278  << getGemmSpecializationString(GemmSpec) << ", "
279  << BlockSize << ", "
280  << MPerBlock << ", "
281  << NPerBlock << ", "
282  << KPerBlock << ", "
283  << AK1 << ", "
284  << BK1 << ", "
285  << MPerXDL << ", "
286  << NPerXDL << ", "
287  << MXdlPerWave << ", "
288  << NXdlPerWave << ", "
289  << ABlockTransferSrcScalarPerVector << ", "
290  << BBlockTransferSrcScalarPerVector << ", "
291  << CShuffleMXdlPerWavePerShuffle << ", "
292  << CShuffleNXdlPerWavePerShuffle
293  << ">"
294  << " LoopScheduler: "
295  << LoopSchedToString[LoopSched] << ", "
296  << "PipelineVersion: "
297  << PipelineVersionToString[PipelineVer];
298  // clang-format on
299 
300  return str.str();
301  }
302 };
303 
304 } // namespace device
305 } // namespace tensor_operation
306 } // namespace ck
float launch_and_time_kernel(const StreamConfig &stream_config, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args)
Definition: kernel_launch.hpp:14
std::string getGemmSpecializationString(const GemmSpecialization &s)
Definition: gemm_specialization.hpp:32
GemmSpecialization
Definition: gemm_specialization.hpp:11
Definition: ck.hpp:267
bool is_xdl_supported()
Definition: device_prop.hpp:68
constexpr Tuple< Args &... > tie(Args &... args) noexcept
Definition: tuple.hpp:218
LoopScheduler
Definition: loop_scheduler.hpp:15
int32_t index_t
Definition: ck.hpp:298
PipelineVersion
Definition: gridwise_gemm_pipeline_selector.hpp:18
constexpr LoopScheduler make_default_loop_scheduler()
Definition: loop_scheduler.hpp:20
Definition: stream_config.hpp:10
Definition: gridwise_gemm_xdl_cshuffle_v2.hpp:500
Definition: gridwise_gemm_xdl_cshuffle_v2.hpp:118
static __host__ auto CalculateGridSize(index_t M, index_t N)
Definition: gridwise_gemm_xdl_cshuffle_v2.hpp:136
static constexpr __host__ index_t CalculateKBlockLoopTailNum(index_t K)
Definition: gridwise_gemm_xdl_cshuffle_v2.hpp:696
static __host__ auto CalculateAK0(index_t K)
Definition: gridwise_gemm_xdl_cshuffle_v2.hpp:156
static constexpr __host__ bool CheckValidity(const Problem &problem)
Definition: gridwise_gemm_xdl_cshuffle_v2.hpp:585
Definition: integral_constant.hpp:20
Definition: device_base.hpp:51
Definition: device_base.hpp:62
Definition: device_gemm_xdl_cshuffle_v2.hpp:142
float Run(const Argument &arg, const StreamConfig &stream_config=StreamConfig{})
Definition: device_gemm_xdl_cshuffle_v2.hpp:143
float Run(const BaseArgument *p_arg, const StreamConfig &stream_config=StreamConfig{}) override
Definition: device_gemm_xdl_cshuffle_v2.hpp:178
Definition: device_gemm_xdl_cshuffle_v2.hpp:80
static auto MakeArgument(const ADataType *p_a, const BDataType *p_b, CDataType *p_c, index_t M, index_t N, index_t K, index_t StrideA, index_t StrideB, index_t StrideC, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation)
Definition: device_gemm_xdl_cshuffle_v2.hpp:215
std::unique_ptr< BaseArgument > MakeArgumentPointer(const void *p_a, const void *p_b, void *p_c, index_t M, index_t N, index_t K, index_t StrideA, index_t StrideB, index_t StrideC, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation) override
Definition: device_gemm_xdl_cshuffle_v2.hpp:234
bool IsSupportedArgument(const BaseArgument *p_arg) override
Definition: device_gemm_xdl_cshuffle_v2.hpp:210
static constexpr auto I0
Definition: device_gemm_xdl_cshuffle_v2.hpp:83
std::string GetTypeString() const override
Definition: device_gemm_xdl_cshuffle_v2.hpp:265
typename GridwiseGemm::Argument Argument
Definition: device_gemm_xdl_cshuffle_v2.hpp:138
static constexpr auto I1
Definition: device_gemm_xdl_cshuffle_v2.hpp:84
static constexpr bool IsValidCompilationParameter()
Definition: device_gemm_xdl_cshuffle_v2.hpp:185
static auto MakeInvoker()
Definition: device_gemm_xdl_cshuffle_v2.hpp:231
static bool IsSupportedArgument(const Argument &arg)
Definition: device_gemm_xdl_cshuffle_v2.hpp:191
std::unique_ptr< BaseInvoker > MakeInvokerPointer() override
Definition: device_gemm_xdl_cshuffle_v2.hpp:259
static constexpr auto I2
Definition: device_gemm_xdl_cshuffle_v2.hpp:85
Definition: device_gemm.hpp:22