/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_xdl.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_xdl.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_xdl.hpp Source File
device_batched_gemm_xdl.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include <iostream>
7 #include <sstream>
8 
18 
19 namespace ck {
20 namespace tensor_operation {
21 namespace device {
22 
23 /*
24  * \brief Wrapper function of GridwiseGemm::Run to realize BatchedGEMM.
25  *
26  * \tparam ComputePtrOffsetOfBatch Class that computes the base pointer offsets of A, B, C matrix
27  * given the batch. For example, ComputePtrOffsetOfStridedBatch() computes the offsets of evenly
28  * strided batched, but we can easily extend to other layouts. The returned offset can be either \p
29  * index_t or \p long_index_t. If it returns \p long_index_t, we are not subject to the 2GB
30  * limitations.
31  *
32  * \tparam Block2CTileMap Block2CTileMap::CalculateBottomIndex() takes in id of a workgroup and
33  * returns the 2D index of the tile that it computes. \see
34  * GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3::Run().
35  *
36  * \note Using \p ComputePtrOffsetOfBatch gives us the flexibility that 2 workgroups can compute 2
37  * tiles from different matrices. Keep in mind that these 2 matrices can share the same grid
38  * descriptor (like in BatchedGEMM), or use their own grid descriptors (in GroupedGemm). \link
39  * device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp kernel_gemm_xdlops_v2r3_for_conv3d \endlink for \link
40  * DeviceConv3d \endlink uses the same concept, but currently does NOT encapsulate the computing of
41  * pointer offset into \p ComputePtrOffsetOfStridedBatch.
42  *
43  * \note \p Block2CTileMap allows customized mapping between a workgroup and the C-tile it computes.
44  * Together with \p ComputePtrOffsetOfBatch, we can reuse GridwiseGemm (and GridwiseGemm fusion ) to
45  * realize BatchedGemm and GroupedGemm (and the corresponding GEMM fusion).
46  *
47  */
48 template <typename DeviceOp, typename GridwiseGemm, bool HasMainKBlockLoop>
49 __global__ void
50 #if CK_USE_LAUNCH_BOUNDS
52 #endif
53  kernel_batched_gemm_xdlops_v2r3(const typename DeviceOp::Argument karg)
54 {
55 #if defined(__gfx9__) || defined(__gfx11__) || defined(__gfx12__)
56  if constexpr(GridwiseGemm::template IsValidCompilationParameter<>())
57  {
58  const index_t num_blocks_per_batch =
59  __builtin_amdgcn_readfirstlane(get_grid_size() / karg.Batch);
60  const index_t g_idx =
61  __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch);
62 
63  const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane(
64  static_cast<long_index_t>(karg.compute_ptr_offset_of_batch.GetAPtrOffset(g_idx)));
65  const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane(
66  static_cast<long_index_t>(karg.compute_ptr_offset_of_batch.GetBPtrOffset(g_idx)));
67  const long_index_t c_batch_offset = __builtin_amdgcn_readfirstlane(
68  static_cast<long_index_t>(karg.compute_ptr_offset_of_batch.GetCPtrOffset(g_idx)));
69 
70  __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
71 
72  const auto a_grid_desc_k0_m_k1 =
73  amd_wave_read_first_lane(GridwiseGemm::MakeAGridDescriptor_K0_M_K1(
74  karg.M, karg.MPadded, karg.K, karg.K0, karg.StrideA));
75  const auto b_grid_desc_k0_n_k1 =
76  amd_wave_read_first_lane(GridwiseGemm::MakeBGridDescriptor_K0_N_K1(
77  karg.K, karg.N, karg.NPadded, karg.K0, karg.StrideB));
78  const auto c_grid_desc_m_n = amd_wave_read_first_lane(GridwiseGemm::MakeCGridDescriptor_M_N(
79  karg.M, karg.MPadded, karg.N, karg.NPadded, karg.StrideC));
80 
81  GridwiseGemm::template Run<HasMainKBlockLoop>(karg.p_a_grid + a_batch_offset,
82  karg.p_b_grid + b_batch_offset,
83  karg.p_c_grid + c_batch_offset,
84  p_shared,
85  a_grid_desc_k0_m_k1,
86  b_grid_desc_k0_n_k1,
87  c_grid_desc_m_n);
88  }
89 #else
90  ignore = karg;
91 #endif
92 }
93 
94 template <typename ADataType,
95  typename BDataType,
96  typename CDataType,
97  typename AccDataType,
98  typename ALayout,
99  typename BLayout,
100  typename CLayout,
101  typename AElementwiseOperation,
102  typename BElementwiseOperation,
103  typename CElementwiseOperation,
104  ck::index_t BlockSize,
105  ck::index_t MPerBlock,
106  ck::index_t NPerBlock,
107  ck::index_t K0PerBlock,
108  ck::index_t K1,
109  ck::index_t MPerXDL,
110  ck::index_t NPerXDL,
111  ck::index_t MXdlPerWave,
112  ck::index_t NXdlPerWave,
113  typename ABlockTransferThreadClusterLengths_K0_M_K1,
114  typename ABlockTransferThreadClusterArrangeOrder,
115  typename ABlockTransferSrcAccessOrder,
116  ck::index_t ABlockTransferSrcVectorDim,
117  ck::index_t ABlockTransferSrcScalarPerVector,
118  ck::index_t ABlockTransferDstScalarPerVector_K1,
119  bool ABlockLdsAddExtraM,
120  typename BBlockTransferThreadClusterLengths_K0_N_K1,
121  typename BBlockTransferThreadClusterArrangeOrder,
122  typename BBlockTransferSrcAccessOrder,
123  ck::index_t BBlockTransferSrcVectorDim,
124  ck::index_t BBlockTransferSrcScalarPerVector,
125  ck::index_t BBlockTransferDstScalarPerVector_K1,
126  bool BBlockLdsAddExtraN,
127  ck::index_t CThreadTransferSrcDstVectorDim,
128  ck::index_t CThreadTransferDstScalarPerVector,
129  ck::index_t NumGemmKPrefetchStage = 1,
132 struct DeviceBatchedGemmXdl : public DeviceBatchedGemm<ALayout,
133  BLayout,
134  CLayout,
135  ADataType,
136  BDataType,
137  CDataType,
138  AElementwiseOperation,
139  BElementwiseOperation,
140  CElementwiseOperation>
141 {
143  static constexpr auto NXdlPerWave64 = GetNXdlPerWave<true>();
144  static constexpr auto NXdlPerWave32 = GetNXdlPerWave<false>();
145 
146  static constexpr auto I0 = Number<0>{};
147  static constexpr auto I1 = Number<1>{};
148  static constexpr auto I2 = Number<2>{};
149 
150  static constexpr auto K1Number = Number<K1>{};
151 
153  {
155  index_t BatchStrideB,
156  index_t BatchStrideC)
157  : BatchStrideA_(BatchStrideA), BatchStrideB_(BatchStrideB), BatchStrideC_(BatchStrideC)
158  {
159  }
160 
161  __host__ __device__ constexpr long_index_t GetAPtrOffset(index_t g_idx) const
162  {
163  return g_idx * static_cast<long_index_t>(BatchStrideA_);
164  }
165 
166  __host__ __device__ constexpr long_index_t GetBPtrOffset(index_t g_idx) const
167  {
168  return g_idx * static_cast<long_index_t>(BatchStrideB_);
169  }
170 
171  __host__ __device__ constexpr long_index_t GetCPtrOffset(index_t g_idx) const
172  {
173  return g_idx * static_cast<long_index_t>(BatchStrideC_);
174  }
175 
176  private:
177  index_t BatchStrideA_;
178  index_t BatchStrideB_;
179  index_t BatchStrideC_;
180  };
181 
182  // GridwiseGemm
183  template <index_t NXdlPerWave_>
185  BlockSize,
186  ADataType, // TODO: distinguish A/B datatype
187  AccDataType,
188  CDataType,
190  ALayout,
191  BLayout,
192  CLayout,
193  AElementwiseOperation,
194  BElementwiseOperation,
195  CElementwiseOperation,
197  MPerBlock,
198  NPerBlock,
199  K0PerBlock,
200  MPerXDL,
201  NPerXDL,
202  K1,
203  MXdlPerWave,
204  NXdlPerWave_,
205  ABlockTransferThreadClusterLengths_K0_M_K1,
206  ABlockTransferThreadClusterArrangeOrder,
207  ABlockTransferSrcAccessOrder,
208  ABlockTransferSrcVectorDim,
209  ABlockTransferSrcScalarPerVector,
210  ABlockTransferDstScalarPerVector_K1,
211  false, // AThreadTransferSrcResetCoordinateAfterRun,
212  ABlockLdsAddExtraM,
213  BBlockTransferThreadClusterLengths_K0_N_K1,
214  BBlockTransferThreadClusterArrangeOrder,
215  BBlockTransferSrcAccessOrder,
216  BBlockTransferSrcVectorDim,
217  BBlockTransferSrcScalarPerVector,
218  BBlockTransferDstScalarPerVector_K1,
219  false, // BThreadTransferSrcResetCoordinateAfterRun,
220  BBlockLdsAddExtraN,
222  CThreadTransferSrcDstVectorDim,
223  CThreadTransferDstScalarPerVector,
224  NumGemmKPrefetchStage,
225  LoopSched,
226  PipelineVer>;
229 
230  using Problem = typename GridwiseGemm64::Problem;
231 
232  // Argument
233  struct Argument : public Problem, public BaseArgument
234  {
235  Argument(const ADataType* p_a_grid_,
236  const BDataType* p_b_grid_,
237  CDataType* p_c_grid_,
238  index_t M_,
239  index_t N_,
240  index_t K_,
241  index_t StrideA_,
242  index_t StrideB_,
243  index_t StrideC_,
244  index_t BatchStrideA,
245  index_t BatchStrideB,
246  index_t BatchStrideC,
247  index_t Batch_)
248  : Problem{M_, N_, K_, StrideA_, StrideB_, StrideC_},
249  p_a_grid{p_a_grid_},
250  p_b_grid{p_b_grid_},
251  p_c_grid{p_c_grid_},
252  Batch(Batch_),
253  compute_ptr_offset_of_batch{BatchStrideA, BatchStrideB, BatchStrideC}
254  {
255  }
256 
257  const ADataType* p_a_grid;
258  const BDataType* p_b_grid;
259  CDataType* p_c_grid;
262  };
263 
264  // Invoker
265  struct Invoker : public BaseInvoker
266  {
268 
269  template <typename GridwiseGemm>
270  float RunImp(const Argument& karg, const StreamConfig& stream_config = StreamConfig{})
271  {
272  if(stream_config.log_level_ > 0)
273  {
274  karg.Print();
275  }
276 
277  typename GridwiseGemm::Problem arg(
278  karg.M, karg.N, karg.K, karg.StrideA, karg.StrideB, karg.StrideC);
279  if(!GridwiseGemm::CheckValidity(arg))
280  {
281  throw std::runtime_error(
282  "wrong! GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3_ext has invalid setting");
283  }
284 
285  auto [gdx, gdy, gdz] = GridwiseGemm::CalculateGridSize(karg.M, karg.N);
286  gdx *= karg.Batch;
287 
288  float ave_time = 0;
289 
290  if(GridwiseGemm::CalculateHasMainKBlockLoop(karg.K))
291  {
292  const auto kernel =
293  kernel_batched_gemm_xdlops_v2r3<DeviceBatchedGemmXdl, GridwiseGemm, true>;
294 
295  ave_time = launch_and_time_kernel(
296  stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, karg);
297  }
298  else
299  {
300  const auto kernel =
301  kernel_batched_gemm_xdlops_v2r3<DeviceBatchedGemmXdl, GridwiseGemm, false>;
302 
303  ave_time = launch_and_time_kernel(
304  stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, karg);
305  }
306 
307  return ave_time;
308  }
309 
311 
312  // polymorphic
313  float Run(const BaseArgument* p_arg,
314  const StreamConfig& stream_config = StreamConfig{}) override
315  {
316  return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
317  }
318  };
319 
320  static constexpr bool IsValidCompilationParameter()
321  {
322  // TODO: properly implement this check
323  return true;
324  }
325 
326  static bool IsSupportedArgument(const Problem& problem)
327  {
328  if(!ck::is_xdl_wmma_supported<ADataType, BDataType, MPerXDL, NPerXDL>())
329  {
330  return false;
331  }
332  // temp disable on gfx11
334  {
335  return false;
336  }
337  if(get_warp_size() == 64)
338  {
339  if constexpr(NXdlPerWave64 > 0)
340  {
341  return GridwiseGemm64::CheckValidity(problem);
342  }
343  }
344  else
345  {
346  if constexpr(NXdlPerWave32 > 0)
347  {
349  reinterpret_cast<const typename GridwiseGemm32::Problem&>(problem));
350  }
351  }
352  return false;
353  }
354 
355  // polymorphic
356  bool IsSupportedArgument(const BaseArgument* p_arg) override
357  {
358  return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
359  }
360 
361  static auto MakeArgument(const ADataType* p_a,
362  const BDataType* p_b,
363  CDataType* p_c,
364  index_t M,
365  index_t N,
366  index_t K,
367  index_t StrideA,
368  index_t StrideB,
369  index_t StrideC,
370  index_t BatchStrideA,
371  index_t BatchStrideB,
372  index_t BatchStrideC,
373  index_t Batch)
374  {
375  return Argument{p_a,
376  p_b,
377  p_c,
378  M,
379  N,
380  K,
381  StrideA,
382  StrideB,
383  StrideC,
384  BatchStrideA,
385  BatchStrideB,
386  BatchStrideC,
387  Batch};
388  }
389 
390  static auto MakeInvoker() { return Invoker{}; }
391 
392  // polymorphic
393  std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a,
394  const void* p_b,
395  void* p_c,
396  index_t M,
397  index_t N,
398  index_t K,
399  index_t StrideA,
400  index_t StrideB,
401  index_t StrideC,
402  index_t BatchStrideA,
403  index_t BatchStrideB,
404  index_t BatchStrideC,
405  index_t Batch,
406  AElementwiseOperation,
407  BElementwiseOperation,
408  CElementwiseOperation) override
409  {
410  return std::make_unique<Argument>(static_cast<const ADataType*>(p_a),
411  static_cast<const BDataType*>(p_b),
412  static_cast<CDataType*>(p_c),
413  M,
414  N,
415  K,
416  StrideA,
417  StrideB,
418  StrideC,
419  BatchStrideA,
420  BatchStrideB,
421  BatchStrideC,
422  Batch);
423  }
424 
425  // polymorphic
426  std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
427  {
428  return std::make_unique<Invoker>(Invoker{});
429  }
430 
431  // polymorphic
432  std::string GetTypeString() const override
433  {
434  auto str = std::stringstream();
435 
436  std::map<LoopScheduler, std::string> LoopSchedToString{
437  {LoopScheduler::Default, "Default"}, {LoopScheduler::Interwave, "Interwave"}};
438 
439  std::map<PipelineVersion, std::string> PipelineVersionToString{{PipelineVersion::v1, "v1"},
440  {PipelineVersion::v2, "v2"}};
441 
442  // clang-format off
443  str << "DeviceBatchedGemmXdl"
444  << "<"
445  << BlockSize << ", "
446  << MPerBlock << ", "
447  << NPerBlock << ", "
448  << K0PerBlock << ", "
449  << K1 << ", "
450  << MPerXDL << ", "
451  << NPerXDL << ", "
452  << MXdlPerWave << ", "
453  << NXdlPerWave << ", "
454  << ">"
455  << " NumGemmKPrefetchStage: "
456  << NumGemmKPrefetchStage << ", "
457  << "LoopScheduler: "
458  << LoopSchedToString[LoopSched] << ", "
459  << "PipelineVersion: "
460  << PipelineVersionToString[PipelineVer];
461  // clang-format on
462 
463  return str.str();
464  }
465 };
466 
467 } // namespace device
468 } // namespace tensor_operation
469 } // namespace ck
#define CK_MIN_BLOCK_PER_CU
Definition: ck.hpp:31
#define CK_MAX_THREAD_PER_BLOCK
Definition: ck.hpp:30
#define GET_NXDL_PER_WAVE_IMPL
Definition: device_base.hpp:81
#define INVOKER_RUN_IMPL
Definition: device_base.hpp:94
float launch_and_time_kernel(const StreamConfig &stream_config, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args)
Definition: kernel_launch.hpp:14
__host__ constexpr __device__ T max(T x)
Definition: math.hpp:84
__global__ void kernel_batched_gemm_xdlops_v2r3(const typename DeviceOp::Argument karg)
Definition: device_batched_gemm_xdl.hpp:53
Definition: ck.hpp:268
__device__ index_t get_grid_size()
Definition: get_id.hpp:49
int64_t long_index_t
Definition: ck.hpp:300
constexpr detail::ignore_t ignore
Definition: ignore.hpp:20
__device__ uint32_t amd_wave_read_first_lane(uint32_t value)
Definition: amd_wave_read_first_lane.hpp:100
__device__ index_t get_block_1d_id()
Definition: get_id.hpp:47
constexpr __device__ index_t get_warp_size()
Definition: get_id.hpp:10
LoopScheduler
Definition: loop_scheduler.hpp:15
int32_t index_t
Definition: ck.hpp:299
PipelineVersion
Definition: gridwise_gemm_pipeline_selector.hpp:18
bool is_gfx11_supported()
Definition: device_prop.hpp:60
constexpr LoopScheduler make_default_loop_scheduler()
Definition: loop_scheduler.hpp:20
Definition: stream_config.hpp:10
Definition: gridwise_gemm_xdlops_v2r3.hpp:814
static constexpr __host__ bool CheckValidity(const Problem &problem)
Definition: gridwise_gemm_xdlops_v2r3.hpp:1003
Definition: sequence.hpp:43
Definition: integral_constant.hpp:20
Definition: device_base.hpp:197
Definition: device_base.hpp:208
Definition: device_batched_gemm.hpp:25
Definition: device_batched_gemm_xdl.hpp:234
ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch
Definition: device_batched_gemm_xdl.hpp:261
const BDataType * p_b_grid
Definition: device_batched_gemm_xdl.hpp:258
const ADataType * p_a_grid
Definition: device_batched_gemm_xdl.hpp:257
Argument(const ADataType *p_a_grid_, const BDataType *p_b_grid_, CDataType *p_c_grid_, index_t M_, index_t N_, index_t K_, index_t StrideA_, index_t StrideB_, index_t StrideC_, index_t BatchStrideA, index_t BatchStrideB, index_t BatchStrideC, index_t Batch_)
Definition: device_batched_gemm_xdl.hpp:235
index_t Batch
Definition: device_batched_gemm_xdl.hpp:260
CDataType * p_c_grid
Definition: device_batched_gemm_xdl.hpp:259
__host__ constexpr __device__ long_index_t GetBPtrOffset(index_t g_idx) const
Definition: device_batched_gemm_xdl.hpp:166
__host__ constexpr __device__ long_index_t GetCPtrOffset(index_t g_idx) const
Definition: device_batched_gemm_xdl.hpp:171
ComputePtrOffsetOfStridedBatch(index_t BatchStrideA, index_t BatchStrideB, index_t BatchStrideC)
Definition: device_batched_gemm_xdl.hpp:154
__host__ constexpr __device__ long_index_t GetAPtrOffset(index_t g_idx) const
Definition: device_batched_gemm_xdl.hpp:161
Definition: device_batched_gemm_xdl.hpp:266
INVOKER_RUN_IMPL float Run(const BaseArgument *p_arg, const StreamConfig &stream_config=StreamConfig{}) override
Definition: device_batched_gemm_xdl.hpp:313
float RunImp(const Argument &karg, const StreamConfig &stream_config=StreamConfig{})
Definition: device_batched_gemm_xdl.hpp:270
Definition: device_batched_gemm_xdl.hpp:141
bool IsSupportedArgument(const BaseArgument *p_arg) override
Definition: device_batched_gemm_xdl.hpp:356
static constexpr auto I0
Definition: device_batched_gemm_xdl.hpp:146
std::unique_ptr< BaseArgument > MakeArgumentPointer(const void *p_a, const void *p_b, void *p_c, index_t M, index_t N, index_t K, index_t StrideA, index_t StrideB, index_t StrideC, index_t BatchStrideA, index_t BatchStrideB, index_t BatchStrideC, index_t Batch, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation) override
Definition: device_batched_gemm_xdl.hpp:393
std::unique_ptr< BaseInvoker > MakeInvokerPointer() override
Definition: device_batched_gemm_xdl.hpp:426
static auto MakeInvoker()
Definition: device_batched_gemm_xdl.hpp:390
std::string GetTypeString() const override
Definition: device_batched_gemm_xdl.hpp:432
typename GridwiseGemm64::Problem Problem
Definition: device_batched_gemm_xdl.hpp:230
static constexpr GET_NXDL_PER_WAVE_IMPL auto NXdlPerWave64
Definition: device_batched_gemm_xdl.hpp:143
static auto MakeArgument(const ADataType *p_a, const BDataType *p_b, CDataType *p_c, index_t M, index_t N, index_t K, index_t StrideA, index_t StrideB, index_t StrideC, index_t BatchStrideA, index_t BatchStrideB, index_t BatchStrideC, index_t Batch)
Definition: device_batched_gemm_xdl.hpp:361
static constexpr auto I1
Definition: device_batched_gemm_xdl.hpp:147
static constexpr auto NXdlPerWave32
Definition: device_batched_gemm_xdl.hpp:144
static bool IsSupportedArgument(const Problem &problem)
Definition: device_batched_gemm_xdl.hpp:326
static constexpr auto K1Number
Definition: device_batched_gemm_xdl.hpp:150
static constexpr auto I2
Definition: device_batched_gemm_xdl.hpp:148
static constexpr bool IsValidCompilationParameter()
Definition: device_batched_gemm_xdl.hpp:320