/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_streamk.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_streamk.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_streamk.hpp Source File
device_gemm_xdl_streamk.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include <iostream>
7 #include <sstream>
8 
19 
20 namespace ck {
21 namespace tensor_operation {
22 namespace device {
23 
24 template <typename ADataType,
25  typename BDataType,
26  typename CDataType,
27  typename AccDataType,
28  typename ALayout,
29  typename BLayout,
30  typename CLayout,
31  typename AElementwiseOperation,
32  typename BElementwiseOperation,
33  typename CElementwiseOperation,
34  ck::index_t BlockSize,
35  ck::index_t MPerBlock,
36  ck::index_t NPerBlock,
37  ck::index_t K0PerBlock,
38  ck::index_t K1,
39  ck::index_t MPerXDL,
40  ck::index_t NPerXDL,
41  ck::index_t MXdlPerWave,
42  ck::index_t NXdlPerWave,
43  typename ABlockTransferThreadClusterLengths_K0_M_K1,
44  typename ABlockTransferThreadClusterArrangeOrder,
45  typename ABlockTransferSrcAccessOrder,
46  ck::index_t ABlockTransferSrcVectorDim,
47  ck::index_t ABlockTransferSrcScalarPerVector,
48  ck::index_t ABlockTransferDstScalarPerVector_K1,
49  ck::index_t ABlockLdsAddExtraM,
50  typename BBlockTransferThreadClusterLengths_K0_N_K1,
51  typename BBlockTransferThreadClusterArrangeOrder,
52  typename BBlockTransferSrcAccessOrder,
53  ck::index_t BBlockTransferSrcVectorDim,
54  ck::index_t BBlockTransferSrcScalarPerVector,
55  ck::index_t BBlockTransferDstScalarPerVector_K1,
56  ck::index_t BBlockLdsAddExtraN,
57  index_t CShuffleMRepeatPerShuffle,
58  index_t CShuffleNRepeatPerShuffle,
59  typename CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
60  index_t CBlockTransferScalarPerVector_NWaveNPerXDL>
61 struct DeviceGemmXdlStreamK : public DeviceGemmStreamK<ALayout,
62  BLayout,
63  CLayout,
64  ADataType,
65  BDataType,
66  CDataType,
67  AElementwiseOperation,
68  BElementwiseOperation,
69  CElementwiseOperation>
70 {
72  static constexpr auto NXdlPerWave64 = GetNXdlPerWave<true>();
73  static constexpr auto NXdlPerWave32 = GetNXdlPerWave<false>();
74 
75  static constexpr auto I0 = Number<0>{};
76  static constexpr auto I1 = Number<1>{};
77  static constexpr auto I2 = Number<2>{};
78  static constexpr auto I3 = Number<3>{};
79 
80  template <index_t NXdlPerWave_>
82  BlockSize,
84  NPerBlock,
85  K0PerBlock * K1,
87  ADataType, // TODO: distinguish A/B datatype
88  AccDataType,
89  CDataType,
90  ALayout,
91  BLayout,
92  CLayout,
93  AElementwiseOperation,
94  BElementwiseOperation,
95  CElementwiseOperation,
96  MPerBlock,
97  NPerBlock,
98  K0PerBlock,
99  MPerXDL,
100  NPerXDL,
101  K1,
102  MXdlPerWave,
103  NXdlPerWave_,
104  ABlockTransferThreadClusterLengths_K0_M_K1,
105  ABlockTransferThreadClusterArrangeOrder,
106  ABlockTransferSrcAccessOrder,
107  ABlockTransferSrcVectorDim,
108  ABlockTransferSrcScalarPerVector,
109  ABlockTransferDstScalarPerVector_K1,
110  false, // AThreadTransferSrcResetCoordinateAfterRun,
111  ABlockLdsAddExtraM,
112  BBlockTransferThreadClusterLengths_K0_N_K1,
113  BBlockTransferThreadClusterArrangeOrder,
114  BBlockTransferSrcAccessOrder,
115  BBlockTransferSrcVectorDim,
116  BBlockTransferSrcScalarPerVector,
117  BBlockTransferDstScalarPerVector_K1,
118  false, // BThreadTransferSrcResetCoordinateAfterRun,
119  BBlockLdsAddExtraN,
120  CShuffleMRepeatPerShuffle,
121  CShuffleNRepeatPerShuffle,
122  CBlockTransferScalarPerVector_NWaveNPerXDL,
123  CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock>;
126 
128 
129  // Invoker
130  struct Invoker : public BaseInvoker
131  {
132  template <typename Argument_>
133  void Print(const Argument_& karg)
134  {
135  karg.Print();
136  }
137 
138  template <typename GridwiseGemm>
139  float RunImp(const typename GridwiseGemm::Argument& karg,
140  const StreamConfig& stream_config = StreamConfig{})
141  {
142  if(stream_config.log_level_ > 0)
143  {
144  Print(karg);
145  }
146  if(!GridwiseGemm::CheckValidity(karg))
147  {
148  throw std::runtime_error(
149  "wrong! GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 has invalid "
150  "setting");
151  }
152 
153  dim3 grid_dims = karg.block_mapping.get_grid_dims();
154 
155  float ave_time = 0;
156 
157  const auto kernel = kernel_gemm_xdlops_streamk<GridwiseGemm>;
158 
159  // TODO: remove clear buffer for streamk kernels
160  if constexpr(GridwiseGemm::Block2CTileMap::ReductionStrategy ==
162  {
163  hipGetErrorString(hipMemsetAsync(karg.p_c_grid,
164  0,
165  karg.M * karg.N * sizeof(CDataType),
166  stream_config.stream_id_));
167  ave_time = launch_and_time_kernel(stream_config,
168  kernel,
169  grid_dims,
170  dim3(BlockSize),
171  0,
172  karg.p_a_grid,
173  karg.p_b_grid,
174  karg.p_c_grid,
175  karg.p_workspace_,
176  karg.M,
177  karg.N,
178  karg.K,
179  karg.StrideA,
180  karg.StrideB,
181  karg.StrideC,
182  karg.block_mapping);
183  }
184  else if constexpr(GridwiseGemm::Block2CTileMap::ReductionStrategy ==
186  {
187  char* workspace_semaphore = reinterpret_cast<char*>(karg.p_workspace_) +
188  karg.block_mapping.get_workspace_size_for_acc(
189  sizeof(typename GridwiseGemm::FloatAcc));
190  auto preprocess = [&]() {
191  hipGetErrorString(
192  hipMemsetAsync(workspace_semaphore,
193  0,
194  karg.block_mapping.get_workspace_size_for_semaphore(),
195  stream_config.stream_id_));
196  };
197 
198  ave_time = launch_and_time_kernel_with_preprocess(stream_config,
199  preprocess,
200  kernel,
201  grid_dims,
202  dim3(BlockSize),
203  0,
204  karg.p_a_grid,
205  karg.p_b_grid,
206  karg.p_c_grid,
207  karg.p_workspace_,
208  karg.M,
209  karg.N,
210  karg.K,
211  karg.StrideA,
212  karg.StrideB,
213  karg.StrideC,
214  karg.block_mapping);
215  }
216 
217  return ave_time;
218  }
219 
221 
222  // polymorphic
223  float Run(const BaseArgument* p_arg,
224  const StreamConfig& stream_config = StreamConfig{}) override
225  {
226  return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
227  }
228  };
229 
230  size_t GetWorkSpaceSize(const BaseArgument* pArg) const override
231  {
232  const Argument* p_arg = dynamic_cast<const Argument*>(pArg);
233  if(get_warp_size() == 64)
234  {
235  if constexpr(GridwiseGemm64::Block2CTileMap::ReductionStrategy ==
237  {
238  return p_arg->block_mapping.get_workspace_size(
239  sizeof(typename GridwiseGemm64::FloatAcc));
240  }
241  }
242  else
243  {
244  if constexpr(GridwiseGemm32::Block2CTileMap::ReductionStrategy ==
246  {
247  return p_arg->block_mapping.get_workspace_size(
248  sizeof(typename GridwiseGemm32::FloatAcc));
249  }
250  }
251  return 0;
252  }
253 
255  void* p_workspace,
256  const StreamConfig& = StreamConfig{}) const override
257  {
258  Argument* pArg_ = dynamic_cast<Argument*>(pArg);
259 
260  pArg_->p_workspace_ = p_workspace;
261  }
262 
263  static constexpr bool IsValidCompilationParameter()
264  {
265  // TODO: properly implement this check
266  return true;
267  }
268 
269  static bool IsSupportedArgument(const Argument& karg)
270  {
271  if(!ck::is_xdl_wmma_supported<ADataType, BDataType, MPerXDL, NPerXDL>())
272  {
273  return false;
274  }
275  if(get_warp_size() == 64)
276  {
277  if constexpr(NXdlPerWave64 > 0)
278  {
279  return GridwiseGemm64::CheckValidity(karg);
280  }
281  }
282  else
283  {
284  if constexpr(NXdlPerWave32 > 0)
285  {
287  reinterpret_cast<const typename GridwiseGemm32::Argument&>(karg));
288  }
289  }
290  return false;
291  }
292 
293  // polymorphic
294  bool IsSupportedArgument(const BaseArgument* p_arg) override
295  {
296  return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
297  }
298 
299  static auto MakeArgument(const ADataType* p_a,
300  const BDataType* p_b,
301  CDataType* p_c,
302  index_t M,
303  index_t N,
304  index_t K,
305  index_t StrideA,
306  index_t StrideB,
307  index_t StrideC,
308  AElementwiseOperation,
309  BElementwiseOperation,
310  CElementwiseOperation,
311  uint32_t NumSKBlocks = 0xffffffff)
312  {
313  int num_cu;
314  hipError_t rtn;
315  int occupancy = [&]() {
316  int occupancy_ = 0;
317  if(get_warp_size() == 64)
318  {
319  if constexpr(NXdlPerWave64 > 0)
320  {
321  const auto kernel = kernel_gemm_xdlops_streamk<GridwiseGemm64>;
322  rtn = hipOccupancyMaxActiveBlocksPerMultiprocessor(
323  &occupancy_,
324  kernel,
325  BlockSize,
327  hip_check_error(rtn);
328  }
329  }
330  else
331  {
332  if constexpr(NXdlPerWave32 > 0)
333  {
334  const auto kernel = kernel_gemm_xdlops_streamk<GridwiseGemm32>;
335  rtn = hipOccupancyMaxActiveBlocksPerMultiprocessor(
336  &occupancy_,
337  kernel,
338  BlockSize,
340  hip_check_error(rtn);
341  }
342  }
343  return occupancy_;
344  }();
345 
346  hipDeviceProp_t dev_prop;
347  hipDevice_t dev;
348  rtn = hipGetDevice(&dev);
349  hip_check_error(rtn);
350  rtn = hipGetDeviceProperties(&dev_prop, dev);
351  hip_check_error(rtn);
352  num_cu = dev_prop.multiProcessorCount;
353 
354  return Argument{p_a,
355  p_b,
356  p_c,
357  M,
358  N,
359  K,
360  StrideA,
361  StrideB,
362  StrideC,
363  static_cast<uint32_t>(num_cu),
364  static_cast<uint32_t>(occupancy),
365  NumSKBlocks};
366  }
367 
368  static auto MakeInvoker() { return Invoker{}; }
369 
370  // polymorphic
371  std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a,
372  const void* p_b,
373  void* p_c,
374  index_t M,
375  index_t N,
376  index_t K,
377  index_t StrideA,
378  index_t StrideB,
379  index_t StrideC,
380  AElementwiseOperation,
381  BElementwiseOperation,
382  CElementwiseOperation,
383  index_t NumSKBlocks = 0) override
384  {
385  int num_cu;
386  hipError_t rtn;
387 
388  int occupancy = [&]() {
389  int occupancy_ = 0;
390  if(get_warp_size() == 64)
391  {
392  if constexpr(NXdlPerWave64 > 0)
393  {
394  const auto kernel = kernel_gemm_xdlops_streamk<GridwiseGemm64>;
395  rtn = hipOccupancyMaxActiveBlocksPerMultiprocessor(
396  &occupancy_,
397  kernel,
398  BlockSize,
400  hip_check_error(rtn);
401  }
402  }
403  else
404  {
405  if constexpr(NXdlPerWave32 > 0)
406  {
407  const auto kernel = kernel_gemm_xdlops_streamk<GridwiseGemm32>;
408  rtn = hipOccupancyMaxActiveBlocksPerMultiprocessor(
409  &occupancy_,
410  kernel,
411  BlockSize,
413  hip_check_error(rtn);
414  }
415  }
416  return occupancy_;
417  }();
418 
419  hipDeviceProp_t dev_prop;
420  hipDevice_t dev;
421  rtn = hipGetDevice(&dev);
422  hip_check_error(rtn);
423  rtn = hipGetDeviceProperties(&dev_prop, dev);
424  hip_check_error(rtn);
425  num_cu = dev_prop.multiProcessorCount;
426 
427  return std::make_unique<Argument>(reinterpret_cast<const ADataType*>(p_a),
428  reinterpret_cast<const BDataType*>(p_b),
429  reinterpret_cast<CDataType*>(p_c),
430  M,
431  N,
432  K,
433  StrideA,
434  StrideB,
435  StrideC,
436  static_cast<uint32_t>(num_cu),
437  static_cast<uint32_t>(occupancy),
438  static_cast<uint32_t>(NumSKBlocks));
439  }
440 
441  // polymorphic
442  std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
443  {
444  return std::make_unique<Invoker>(Invoker{});
445  }
446 
447  // polymorphic
448  std::string GetTypeString() const override
449  {
452  }
453 };
454 
455 } // namespace device
456 } // namespace tensor_operation
457 } // namespace ck
#define INVOKER_RUN3_IMPL
Definition: device_base.hpp:114
#define GET_NXDL_PER_WAVE_IMPL
Definition: device_base.hpp:81
void hip_check_error(hipError_t x)
Definition: hip_check_error.hpp:10
float launch_and_time_kernel(const StreamConfig &stream_config, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args)
Definition: kernel_launch.hpp:14
float launch_and_time_kernel_with_preprocess(const StreamConfig &stream_config, PreProcessFunc preprocess, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args)
Definition: kernel_launch.hpp:91
__host__ constexpr __device__ T max(T x)
Definition: math.hpp:84
Definition: ck.hpp:268
@ Atomic
Definition: block_to_ctile_map.hpp:1012
@ Reduction
Definition: block_to_ctile_map.hpp:1013
constexpr __device__ index_t get_warp_size()
Definition: get_id.hpp:10
int32_t index_t
Definition: ck.hpp:299
unsigned int uint32_t
Definition: stdint.h:126
Definition: stream_config.hpp:10
Definition: block_to_ctile_map.hpp:1022
Definition: gridwise_gemm_xdlops_streamk.hpp:140
Definition: gridwise_gemm_xdlops_streamk.hpp:115
__host__ static constexpr __device__ bool CheckValidity(const Argument &karg)
Definition: gridwise_gemm_xdlops_streamk.hpp:315
__host__ static constexpr __device__ index_t GetSharedMemoryNumberOfByte()
Definition: gridwise_gemm_xdlops_streamk.hpp:289
static std::string GetTypeString()
Definition: gridwise_gemm_xdlops_streamk.hpp:1163
FloatAcc_ FloatAcc
Definition: gridwise_gemm_xdlops_streamk.hpp:132
Definition: integral_constant.hpp:20
Definition: device_base.hpp:197
void * p_workspace_
Definition: device_base.hpp:204
Definition: device_base.hpp:208
Definition: device_gemm_streamk.hpp:25
Definition: device_gemm_xdl_streamk.hpp:131
void Print(const Argument_ &karg)
Definition: device_gemm_xdl_streamk.hpp:133
INVOKER_RUN3_IMPL float Run(const BaseArgument *p_arg, const StreamConfig &stream_config=StreamConfig{}) override
Definition: device_gemm_xdl_streamk.hpp:223
float RunImp(const typename GridwiseGemm::Argument &karg, const StreamConfig &stream_config=StreamConfig{})
Definition: device_gemm_xdl_streamk.hpp:139
Definition: device_gemm_xdl_streamk.hpp:70
static constexpr bool IsValidCompilationParameter()
Definition: device_gemm_xdl_streamk.hpp:263
static constexpr auto I3
Definition: device_gemm_xdl_streamk.hpp:78
static auto MakeArgument(const ADataType *p_a, const BDataType *p_b, CDataType *p_c, index_t M, index_t N, index_t K, index_t StrideA, index_t StrideB, index_t StrideC, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, uint32_t NumSKBlocks=0xffffffff)
Definition: device_gemm_xdl_streamk.hpp:299
bool IsSupportedArgument(const BaseArgument *p_arg) override
Definition: device_gemm_xdl_streamk.hpp:294
static auto MakeInvoker()
Definition: device_gemm_xdl_streamk.hpp:368
std::string GetTypeString() const override
Definition: device_gemm_xdl_streamk.hpp:448
std::unique_ptr< BaseInvoker > MakeInvokerPointer() override
Definition: device_gemm_xdl_streamk.hpp:442
static constexpr GET_NXDL_PER_WAVE_IMPL auto NXdlPerWave64
Definition: device_gemm_xdl_streamk.hpp:72
typename GridwiseGemm64::Argument Argument
Definition: device_gemm_xdl_streamk.hpp:127
static constexpr auto I2
Definition: device_gemm_xdl_streamk.hpp:77
static bool IsSupportedArgument(const Argument &karg)
Definition: device_gemm_xdl_streamk.hpp:269
static constexpr auto I0
Definition: device_gemm_xdl_streamk.hpp:75
void SetWorkSpacePointer(BaseArgument *pArg, void *p_workspace, const StreamConfig &=StreamConfig{}) const override
Definition: device_gemm_xdl_streamk.hpp:254
static constexpr auto NXdlPerWave32
Definition: device_gemm_xdl_streamk.hpp:73
size_t GetWorkSpaceSize(const BaseArgument *pArg) const override
Definition: device_gemm_xdl_streamk.hpp:230
static constexpr auto I1
Definition: device_gemm_xdl_streamk.hpp:76
std::unique_ptr< BaseArgument > MakeArgumentPointer(const void *p_a, const void *p_b, void *p_c, index_t M, index_t N, index_t K, index_t StrideA, index_t StrideB, index_t StrideC, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, index_t NumSKBlocks=0) override
Definition: device_gemm_xdl_streamk.hpp:371