/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_streamk.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_streamk.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_streamk.hpp Source File
device_gemm_xdl_streamk.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include <iostream>
7 #include <sstream>
8 
19 
20 namespace ck {
21 namespace tensor_operation {
22 namespace device {
23 
24 template <typename ADataType,
25  typename BDataType,
26  typename CDataType,
27  typename AccDataType,
28  typename ALayout,
29  typename BLayout,
30  typename CLayout,
31  typename AElementwiseOperation,
32  typename BElementwiseOperation,
33  typename CElementwiseOperation,
34  ck::index_t BlockSize,
35  ck::index_t MPerBlock,
36  ck::index_t NPerBlock,
37  ck::index_t K0PerBlock,
38  ck::index_t K1,
39  ck::index_t MPerXDL,
40  ck::index_t NPerXDL,
41  ck::index_t MXdlPerWave,
42  ck::index_t NXdlPerWave,
43  typename ABlockTransferThreadClusterLengths_K0_M_K1,
44  typename ABlockTransferThreadClusterArrangeOrder,
45  typename ABlockTransferSrcAccessOrder,
46  ck::index_t ABlockTransferSrcVectorDim,
47  ck::index_t ABlockTransferSrcScalarPerVector,
48  ck::index_t ABlockTransferDstScalarPerVector_K1,
49  ck::index_t ABlockLdsAddExtraM,
50  typename BBlockTransferThreadClusterLengths_K0_N_K1,
51  typename BBlockTransferThreadClusterArrangeOrder,
52  typename BBlockTransferSrcAccessOrder,
53  ck::index_t BBlockTransferSrcVectorDim,
54  ck::index_t BBlockTransferSrcScalarPerVector,
55  ck::index_t BBlockTransferDstScalarPerVector_K1,
56  ck::index_t BBlockLdsAddExtraN,
57  index_t CShuffleMRepeatPerShuffle,
58  index_t CShuffleNRepeatPerShuffle,
59  typename CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
60  index_t CBlockTransferScalarPerVector_NWaveNPerXDL>
61 struct DeviceGemmXdlStreamK : public DeviceGemmStreamK<ALayout,
62  BLayout,
63  CLayout,
64  ADataType,
65  BDataType,
66  CDataType,
67  AElementwiseOperation,
68  BElementwiseOperation,
69  CElementwiseOperation>
70 {
71  static constexpr auto I0 = Number<0>{};
72  static constexpr auto I1 = Number<1>{};
73  static constexpr auto I2 = Number<2>{};
74  static constexpr auto I3 = Number<3>{};
75 
77  BlockSize,
79  NPerBlock,
80  K0PerBlock * K1,
82  ADataType, // TODO: distinguish A/B datatype
83  AccDataType,
84  CDataType,
85  ALayout,
86  BLayout,
87  CLayout,
88  AElementwiseOperation,
89  BElementwiseOperation,
90  CElementwiseOperation,
91  MPerBlock,
92  NPerBlock,
93  K0PerBlock,
94  MPerXDL,
95  NPerXDL,
96  K1,
97  MXdlPerWave,
98  NXdlPerWave,
99  ABlockTransferThreadClusterLengths_K0_M_K1,
100  ABlockTransferThreadClusterArrangeOrder,
101  ABlockTransferSrcAccessOrder,
102  ABlockTransferSrcVectorDim,
103  ABlockTransferSrcScalarPerVector,
104  ABlockTransferDstScalarPerVector_K1,
105  false, // AThreadTransferSrcResetCoordinateAfterRun,
106  ABlockLdsAddExtraM,
107  BBlockTransferThreadClusterLengths_K0_N_K1,
108  BBlockTransferThreadClusterArrangeOrder,
109  BBlockTransferSrcAccessOrder,
110  BBlockTransferSrcVectorDim,
111  BBlockTransferSrcScalarPerVector,
112  BBlockTransferDstScalarPerVector_K1,
113  false, // BThreadTransferSrcResetCoordinateAfterRun,
114  BBlockLdsAddExtraN,
115  CShuffleMRepeatPerShuffle,
116  CShuffleNRepeatPerShuffle,
117  CBlockTransferScalarPerVector_NWaveNPerXDL,
118  CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock>;
119 
121 
122  // Invoker
123  struct Invoker : public BaseInvoker
124  {
125  void Print(const Argument& karg) { karg.Print(); }
126 
127  float Run(const Argument& karg, const StreamConfig& stream_config = StreamConfig{})
128  {
129  if(stream_config.log_level_ > 0)
130  {
131  Print(karg);
132  }
133  if(!GridwiseGemm::CheckValidity(karg))
134  {
135  throw std::runtime_error(
136  "wrong! GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 has invalid "
137  "setting");
138  }
139 
140  dim3 grid_dims = karg.block_mapping.get_grid_dims();
141 
142  float ave_time = 0;
143 
144  const auto kernel = kernel_gemm_xdlops_streamk<GridwiseGemm>;
145 
146  // TODO: remove clear buffer for streamk kernels
147  if constexpr(GridwiseGemm::Block2CTileMap::ReductionStrategy ==
149  {
150  hipGetErrorString(hipMemsetAsync(karg.p_c_grid,
151  0,
152  karg.M * karg.N * sizeof(CDataType),
153  stream_config.stream_id_));
154  ave_time = launch_and_time_kernel(stream_config,
155  kernel,
156  grid_dims,
157  dim3(BlockSize),
158  0,
159  karg.p_a_grid,
160  karg.p_b_grid,
161  karg.p_c_grid,
162  karg.p_workspace_,
163  karg.M,
164  karg.N,
165  karg.K,
166  karg.StrideA,
167  karg.StrideB,
168  karg.StrideC,
169  karg.block_mapping);
170  }
171  else if constexpr(GridwiseGemm::Block2CTileMap::ReductionStrategy ==
173  {
174  char* workspace_semaphore = reinterpret_cast<char*>(karg.p_workspace_) +
175  karg.block_mapping.get_workspace_size_for_acc(
176  sizeof(typename GridwiseGemm::FloatAcc));
177  auto preprocess = [&]() {
178  hipGetErrorString(
179  hipMemsetAsync(workspace_semaphore,
180  0,
181  karg.block_mapping.get_workspace_size_for_semaphore(),
182  stream_config.stream_id_));
183  };
184 
185  ave_time = launch_and_time_kernel_with_preprocess(stream_config,
186  preprocess,
187  kernel,
188  grid_dims,
189  dim3(BlockSize),
190  0,
191  karg.p_a_grid,
192  karg.p_b_grid,
193  karg.p_c_grid,
194  karg.p_workspace_,
195  karg.M,
196  karg.N,
197  karg.K,
198  karg.StrideA,
199  karg.StrideB,
200  karg.StrideC,
201  karg.block_mapping);
202  }
203 
204  return ave_time;
205  }
206 
207  // polymorphic
208  float Run(const BaseArgument* p_arg,
209  const StreamConfig& stream_config = StreamConfig{}) override
210  {
211  return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
212  }
213  };
214 
215  size_t GetWorkSpaceSize(const BaseArgument* pArg) const override
216  {
217  const Argument* p_arg = dynamic_cast<const Argument*>(pArg);
218  if constexpr(GridwiseGemm::Block2CTileMap::ReductionStrategy ==
220  {
221  return p_arg->block_mapping.get_workspace_size(sizeof(typename GridwiseGemm::FloatAcc));
222  }
223  else
224  {
225  return 0;
226  }
227  }
228 
230  void* p_workspace,
231  const StreamConfig& = StreamConfig{}) const override
232  {
233  Argument* pArg_ = dynamic_cast<Argument*>(pArg);
234 
235  pArg_->p_workspace_ = p_workspace;
236  }
237 
238  static constexpr bool IsValidCompilationParameter()
239  {
240  // TODO: properly implement this check
241  return true;
242  }
243 
244  static bool IsSupportedArgument(const Argument& karg)
245  {
246  if(!(ck::is_xdl_supported()))
247  {
248  return false;
249  }
250  return GridwiseGemm::CheckValidity(karg);
251  }
252 
253  // polymorphic
254  bool IsSupportedArgument(const BaseArgument* p_arg) override
255  {
256  return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
257  }
258 
259  static auto MakeArgument(const ADataType* p_a,
260  const BDataType* p_b,
261  CDataType* p_c,
262  index_t M,
263  index_t N,
264  index_t K,
265  index_t StrideA,
266  index_t StrideB,
267  index_t StrideC,
268  AElementwiseOperation,
269  BElementwiseOperation,
270  CElementwiseOperation,
271  uint32_t NumSKBlocks = 0xffffffff)
272  {
273  const auto kernel = kernel_gemm_xdlops_streamk<GridwiseGemm>;
274  int occupancy, num_cu;
275  hipError_t rtn;
276  rtn = hipOccupancyMaxActiveBlocksPerMultiprocessor(
277  &occupancy, kernel, BlockSize, GridwiseGemm::GetSharedMemoryNumberOfByte());
278  hip_check_error(rtn);
279 
280  hipDeviceProp_t dev_prop;
281  hipDevice_t dev;
282  rtn = hipGetDevice(&dev);
283  hip_check_error(rtn);
284  rtn = hipGetDeviceProperties(&dev_prop, dev);
285  hip_check_error(rtn);
286  num_cu = dev_prop.multiProcessorCount;
287 
288  return Argument{p_a,
289  p_b,
290  p_c,
291  M,
292  N,
293  K,
294  StrideA,
295  StrideB,
296  StrideC,
297  static_cast<uint32_t>(num_cu),
298  static_cast<uint32_t>(occupancy),
299  NumSKBlocks};
300  }
301 
302  static auto MakeInvoker() { return Invoker{}; }
303 
304  // polymorphic
305  std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a,
306  const void* p_b,
307  void* p_c,
308  index_t M,
309  index_t N,
310  index_t K,
311  index_t StrideA,
312  index_t StrideB,
313  index_t StrideC,
314  AElementwiseOperation,
315  BElementwiseOperation,
316  CElementwiseOperation,
317  index_t NumSKBlocks = 0) override
318  {
319  const auto kernel = kernel_gemm_xdlops_streamk<GridwiseGemm>;
320  int occupancy, num_cu;
321  hipError_t rtn;
322  rtn = hipOccupancyMaxActiveBlocksPerMultiprocessor(
323  &occupancy, kernel, BlockSize, GridwiseGemm::GetSharedMemoryNumberOfByte());
324  hip_check_error(rtn);
325 
326  hipDeviceProp_t dev_prop;
327  hipDevice_t dev;
328  rtn = hipGetDevice(&dev);
329  hip_check_error(rtn);
330  rtn = hipGetDeviceProperties(&dev_prop, dev);
331  hip_check_error(rtn);
332  num_cu = dev_prop.multiProcessorCount;
333 
334  return std::make_unique<Argument>(reinterpret_cast<const ADataType*>(p_a),
335  reinterpret_cast<const BDataType*>(p_b),
336  reinterpret_cast<CDataType*>(p_c),
337  M,
338  N,
339  K,
340  StrideA,
341  StrideB,
342  StrideC,
343  static_cast<uint32_t>(num_cu),
344  static_cast<uint32_t>(occupancy),
345  static_cast<uint32_t>(NumSKBlocks));
346  }
347 
348  // polymorphic
349  std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
350  {
351  return std::make_unique<Invoker>(Invoker{});
352  }
353 
354  // polymorphic
355  std::string GetTypeString() const override { return GridwiseGemm::GetTypeString(); }
356 };
357 
358 } // namespace device
359 } // namespace tensor_operation
360 } // namespace ck
void hip_check_error(hipError_t x)
Definition: hip_check_error.hpp:10
float launch_and_time_kernel(const StreamConfig &stream_config, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args)
Definition: kernel_launch.hpp:14
float launch_and_time_kernel_with_preprocess(const StreamConfig &stream_config, PreProcessFunc preprocess, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args)
Definition: kernel_launch.hpp:91
Definition: ck.hpp:267
@ Atomic
Definition: block_to_ctile_map.hpp:1011
@ Reduction
Definition: block_to_ctile_map.hpp:1012
bool is_xdl_supported()
Definition: device_prop.hpp:68
int32_t index_t
Definition: ck.hpp:298
unsigned int uint32_t
Definition: stdint.h:126
Definition: stream_config.hpp:10
Definition: block_to_ctile_map.hpp:1021
Definition: gridwise_gemm_xdlops_streamk.hpp:137
Definition: gridwise_gemm_xdlops_streamk.hpp:112
__host__ static constexpr __device__ bool CheckValidity(const Argument &karg)
Definition: gridwise_gemm_xdlops_streamk.hpp:308
static std::string GetTypeString()
Definition: gridwise_gemm_xdlops_streamk.hpp:1156
FloatAcc_ FloatAcc
Definition: gridwise_gemm_xdlops_streamk.hpp:129
__host__ static constexpr __device__ index_t GetSharedMemoryNumberOfByte()
Definition: gridwise_gemm_xdlops_streamk.hpp:286
Definition: integral_constant.hpp:20
Definition: device_base.hpp:51
void * p_workspace_
Definition: device_base.hpp:58
Definition: device_base.hpp:62
Definition: device_gemm_streamk.hpp:25
Definition: device_gemm_xdl_streamk.hpp:124
void Print(const Argument &karg)
Definition: device_gemm_xdl_streamk.hpp:125
float Run(const BaseArgument *p_arg, const StreamConfig &stream_config=StreamConfig{}) override
Definition: device_gemm_xdl_streamk.hpp:208
float Run(const Argument &karg, const StreamConfig &stream_config=StreamConfig{})
Definition: device_gemm_xdl_streamk.hpp:127
Definition: device_gemm_xdl_streamk.hpp:70
static constexpr bool IsValidCompilationParameter()
Definition: device_gemm_xdl_streamk.hpp:238
static constexpr auto I3
Definition: device_gemm_xdl_streamk.hpp:74
static auto MakeArgument(const ADataType *p_a, const BDataType *p_b, CDataType *p_c, index_t M, index_t N, index_t K, index_t StrideA, index_t StrideB, index_t StrideC, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, uint32_t NumSKBlocks=0xffffffff)
Definition: device_gemm_xdl_streamk.hpp:259
bool IsSupportedArgument(const BaseArgument *p_arg) override
Definition: device_gemm_xdl_streamk.hpp:254
static auto MakeInvoker()
Definition: device_gemm_xdl_streamk.hpp:302
typename GridwiseGemm::Argument Argument
Definition: device_gemm_xdl_streamk.hpp:120
std::string GetTypeString() const override
Definition: device_gemm_xdl_streamk.hpp:355
std::unique_ptr< BaseInvoker > MakeInvokerPointer() override
Definition: device_gemm_xdl_streamk.hpp:349
static constexpr auto I2
Definition: device_gemm_xdl_streamk.hpp:73
static bool IsSupportedArgument(const Argument &karg)
Definition: device_gemm_xdl_streamk.hpp:244
static constexpr auto I0
Definition: device_gemm_xdl_streamk.hpp:71
void SetWorkSpacePointer(BaseArgument *pArg, void *p_workspace, const StreamConfig &=StreamConfig{}) const override
Definition: device_gemm_xdl_streamk.hpp:229
size_t GetWorkSpaceSize(const BaseArgument *pArg) const override
Definition: device_gemm_xdl_streamk.hpp:215
static constexpr auto I1
Definition: device_gemm_xdl_streamk.hpp:72
std::unique_ptr< BaseArgument > MakeArgumentPointer(const void *p_a, const void *p_b, void *p_c, index_t M, index_t N, index_t K, index_t StrideA, index_t StrideB, index_t StrideC, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, index_t NumSKBlocks=0) override
Definition: device_gemm_xdl_streamk.hpp:305