/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3.hpp Source File
device_gemm_wmma_cshuffle_v3.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include <iostream>
7 #include <sstream>
8 
20 
21 namespace ck {
22 namespace tensor_operation {
23 namespace device {
24 
124 template <typename ALayout,
125  typename BLayout,
126  typename CLayout,
127  typename ADataType,
128  typename BDataType,
129  typename CDataType,
130  typename AccDataType,
131  typename CShuffleDataType,
132  typename AElementwiseOperation,
133  typename BElementwiseOperation,
134  typename CElementwiseOperation,
135  GemmSpecialization GemmSpec,
136  index_t BlockSize,
137  index_t MPerBlock,
138  index_t NPerBlock,
139  index_t KPerBlock,
140  index_t AK1,
141  index_t BK1,
142  index_t MPerWmma,
143  index_t NPerWmma,
144  index_t MRepeat,
145  index_t NRepeat,
146  typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
147  typename ABlockTransferThreadClusterArrangeOrder,
148  typename ABlockTransferSrcAccessOrder,
149  index_t ABlockTransferSrcVectorDim,
150  index_t ABlockTransferSrcScalarPerVector,
151  index_t ABlockTransferDstScalarPerVector_AK1,
152  bool ABlockLdsExtraM,
153  typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
154  typename BBlockTransferThreadClusterArrangeOrder,
155  typename BBlockTransferSrcAccessOrder,
156  index_t BBlockTransferSrcVectorDim,
157  index_t BBlockTransferSrcScalarPerVector,
158  index_t BBlockTransferDstScalarPerVector_BK1,
159  bool BBlockLdsExtraN,
160  index_t CShuffleMRepeatPerShuffle,
161  index_t CShuffleNRepeatPerShuffle,
162  typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
163  index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
166  typename ComputeTypeA = CDataType,
167  typename ComputeTypeB = ComputeTypeA,
168  bool PermuteA = false,
169  bool PermuteB = false>
170 struct DeviceGemm_Wmma_CShuffleV3 : public DeviceGemmV2<ALayout,
171  BLayout,
172  CLayout,
173  ADataType,
174  BDataType,
175  CDataType,
176  AElementwiseOperation,
177  BElementwiseOperation,
178  CElementwiseOperation>
179 {
180  // GridwiseGemm
182  ALayout,
183  BLayout,
184  CLayout,
185  ADataType,
186  BDataType,
187  AccDataType,
188  CShuffleDataType,
189  CDataType,
190  AElementwiseOperation,
191  BElementwiseOperation,
192  CElementwiseOperation,
193  GemmSpec,
194  BlockSize,
195  MPerBlock,
196  NPerBlock,
197  KPerBlock,
198  AK1,
199  BK1,
200  MPerWmma,
201  NPerWmma,
202  MRepeat,
203  NRepeat,
204  ABlockTransferThreadClusterLengths_AK0_M_AK1,
205  ABlockTransferThreadClusterArrangeOrder,
206  ABlockTransferSrcAccessOrder,
207  ABlockTransferSrcVectorDim,
208  ABlockTransferSrcScalarPerVector,
209  ABlockTransferDstScalarPerVector_AK1,
210  false,
211  ABlockLdsExtraM,
212  BBlockTransferThreadClusterLengths_BK0_N_BK1,
213  BBlockTransferThreadClusterArrangeOrder,
214  BBlockTransferSrcAccessOrder,
215  BBlockTransferSrcVectorDim,
216  BBlockTransferSrcScalarPerVector,
217  BBlockTransferDstScalarPerVector_BK1,
218  false,
219  BBlockLdsExtraN,
220  CShuffleMRepeatPerShuffle,
221  CShuffleNRepeatPerShuffle,
222  CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
223  CShuffleBlockTransferScalarPerVector_NPerBlock,
224  BlkGemmPipeSched,
225  BlkGemmPipelineVer,
226  ComputeTypeA,
227  ComputeTypeB,
228  PermuteA,
229  PermuteB>;
230 
232 
234  ADataType,
235  BDataType,
236  CDataType,
237  MPerBlock,
238  NPerBlock,
239  KPerBlock,
240  BlockSize,
241  AK1,
242  BK1,
243  GemmSpec,
244  BlkGemmPipeSched,
245  BlkGemmPipelineVer,
246  ComputeTypeA,
247  ComputeTypeB>;
248 
249  // Invoker
251 
252  static bool IsSupportedArgument(const Argument& arg)
253  {
255  }
256 
257  // polymorphic
258  bool IsSupportedArgument(const BaseArgument* p_arg) override
259  {
260  return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
261  }
262 
263  index_t GetKPerBlock() override { return KPerBlock; }
264 
265  bool GetPermuteA() override { return PermuteA; }
266  bool GetPermuteB() override { return PermuteB; }
267 
268  static auto MakeArgument(const ADataType* p_a,
269  const BDataType* p_b,
270  CDataType* p_c,
271  index_t M,
272  index_t N,
273  index_t K,
274  index_t StrideA,
275  index_t StrideB,
276  index_t StrideC,
277  index_t KBatch,
278  AElementwiseOperation,
279  BElementwiseOperation,
280  CElementwiseOperation)
281  {
282  return Argument{p_a, p_b, p_c, M, N, K, StrideA, StrideB, StrideC, KBatch};
283  }
284 
285  static auto MakeInvoker() { return Invoker{}; }
286 
287  // polymorphic
288  std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a,
289  const void* p_b,
290  void* p_c,
291  index_t M,
292  index_t N,
293  index_t K,
294  index_t StrideA,
295  index_t StrideB,
296  index_t StrideC,
297  index_t KBatch,
298  AElementwiseOperation,
299  BElementwiseOperation,
300  CElementwiseOperation) override
301  {
302  return std::make_unique<Argument>(static_cast<const ADataType*>(p_a),
303  static_cast<const BDataType*>(p_b),
304  static_cast<CDataType*>(p_c),
305  M,
306  N,
307  K,
308  StrideA,
309  StrideB,
310  StrideC,
311  KBatch);
312  }
313 
314  // polymorphic
315  std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
316  {
317  return std::make_unique<Invoker>(Invoker{});
318  }
319 
320  // polymorphic
321  std::string GetTypeString() const override
322  {
323  auto str = std::stringstream();
324 
325  std::map<BlockGemmPipelineScheduler, std::string> BlkGemmPipelineSchedulerToString{
328 
329  std::map<BlockGemmPipelineVersion, std::string> BlkGemmPipelineVersionToString{
335 
336  // clang-format off
337  str << "DeviceGemm_Wmma_CShuffleV3"
338  << "<"
339  << getGemmSpecializationString(GemmSpec) << ", "
340  << std::string(ALayout::name)[0]
341  << std::string(BLayout::name)[0]
342  << std::string(CLayout::name)[0]
343  << ">"
344  << " BlkSize: "
345  << BlockSize << ", "
346  << "BlkTile: "
347  << MPerBlock << "x" << NPerBlock << "x" << KPerBlock << ", "
348  << "WaveTile: "
349  << MPerWmma << "x"<<NPerWmma << ", "
350  << "WaveMap: "
351  << MRepeat << "x" << NRepeat << ", "
352  << "VmemReadVec: "
353  << ABlockTransferSrcScalarPerVector << "x" << BBlockTransferSrcScalarPerVector << ", "
354  << "BlkGemmPipelineScheduler: "
355  << BlkGemmPipelineSchedulerToString[BlkGemmPipeSched] << ", "
356  << "BlkGemmPipelineVersion: "
357  << BlkGemmPipelineVersionToString[BlkGemmPipelineVer] << ", "
358  << "BlkGemmPipelinePrefetchStages: "
359  << GridwiseGemm::BlockwiseGemmPipe::PrefetchStages << ", "
360  << "KPack: "
362  // clang-format on
363 
364  return str.str();
365  }
367 };
368 
369 } // namespace device
370 } // namespace tensor_operation
371 } // namespace ck
#define REGISTER_EXTRA_PRINTING_METHODS
Definition: device_base.hpp:46
std::string getGemmSpecializationString(const GemmSpecialization &s)
Definition: gemm_specialization.hpp:32
GemmSpecialization
Definition: gemm_specialization.hpp:11
Definition: ck.hpp:267
BlockGemmPipelineVersion
Definition: blkgemmpipe_scheduler.hpp:12
BlockGemmPipelineScheduler
Definition: blkgemmpipe_scheduler.hpp:25
int32_t index_t
Definition: ck.hpp:298
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:367
static constexpr index_t KPack
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:121
"Universal" GEMM kernel with SplitK support.
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:222
Definition: device_base.hpp:51
Helper structure responsible for kernel invocation.
Definition: device_gemm_wmma_cshuffle_v3_common.hpp:54
Definition: device_gemm_wmma_cshuffle_v3_common.hpp:40
static bool IsSupportedArgument(const Argument &arg)
Definition: device_gemm_wmma_cshuffle_v3_common.hpp:225
"Universal" GEMM operation with SplitK support.
Definition: device_gemm_wmma_cshuffle_v3.hpp:179
std::string GetTypeString() const override
Definition: device_gemm_wmma_cshuffle_v3.hpp:321
typename DeviceGemmCommon::Invoker Invoker
Definition: device_gemm_wmma_cshuffle_v3.hpp:250
GridwiseGemm_wmma_cshuffle_v3< ALayout, BLayout, CLayout, ADataType, BDataType, AccDataType, CShuffleDataType, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerWmma, NPerWmma, MRepeat, NRepeat, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMRepeatPerShuffle, CShuffleNRepeatPerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB, PermuteA, PermuteB > GridwiseGemm
Definition: device_gemm_wmma_cshuffle_v3.hpp:229
static auto MakeInvoker()
Definition: device_gemm_wmma_cshuffle_v3.hpp:285
bool IsSupportedArgument(const BaseArgument *p_arg) override
Definition: device_gemm_wmma_cshuffle_v3.hpp:258
bool GetPermuteA() override
Definition: device_gemm_wmma_cshuffle_v3.hpp:265
typename GridwiseGemm::Argument Argument
Definition: device_gemm_wmma_cshuffle_v3.hpp:231
std::unique_ptr< BaseInvoker > MakeInvokerPointer() override
Definition: device_gemm_wmma_cshuffle_v3.hpp:315
std::unique_ptr< BaseArgument > MakeArgumentPointer(const void *p_a, const void *p_b, void *p_c, index_t M, index_t N, index_t K, index_t StrideA, index_t StrideB, index_t StrideC, index_t KBatch, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation) override
Definition: device_gemm_wmma_cshuffle_v3.hpp:288
bool GetPermuteB() override
Definition: device_gemm_wmma_cshuffle_v3.hpp:266
static auto MakeArgument(const ADataType *p_a, const BDataType *p_b, CDataType *p_c, index_t M, index_t N, index_t K, index_t StrideA, index_t StrideB, index_t StrideC, index_t KBatch, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation)
Definition: device_gemm_wmma_cshuffle_v3.hpp:268
index_t GetKPerBlock() override
Definition: device_gemm_wmma_cshuffle_v3.hpp:263
static bool IsSupportedArgument(const Argument &arg)
Definition: device_gemm_wmma_cshuffle_v3.hpp:252
Definition: device_gemm_v2.hpp:22