/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3r1.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3r1.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3r1.hpp Source File
device_gemm_wmma_cshuffle_v3r1.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include <sstream>
7 #include <type_traits>
8 #include <typeinfo>
9 #include <memory>
10 #include <array>
11 #include <stdexcept>
12 
14 #include "ck/ck.hpp"
24 
28 
29 namespace ck {
30 namespace tensor_operation {
31 namespace device {
32 
33 template <typename ALayout,
34  typename BLayout,
35  typename DsLayout,
36  typename CLayout,
37  typename ADataType,
38  typename BDataType,
39  typename DsDataType,
40  typename CDataType,
41  typename GemmAccDataType,
42  typename CShuffleDataType,
43  typename AElementwiseOperation,
44  typename BElementwiseOperation,
45  typename CElementwiseOperation,
46  GemmSpecialization GemmSpec,
47  index_t BlockSize,
48  index_t MPerBlock,
49  index_t NPerBlock,
50  index_t KPerBlock,
51  index_t AK1,
52  index_t BK1,
53  index_t MPerWmma,
54  index_t NPerWmma,
55  index_t MRepeat,
56  index_t NRepeat,
57  typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
58  typename ABlockTransferThreadClusterArrangeOrder,
59  typename ABlockTransferSrcAccessOrder,
60  index_t ABlockTransferSrcVectorDim,
61  index_t ABlockTransferSrcScalarPerVector,
62  index_t ABlockTransferDstScalarPerVector_AK1,
63  bool ABlockLdsExtraM,
64  typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
65  typename BBlockTransferThreadClusterArrangeOrder,
66  typename BBlockTransferSrcAccessOrder,
67  index_t BBlockTransferSrcVectorDim,
68  index_t BBlockTransferSrcScalarPerVector,
69  index_t BBlockTransferDstScalarPerVector_BK1,
70  bool BBlockLdsExtraN,
71  index_t CShuffleMRepeatPerShuffle,
72  index_t CShuffleNRepeatPerShuffle,
73  typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
74  index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
77  typename ReduceDataType = CDataType,
78  typename ComputeTypeA = CDataType,
79  typename ComputeTypeB = ComputeTypeA>
81  BLayout,
82  DsLayout,
83  CLayout,
84  ADataType,
85  BDataType,
86  DsDataType,
87  CDataType,
88  AElementwiseOperation,
89  BElementwiseOperation,
90  CElementwiseOperation>
91 {
92  static constexpr index_t NumDTensor = DsDataType::Size();
93 
95 
97  ALayout,
98  BLayout,
99  Tuple<>,
100  CLayout,
103  GemmAccDataType,
104  ReduceDataType,
105  Tuple<>,
106  ReduceDataType,
107  AElementwiseOperation,
108  BElementwiseOperation,
109  PassThrough,
110  GemmSpec,
111  BlockSize,
112  MPerBlock,
113  NPerBlock,
114  KPerBlock,
115  AK1,
116  BK1,
117  MPerWmma,
118  NPerWmma,
119  MRepeat,
120  NRepeat,
121  ABlockTransferThreadClusterLengths_AK0_M_AK1,
122  ABlockTransferThreadClusterArrangeOrder,
123  ABlockTransferSrcAccessOrder,
124  ABlockTransferSrcVectorDim,
125  ABlockTransferSrcScalarPerVector,
126  ABlockTransferDstScalarPerVector_AK1,
127  false,
128  ABlockLdsExtraM,
129  BBlockTransferThreadClusterLengths_BK0_N_BK1,
130  BBlockTransferThreadClusterArrangeOrder,
131  BBlockTransferSrcAccessOrder,
132  BBlockTransferSrcVectorDim,
133  BBlockTransferSrcScalarPerVector,
134  BBlockTransferDstScalarPerVector_BK1,
135  false,
136  BBlockLdsExtraN,
137  CShuffleMRepeatPerShuffle,
138  CShuffleNRepeatPerShuffle,
139  CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
141  BlkGemmPipeSched,
142  BlkGemmPipelineVer,
143  ComputeTypeA,
144  ComputeTypeB,
145  false,
146  false>;
147 
149  {
150  Argument(std::array<const void*, 1> p_a_grid_,
151  std::array<const void*, 1> p_b_grid_,
152  const ::std::array<const void*, NumDTensor> p_ds_,
153  CDataType* p_c_grid_,
154  index_t M_,
155  index_t N_,
156  index_t K_,
157  std::array<index_t, 1> StrideA_,
158  std::array<index_t, 1> StrideB_,
159  const ::std::array<index_t, NumDTensor> stride_ds_,
160  index_t StrideC_,
161  index_t KBatch_,
162  AElementwiseOperation a_element_op_,
163  BElementwiseOperation b_element_op_,
164  CElementwiseOperation c_element_op_)
165  : GridwiseGemm::Argument(p_a_grid_,
166  p_b_grid_,
167  ::std::array<const void*, 0>{},
168  reinterpret_cast<ReduceDataType*>(p_c_grid_),
169  M_,
170  N_,
171  K_,
172  StrideA_,
173  StrideB_,
174  std::array<index_t, 0>{},
175  StrideC_,
176  KBatch_,
177  a_element_op_,
178  b_element_op_,
179  PassThrough{},
180  true),
181  p_c_grid(p_c_grid_),
182  c_element_op(c_element_op_),
183  p_ds(p_ds_),
184  StrideDs(stride_ds_)
185  {
186  }
187 
188  CDataType* p_c_grid;
189  CElementwiseOperation c_element_op;
190  const ::std::array<const void*, NumDTensor> p_ds;
191  ::std::array<index_t, NumDTensor> StrideDs;
192  };
193 
195  using OutElementwiseOperation = CElementwiseOperation;
196 
198  [](auto i) {
199  using DLayout = ::std::__remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
202  else
203  return Number<1>{};
204  },
206 
208  ReduceDataType, // InDataType
209  DsDataType, // DsDatatype
210  GemmAccDataType, // AccDataType
211  CDataType, // OutDataType
212  3, // Rank
213  1, // NumReduceDim
214  ReduceAdd,
215  PassThrough,
217  256, // BlockSize_
218  CShuffleBlockTransferScalarPerVector_NPerBlock, // MThreadSliceSize_
219  1, // KThreadSliceSize_
220  0, // InSrcVectorDim_
221  CShuffleBlockTransferScalarPerVector_NPerBlock, // InSrcVectorSize_
222  CShuffleBlockTransferScalarPerVector_NPerBlock, // OutDstVectorSize_
223  decltype(DsVectorLengthSequence)>;
224 
225  struct Invoker : public BaseInvoker
226  {
227  float RunReduce(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
228  {
229  static constexpr index_t NumInDim = 3;
230  static constexpr index_t NumOutDim = 2;
231 
232  ::std::array<index_t, NumInDim> in_lengths = {arg.KBatch, arg.M, arg.N};
233  ::std::array<index_t, NumOutDim> out_lengths = {arg.M, arg.N};
234 
235  ::std::array<index_t, NumInDim> in_strides;
236  ::std::array<index_t, NumOutDim> out_strides;
238  {
239  in_strides = {arg.M * arg.N, arg.N, 1};
240  out_strides = {arg.N, 1};
241  }
242  else
243  {
244  in_strides = {arg.M * arg.N, 1, arg.M};
245  out_strides = {1, arg.M};
246  }
247 
248  ::std::array<int, 1> reduce_dims{0};
249 
250  ::std::array<::std::array<index_t, NumOutDim>, NumDTensor> DsLengths;
251  ::std::array<::std::array<index_t, NumOutDim>, NumDTensor> DsStrides;
252 
253  static_for<0, NumDTensor, 1>{}([&](auto i) {
254  DsLengths[i] = out_lengths;
255 
256  using DLayout = ::std::__remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
258  {
259  DsStrides[i] = {arg.StrideDs[i], 1};
260  }
261  else
262  {
263  DsStrides[i] = {1, arg.StrideDs[i]};
264  }
265  });
266 
267  auto reduce = DeviceReduceInstance{};
268 
269  auto argument_ptr = reduce.MakeArgumentPointer(in_lengths,
270  in_strides,
271  DsLengths,
272  DsStrides,
273  out_lengths,
274  out_strides,
275  reduce_dims,
276  arg.p_workspace_,
277  arg.p_ds,
278  arg.p_c_grid,
279  PassThrough{},
281 
282  auto invoker_ptr = reduce.MakeInvokerPointer();
283 
284  float ave_time = 0;
285 
286  if(reduce.IsSupportedArgument(argument_ptr.get()))
287  {
288  ave_time = invoker_ptr->Run(argument_ptr.get(), stream_config);
289  }
290  else
291  {
292  throw ::std::runtime_error(
293  "The runtime parameters are not supported by the device instance.");
294  }
295 
296  return ave_time;
297  }
298 
299  float Run(const Argument& arg_, const StreamConfig& stream_config = StreamConfig{})
300  {
301  auto arg = *dynamic_cast<const typename GridwiseGemm::Argument*>(&arg_);
302 
303  // workspace required when doing two-kernel reduce or Ds present
304  const bool need_workspace = !(!(arg.IsReduceAdd() || NumDTensor > 0) &&
306  if(need_workspace)
307  {
308  if(arg.p_workspace_ == nullptr)
309  {
310  throw ::std::runtime_error("using reduce, but empty workspace!");
311  }
312  arg.p_e_grid = reinterpret_cast<ReduceDataType*>(arg.p_workspace_);
313  }
314 
315  if(stream_config.log_level_ > 0)
316  {
317  arg.Print();
318  }
319 
321  {
322  throw ::std::runtime_error("wrong! GridwiseGemm has invalid setting");
323  }
324 
325  index_t gdx, gdy, gdz;
326  ::std::tie(gdx, gdy, gdz) = GridwiseGemm::CalculateGridSize(arg.M, arg.N, arg.KBatch);
327 
328  float ave_time = 0;
329 
330  index_t k_grain = arg.KBatch * KPerBlock;
331  index_t K_split = (arg.K + k_grain - 1) / k_grain * KPerBlock;
332 
333  const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K_split);
334 
335  constexpr index_t minimum_occupancy =
336  BlkGemmPipeSched == BlockGemmPipelineScheduler::Intrawave ? 1 : 2;
337 
338  if(has_main_k_block_loop)
339  {
340  const auto kernel =
342  true,
344  minimum_occupancy>;
345  ave_time = launch_and_time_kernel(
346  stream_config, kernel, ::dim3(gdx, gdy, gdz), ::dim3(BlockSize), 0, arg);
347  }
348  else
349  {
350  const auto kernel =
352  false,
354  minimum_occupancy>;
355  ave_time = launch_and_time_kernel(
356  stream_config, kernel, ::dim3(gdx, gdy, gdz), ::dim3(BlockSize), 0, arg);
357  }
358 
359  if(need_workspace)
360  {
361  ave_time += RunReduce(arg_, stream_config);
362  }
363 
364  return ave_time;
365  }
366 
367  // polymorphic
368  float Run(const BaseArgument* p_arg,
369  const StreamConfig& stream_config = StreamConfig{}) override
370  {
371  return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
372  }
373  };
374 
375  static constexpr bool IsValidCompilationParameter()
376  {
377  // TODO: properly implement this
378  return true;
379  }
380 
381  static bool IsSupportedArgument(const Argument& arg)
382  {
383  if(!ck::is_wmma_supported())
384  {
385  return false;
386  }
387 
388  if((arg.K % AK1 != 0 || arg.K % BK1 != 0) && !(GemmSpec == GemmSpecialization::MKPadding ||
389  GemmSpec == GemmSpecialization::NKPadding ||
390  GemmSpec == GemmSpecialization::MNKPadding ||
391  GemmSpec == GemmSpecialization::KPadding))
392  {
393  return false;
394  }
395 
396  return GridwiseGemm::CheckValidity(
397  *dynamic_cast<const typename GridwiseGemm::Argument*>(&arg));
398  }
399 
400  bool IsSupportedArgument(const BaseArgument* p_arg) override
401  {
402  return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
403  }
404 
405  static auto CalculateGridSize(index_t M, index_t N, index_t KBatch)
406  {
407  return GridwiseGemm::CalculateGridSize(M, N, KBatch);
408  }
409 
410  static constexpr index_t GetBlockSize() { return BlockSize; }
411 
413  {
414  return GridwiseGemm::GetSharedMemoryNumberOfByte();
415  }
416 
417  static auto MakeArgument(const ADataType* p_a,
418  const BDataType* p_b,
419  const ::std::array<const void*, NumDTensor> p_ds,
420  CDataType* p_c,
421  index_t M,
422  index_t N,
423  index_t K,
424  index_t StrideA,
425  index_t StrideB,
426  const ::std::array<index_t, NumDTensor> stride_ds,
427  index_t StrideC,
428  index_t KBatch,
429  AElementwiseOperation a_element_op,
430  BElementwiseOperation b_element_op,
431  CElementwiseOperation c_element_op)
432  {
433  return Argument{std::array<const void*, 1>{p_a},
434  std::array<const void*, 1>{p_b},
435  p_ds,
436  p_c,
437  M,
438  N,
439  K,
440  std::array<index_t, 1>{StrideA},
441  std::array<index_t, 1>{StrideB},
442  stride_ds,
443  StrideC,
444  KBatch,
445  a_element_op,
446  b_element_op,
447  c_element_op};
448  }
449 
450  static auto MakeInvoker() { return Invoker{}; }
451 
452  // polymorphic
453  ::std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
454  {
455  return ::std::make_unique<Invoker>(Invoker{});
456  }
457 
458  // Polymorphic interfaces
459  ::std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a,
460  const void* p_b,
461  ::std::array<const void*, NumDTensor> p_ds,
462  void* p_c,
463  index_t M,
464  index_t N,
465  index_t K,
466  index_t StrideA,
467  index_t StrideB,
468  ::std::array<index_t, NumDTensor> DsStrides,
469  index_t StrideC,
470  index_t KSplit,
471  AElementwiseOperation a_element_op,
472  BElementwiseOperation b_element_op,
473  CElementwiseOperation c_element_op) override
474  {
475  return ::std::make_unique<Argument>(std::array<const void*, 1>{p_a},
476  std::array<const void*, 1>{p_b},
477  p_ds,
478  static_cast<CDataType*>(p_c),
479  M,
480  N,
481  K,
482  std::array<index_t, 1>{StrideA},
483  std::array<index_t, 1>{StrideB},
484  DsStrides,
485  StrideC,
486  KSplit,
487  a_element_op,
488  b_element_op,
489  c_element_op);
490  }
491 
492  ::std::string GetTypeString() const override
493  {
494  auto str = ::std::stringstream();
495 
496  auto BlkGemmPipelineSchedulerToString = [](BlockGemmPipelineScheduler s) {
497  switch(s)
498  {
499  case BlockGemmPipelineScheduler::Intrawave: return ::std::string("Intrawave");
500  case BlockGemmPipelineScheduler::Interwave: return ::std::string("Interwave");
501  }
502  return ::std::string("?");
503  };
504 
505  auto BlkGemmPipelineVersionToString = [](BlockGemmPipelineVersion v) {
506  switch(v)
507  {
508  case BlockGemmPipelineVersion::v1: return ::std::string("v1");
509  case BlockGemmPipelineVersion::v2: return ::std::string("v2");
510  case BlockGemmPipelineVersion::v3: return ::std::string("v3");
511  case BlockGemmPipelineVersion::v4: return ::std::string("v4");
512  case BlockGemmPipelineVersion::v5: return ::std::string("v5");
513  }
514  return ::std::string("v?");
515  };
516 
517  // clang-format off
518  str << "DeviceGemmWmmaUniversalReduce"
519  << "<"
520  << getGemmSpecializationString(GemmSpec) << ", "
521  << ::std::string(ALayout::name)[0]
522  << ::std::string(BLayout::name)[0]
523  << ::std::string(CLayout::name)[0]
524  << ">"
525  << " BlkSize: "
526  << BlockSize << ", "
527  << "BlkTile: "
528  << MPerBlock<<"x"<<NPerBlock<<"x"<<KPerBlock << ", "
529  << "WmmaTile: "
530  << MPerWmma<<"x"<<NPerWmma << ", "
531  << "WmmaRepeat: "
532  << MRepeat<<"x" << NRepeat<<", "
533  << "VmemReadVec: "
534  << ABlockTransferSrcScalarPerVector<<"x"<<BBlockTransferSrcScalarPerVector<<", "
535  << "BlkGemmPipelineScheduler: "
536  << BlkGemmPipelineSchedulerToString(BlkGemmPipeSched) << ", "
537  << "BlkGemmPipelineVersion: "
538  << BlkGemmPipelineVersionToString(BlkGemmPipelineVer) << ", "
539  << "BlkGemmPipelinePrefetchStages: "
540  << GridwiseGemm::BlockwiseGemmPipe::PrefetchStages;
541  // clang-format on
542 
543  return str.str();
544  }
545 
546  size_t GetWorkSpaceSize(const BaseArgument* p_arg) const override
547  {
548  auto arg = *dynamic_cast<const Argument*>(p_arg);
549 
550  // Need workspace if using split-K or have D tensors
551  if(!(!(arg.IsReduceAdd() || NumDTensor > 0) && is_same<CDataType, ReduceDataType>::value))
552  {
553  return arg.M * arg.N * arg.KBatch * sizeof(ReduceDataType);
554  }
555 
556  return 0;
557  }
558 };
559 
560 } // namespace device
561 } // namespace tensor_operation
562 } // namespace ck
float launch_and_time_kernel(const StreamConfig &stream_config, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args)
Definition: kernel_launch.hpp:14
std::string getGemmSpecializationString(const GemmSpecialization &s)
Definition: gemm_specialization.hpp:32
GemmSpecialization
Definition: gemm_specialization.hpp:11
Definition: ck.hpp:268
bool is_wmma_supported()
Definition: device_prop.hpp:127
typename tuple_element< I, TTuple >::type tuple_element_t
Definition: tuple.hpp:208
BlockGemmPipelineVersion
Definition: blkgemmpipe_scheduler.hpp:12
constexpr Tuple< Args &... > tie(Args &... args) noexcept
Definition: tuple.hpp:218
__host__ constexpr __device__ auto generate_sequence_v2(F &&f, Number< N >)
Definition: sequence_helper.hpp:25
__global__ void kernel_gemm_wmma_cshuffle_v3(typename GridwiseGemm::Argument karg)
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:35
BlockGemmPipelineScheduler
Definition: blkgemmpipe_scheduler.hpp:25
int32_t index_t
Definition: ck.hpp:299
Definition: stream_config.hpp:10
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:408
__host__ __device__ bool IsReduceAdd() const
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:460
EDataType * p_e_grid
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:473
__host__ void Print() const
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:362
index_t M
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:388
index_t K
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:390
index_t N
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:389
index_t KBatch
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:395
static __host__ auto CalculateGridSize(index_t M, index_t N, index_t KBatch)
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:170
static constexpr __host__ bool CheckValidity(const Argument &karg)
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:941
static constexpr __host__ bool CalculateHasMainKBlockLoop(index_t K)
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:1154
"Universal" GEMM kernel with SplitK support.
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:231
Definition: multi_index_transform.hpp:13
Definition: sequence.hpp:43
Definition: tuple.hpp:186
Definition: tuple.hpp:117
Definition: integral_constant.hpp:20
static constexpr bool value
Definition: integral_constant.hpp:21
Definition: type.hpp:177
Definition: reduction_operator.hpp:37
Definition: device_base.hpp:197
void * p_workspace_
Definition: device_base.hpp:204
Definition: device_base.hpp:208
Definition: device_gemm_wmma_cshuffle_v3r1.hpp:149
Argument(std::array< const void *, 1 > p_a_grid_, std::array< const void *, 1 > p_b_grid_, const ::std::array< const void *, NumDTensor > p_ds_, CDataType *p_c_grid_, index_t M_, index_t N_, index_t K_, std::array< index_t, 1 > StrideA_, std::array< index_t, 1 > StrideB_, const ::std::array< index_t, NumDTensor > stride_ds_, index_t StrideC_, index_t KBatch_, AElementwiseOperation a_element_op_, BElementwiseOperation b_element_op_, CElementwiseOperation c_element_op_)
Definition: device_gemm_wmma_cshuffle_v3r1.hpp:150
CDataType * p_c_grid
Definition: device_gemm_wmma_cshuffle_v3r1.hpp:188
const ::std::array< const void *, NumDTensor > p_ds
Definition: device_gemm_wmma_cshuffle_v3r1.hpp:190
CElementwiseOperation c_element_op
Definition: device_gemm_wmma_cshuffle_v3r1.hpp:189
::std::array< index_t, NumDTensor > StrideDs
Definition: device_gemm_wmma_cshuffle_v3r1.hpp:191
Definition: device_gemm_wmma_cshuffle_v3r1.hpp:226
float RunReduce(const Argument &arg, const StreamConfig &stream_config=StreamConfig{})
Definition: device_gemm_wmma_cshuffle_v3r1.hpp:227
float Run(const Argument &arg_, const StreamConfig &stream_config=StreamConfig{})
Definition: device_gemm_wmma_cshuffle_v3r1.hpp:299
float Run(const BaseArgument *p_arg, const StreamConfig &stream_config=StreamConfig{}) override
Definition: device_gemm_wmma_cshuffle_v3r1.hpp:368
Definition: device_gemm_wmma_cshuffle_v3r1.hpp:91
ck::tensor_operation::element_wise::PassThrough PassThrough
Definition: device_gemm_wmma_cshuffle_v3r1.hpp:94
DeviceReduceThreadWiseMultiD< ReduceDataType, DsDataType, GemmAccDataType, CDataType, 3, 1, ReduceAdd, PassThrough, OutElementwiseOperation, 256, CShuffleBlockTransferScalarPerVector_NPerBlock, 1, 0, CShuffleBlockTransferScalarPerVector_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, decltype(DsVectorLengthSequence)> DeviceReduceInstance
Definition: device_gemm_wmma_cshuffle_v3r1.hpp:223
static constexpr bool IsValidCompilationParameter()
Definition: device_gemm_wmma_cshuffle_v3r1.hpp:375
static size_t GetSharedMemoryNumberOfByte()
Definition: device_gemm_wmma_cshuffle_v3r1.hpp:412
static constexpr index_t NumDTensor
Definition: device_gemm_wmma_cshuffle_v3r1.hpp:92
static auto MakeArgument(const ADataType *p_a, const BDataType *p_b, const ::std::array< const void *, NumDTensor > p_ds, CDataType *p_c, index_t M, index_t N, index_t K, index_t StrideA, index_t StrideB, const ::std::array< index_t, NumDTensor > stride_ds, index_t StrideC, index_t KBatch, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CElementwiseOperation c_element_op)
Definition: device_gemm_wmma_cshuffle_v3r1.hpp:417
static auto CalculateGridSize(index_t M, index_t N, index_t KBatch)
Definition: device_gemm_wmma_cshuffle_v3r1.hpp:405
static bool IsSupportedArgument(const Argument &arg)
Definition: device_gemm_wmma_cshuffle_v3r1.hpp:381
static constexpr auto DsVectorLengthSequence
Definition: device_gemm_wmma_cshuffle_v3r1.hpp:197
::std::unique_ptr< BaseInvoker > MakeInvokerPointer() override
Definition: device_gemm_wmma_cshuffle_v3r1.hpp:453
static auto MakeInvoker()
Definition: device_gemm_wmma_cshuffle_v3r1.hpp:450
GridwiseGemm_wmma_cshuffle_v3< ALayout, BLayout, Tuple<>, CLayout, Tuple< ADataType >, Tuple< BDataType >, GemmAccDataType, ReduceDataType, Tuple<>, ReduceDataType, AElementwiseOperation, BElementwiseOperation, PassThrough, GemmSpec, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerWmma, NPerWmma, MRepeat, NRepeat, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMRepeatPerShuffle, CShuffleNRepeatPerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, Sequence< CShuffleBlockTransferScalarPerVector_NPerBlock >, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB, false, false > GridwiseGemm
Definition: device_gemm_wmma_cshuffle_v3r1.hpp:146
ck::reduce::Add ReduceAdd
Definition: device_gemm_wmma_cshuffle_v3r1.hpp:194
CElementwiseOperation OutElementwiseOperation
Definition: device_gemm_wmma_cshuffle_v3r1.hpp:195
bool IsSupportedArgument(const BaseArgument *p_arg) override
Definition: device_gemm_wmma_cshuffle_v3r1.hpp:400
size_t GetWorkSpaceSize(const BaseArgument *p_arg) const override
Definition: device_gemm_wmma_cshuffle_v3r1.hpp:546
static constexpr index_t GetBlockSize()
Definition: device_gemm_wmma_cshuffle_v3r1.hpp:410
::std::unique_ptr< BaseArgument > MakeArgumentPointer(const void *p_a, const void *p_b, ::std::array< const void *, NumDTensor > p_ds, void *p_c, index_t M, index_t N, index_t K, index_t StrideA, index_t StrideB, ::std::array< index_t, NumDTensor > DsStrides, index_t StrideC, index_t KSplit, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CElementwiseOperation c_element_op) override
Definition: device_gemm_wmma_cshuffle_v3r1.hpp:459
::std::string GetTypeString() const override
Definition: device_gemm_wmma_cshuffle_v3r1.hpp:492
Definition: device_gemm_v2.hpp:57
Definition: device_reduce_threadwise_multi_d.hpp:47
Definition: unary_element_wise_operation.hpp:334