/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3r1.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3r1.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3r1.hpp Source File
device_gemm_xdl_cshuffle_v3r1.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include <iostream>
7 #include <sstream>
8 #include <typeinfo>
9 
20 
24 
25 namespace ck {
26 namespace tensor_operation {
27 namespace device {
28 
29 template <typename ALayout,
30  typename BLayout,
31  typename DsLayout,
32  typename CLayout,
33  typename ADataType,
34  typename BDataType,
35  typename DsDataType,
36  typename CDataType,
37  typename GemmAccDataType,
38  typename CShuffleDataType,
39  typename AElementwiseOperation,
40  typename BElementwiseOperation,
41  typename CElementwiseOperation,
42  GemmSpecialization GemmSpec,
43  index_t BlockSize,
44  index_t MPerBlock,
45  index_t NPerBlock,
46  index_t KPerBlock,
47  index_t AK1,
48  index_t BK1,
49  index_t MPerXDL,
50  index_t NPerXDL,
51  index_t MXdlPerWave,
52  index_t NXdlPerWave,
53  typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
54  typename ABlockTransferThreadClusterArrangeOrder,
55  typename ABlockTransferSrcAccessOrder,
56  index_t ABlockTransferSrcVectorDim,
57  index_t ABlockTransferSrcScalarPerVector,
58  index_t ABlockTransferDstScalarPerVector_AK1,
59  bool ABlockLdsExtraM,
60  typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
61  typename BBlockTransferThreadClusterArrangeOrder,
62  typename BBlockTransferSrcAccessOrder,
63  index_t BBlockTransferSrcVectorDim,
64  index_t BBlockTransferSrcScalarPerVector,
65  index_t BBlockTransferDstScalarPerVector_BK1,
66  bool BBlockLdsExtraN,
67  index_t CShuffleMXdlPerWavePerShuffle,
68  index_t CShuffleNXdlPerWavePerShuffle,
69  typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
70  index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
73  typename ReduceDataType = CDataType,
74  typename ComputeTypeA = CDataType,
75  typename ComputeTypeB = ComputeTypeA>
77  BLayout,
78  DsLayout,
79  CLayout,
80  ADataType,
81  BDataType,
82  DsDataType,
83  CDataType,
84  AElementwiseOperation,
85  BElementwiseOperation,
86  CElementwiseOperation>
87 {
88  static constexpr index_t NumDTensor = DsDataType::Size();
89 
91 
92  // GridwiseGemm
94  ALayout,
95  BLayout,
96  CLayout,
97  ADataType,
98  BDataType,
99  GemmAccDataType,
100  CShuffleDataType,
101  ReduceDataType,
102  AElementwiseOperation,
103  BElementwiseOperation,
104  PassThrough,
105  GemmSpec,
106  BlockSize,
107  MPerBlock,
108  NPerBlock,
109  KPerBlock,
110  AK1,
111  BK1,
112  MPerXDL,
113  NPerXDL,
114  MXdlPerWave,
115  NXdlPerWave,
116  ABlockTransferThreadClusterLengths_AK0_M_AK1,
117  ABlockTransferThreadClusterArrangeOrder,
118  ABlockTransferSrcAccessOrder,
119  ABlockTransferSrcVectorDim,
120  ABlockTransferSrcScalarPerVector,
121  ABlockTransferDstScalarPerVector_AK1,
122  false,
123  ABlockLdsExtraM,
124  BBlockTransferThreadClusterLengths_BK0_N_BK1,
125  BBlockTransferThreadClusterArrangeOrder,
126  BBlockTransferSrcAccessOrder,
127  BBlockTransferSrcVectorDim,
128  BBlockTransferSrcScalarPerVector,
129  BBlockTransferDstScalarPerVector_BK1,
130  false,
131  BBlockLdsExtraN,
132  CShuffleMXdlPerWavePerShuffle,
133  CShuffleNXdlPerWavePerShuffle,
134  CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
135  CShuffleBlockTransferScalarPerVector_NPerBlock,
136  BlkGemmPipeSched,
137  BlkGemmPipelineVer,
138  ComputeTypeA,
139  ComputeTypeB>;
140 
142  {
143  Argument(const ADataType* p_a_grid_,
144  const BDataType* p_b_grid_,
145  const std::array<const void*, NumDTensor> p_ds_,
146  CDataType* p_c_grid_,
147  index_t M_,
148  index_t N_,
149  index_t K_,
150  index_t StrideA_,
151  index_t StrideB_,
152  std::array<ck::index_t, NumDTensor> StrideDs_,
153  index_t StrideC_,
154  index_t k_batch_)
155  : GridwiseGemm::Argument(p_a_grid_,
156  p_b_grid_,
157  reinterpret_cast<ReduceDataType*>(p_c_grid_),
158  M_,
159  N_,
160  K_,
161  StrideA_,
162  StrideB_,
163  StrideC_,
164  k_batch_,
165  true),
166  p_ds(p_ds_),
167  StrideDs(StrideDs_)
168  {
169  }
170 
171  const std::array<const void*, NumDTensor> p_ds;
172  std::array<ck::index_t, NumDTensor> StrideDs;
173  };
174 
176  using OutElementwiseOperation = CElementwiseOperation;
177 
179  [](auto i) {
180  using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
183  else
184  return Number<1>{};
185  },
187 
189  ReduceDataType, // InDataType,
190  DsDataType, // DsDatatype
191  GemmAccDataType, // AccDataType,
192  CDataType, // OutDataType,
193  3, // Rank
194  1, // NumReduceDim
195  ReduceAdd,
196  PassThrough,
198  256, // BlockSize_,
199  CShuffleBlockTransferScalarPerVector_NPerBlock, // MThreadSliceSize_,
200  1, // KThreadSliceSize_,
201  0, // InSrcVectorDim_,
202  CShuffleBlockTransferScalarPerVector_NPerBlock, // InSrcVectorSize_,
203  CShuffleBlockTransferScalarPerVector_NPerBlock, // OutDstVectorSize_
204  decltype(DsVectorLengthSequence)>;
205 
206  // Invoker
207  struct Invoker : public BaseInvoker
208  {
209  float RunReduce(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
210  {
211  static constexpr index_t NumInDim = 3;
212  static constexpr index_t NumOutDim = 2;
213 
214  std::array<ck::index_t, NumInDim> in_lengths = {arg.KBatch, arg.M, arg.N};
215  std::array<ck::index_t, NumOutDim> out_lengths = {arg.M, arg.N};
216 
217  std::array<ck::index_t, NumInDim> in_strides;
218  std::array<ck::index_t, NumOutDim> out_strides;
220  {
221  in_strides = {arg.M * arg.N, arg.N, 1};
222  out_strides = {arg.N, 1};
223  }
224  else
225  {
226  in_strides = {arg.M * arg.N, 1, arg.M};
227  out_strides = {1, arg.M};
228  }
229 
230  std::array<int, 1> reduce_dims{0};
231 
232  std::array<std::array<index_t, NumOutDim>, NumDTensor> DsLengths;
233  std::array<std::array<index_t, NumOutDim>, NumDTensor> DsStrides;
234 
235  static_for<0, NumDTensor, 1>{}([&](auto i) {
236  DsLengths[i] = out_lengths;
237 
238  using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
240  {
241  DsStrides[i] = {arg.StrideDs[i], 1};
242  }
243  else
244  {
245  DsStrides[i] = {1, arg.StrideDs[i]};
246  }
247  });
248 
249  auto reduce = DeviceReduceInstance{};
250 
251  auto argument_ptr = reduce.MakeArgumentPointer(in_lengths,
252  in_strides,
253  DsLengths,
254  DsStrides,
255  out_lengths,
256  out_strides,
257  reduce_dims,
258  arg.p_workspace_,
259  arg.p_ds,
260  arg.p_c_grid,
261  PassThrough{},
263 
264  auto invoker_ptr = reduce.MakeInvokerPointer();
265 
266  float ave_time = 0;
267 
268  if(reduce.IsSupportedArgument(argument_ptr.get()))
269  {
270  ave_time = invoker_ptr->Run(argument_ptr.get(), stream_config);
271  }
272  else
273  {
274  throw std::runtime_error(
275  "The runtime parameters seems not supported by the device instance, exiting!");
276  }
277 
278  return ave_time;
279  }
280 
281  float Run(const Argument& arg_, const StreamConfig& stream_config = StreamConfig{})
282  {
283  auto arg = *dynamic_cast<const typename GridwiseGemm::Argument*>(&arg_);
284 
285  if(!(!(arg.IsReduceAdd() || NumDTensor > 0) &&
287  {
288  if(arg.p_workspace_ == nullptr)
289  {
290  throw std::runtime_error("using reduce , but empty workspace!");
291  }
292 
293  arg.p_c_grid = reinterpret_cast<ReduceDataType*>(arg.p_workspace_);
294  }
295 
296  if(stream_config.log_level_ > 0)
297  {
298  arg.Print();
299  }
300 
302  {
303  throw std::runtime_error("wrong! GridwiseGemm has invalid setting");
304  }
305 
306  index_t gdx, gdy, gdz;
307  std::tie(gdx, gdy, gdz) = GridwiseGemm::CalculateGridSize(arg.M, arg.N, arg.KBatch);
308 
309  float ave_time = 0;
310 
311  index_t k_grain = arg.KBatch * KPerBlock;
312  index_t K_split = (arg.K + k_grain - 1) / k_grain * KPerBlock;
313 
314  const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K_split);
315 
316  const auto Run = [&](const auto& kernel) {
317  if(stream_config.flush_cache)
318  {
320  arg,
321  stream_config.rotating_count,
322  arg.M * arg.K * sizeof(ADataType),
323  arg.K * arg.N * sizeof(BDataType));
324  rotating_mem.Print();
325 
326  auto run_flush_cache = [&]() {
327  // flush icache
329  // rotating mem
330  rotating_mem.Next();
331  };
332 
333  ave_time = ck::utility::launch_and_time_kernel_with_preprocess<false>(
334  stream_config,
335  run_flush_cache,
336  kernel,
337  dim3(gdx, gdy, gdz),
338  dim3(BlockSize),
339  0,
340  arg);
341  }
342  else
343  {
344  ave_time = launch_and_time_kernel(
345  stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, arg);
346  }
347  };
348 
349  constexpr index_t minimum_occupancy =
350  BlkGemmPipeSched == BlockGemmPipelineScheduler::Intrawave ? 1 : 2;
351 
352  if(has_main_k_block_loop)
353  {
354  // Tail number always full
355  if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1 ||
356  BlkGemmPipelineVer == BlockGemmPipelineVersion::v3)
357  {
358 
359  const auto kernel = kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
360  true,
362  minimum_occupancy>;
363  Run(kernel);
364  }
365  // Tail number could be One to Seven
366  else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v2)
367  {
369  {
370  const auto kernel =
372  true,
374  minimum_occupancy,
376  Run(kernel);
377  }
379  {
380  const auto kernel =
382  true,
384  minimum_occupancy,
386  Run(kernel);
387  }
388 
389  if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 2)
390  {
392  {
393  const auto kernel =
395  true,
397  minimum_occupancy,
399  Run(kernel);
400  }
401  }
402 
403  if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 3)
404  {
406  {
407  const auto kernel =
409  true,
411  minimum_occupancy,
413  Run(kernel);
414  }
415  }
416 
417  if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 4)
418  {
420  {
421  const auto kernel =
423  true,
425  minimum_occupancy,
427  Run(kernel);
428  }
429  }
430 
431  if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 5)
432  {
434  {
435  const auto kernel =
437  true,
439  minimum_occupancy,
441  Run(kernel);
442  }
443  }
444 
445  if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 6)
446  {
448  {
449  const auto kernel =
451  true,
453  minimum_occupancy,
455  Run(kernel);
456  }
457  }
458 
459  if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 7)
460  {
462  {
463  const auto kernel =
465  true,
467  minimum_occupancy,
469  Run(kernel);
470  }
471  }
472  }
473  // Tail number could be Odd or Even
474  else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v4)
475  {
476 
478  {
479  const auto kernel =
481  true,
483  minimum_occupancy,
485  Run(kernel);
486  }
487  else
488  {
489  const auto kernel =
491  true,
493  minimum_occupancy,
495  Run(kernel);
496  }
497  }
498  else
499  {
501  {
502  const auto kernel =
504  true,
506  minimum_occupancy,
508  Run(kernel);
509  }
510  else
511  {
512  const auto kernel =
514  true,
516  minimum_occupancy,
518  Run(kernel);
519  }
520  }
521  }
522  else
523  {
524  // Tail number always 1
525  if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1)
526  {
527 
528  const auto kernel = kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
529  false,
531  minimum_occupancy>;
532  Run(kernel);
533  }
534  }
535 
536  if(!(!(arg.IsReduceAdd() || NumDTensor > 0) &&
538  {
539  // reduce c data
540  ave_time += RunReduce(arg_, stream_config);
541  }
542  return ave_time;
543  }
544 
545  // polymorphic
546  float Run(const BaseArgument* p_arg,
547  const StreamConfig& stream_config = StreamConfig{}) override
548  {
549  return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
550  }
551  };
552 
553  static constexpr bool IsValidCompilationParameter()
554  {
555  // TODO: properly implement this check
556  return true;
557  }
558 
559  static bool IsSupportedArgument(const Argument& arg)
560  {
561  if(!ck::is_xdl_supported())
562  {
563  return false;
564  }
565 
566  if((arg.K % AK1 != 0 || arg.K % BK1 != 0) && !(GemmSpec == GemmSpecialization::MKPadding ||
567  GemmSpec == GemmSpecialization::NKPadding ||
568  GemmSpec == GemmSpecialization::MNKPadding ||
569  GemmSpec == GemmSpecialization::KPadding))
570  {
571  return false;
572  }
573 
574  return GridwiseGemm::CheckValidity(arg);
575  }
576 
577  // polymorphic
578  bool IsSupportedArgument(const BaseArgument* p_arg) override
579  {
580  return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
581  }
582 
583  static auto MakeArgument(const ADataType* p_a,
584  const BDataType* p_b,
585  const std::array<const void*, NumDTensor> p_ds,
586  CDataType* p_c,
587  index_t M,
588  index_t N,
589  index_t K,
590  index_t StrideA,
591  index_t StrideB,
592  std::array<ck::index_t, NumDTensor> StrideDs,
593  index_t StrideC,
594  index_t KBatch,
595  AElementwiseOperation,
596  BElementwiseOperation,
597  CElementwiseOperation)
598  {
599  return Argument{p_a, p_b, p_ds, p_c, M, N, K, StrideA, StrideB, StrideDs, StrideC, KBatch};
600  }
601 
602  static auto MakeInvoker() { return Invoker{}; }
603 
604  // polymorphic
605  std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a,
606  const void* p_b,
607  std::array<const void*, NumDTensor> p_ds,
608  void* p_c,
609  index_t M,
610  index_t N,
611  index_t K,
612  index_t StrideA,
613  index_t StrideB,
614  std::array<ck::index_t, NumDTensor> StrideDs,
615  index_t StrideC,
616  index_t KBatch,
617  AElementwiseOperation,
618  BElementwiseOperation,
619  CElementwiseOperation) override
620  {
621  return std::make_unique<Argument>(static_cast<const ADataType*>(p_a),
622  static_cast<const BDataType*>(p_b),
623  p_ds,
624  static_cast<CDataType*>(p_c),
625  M,
626  N,
627  K,
628  StrideA,
629  StrideB,
630  StrideDs,
631  StrideC,
632  KBatch);
633  }
634 
635  // polymorphic
636  std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
637  {
638  return std::make_unique<Invoker>(Invoker{});
639  }
640 
641  // polymorphic
642  std::string GetTypeString() const override
643  {
644  auto str = std::stringstream();
645 
646  std::map<BlockGemmPipelineScheduler, std::string> BlkGemmPipelineSchedulerToString{
649 
650  std::map<BlockGemmPipelineVersion, std::string> BlkGemmPipelineVersionToString{
656 
657  // clang-format off
658  str << "DeviceGemmXdlUniversalReduce"
659  << "<"
660  << getGemmSpecializationString(GemmSpec) << ", "
661  << std::string(ALayout::name)[0]
662  << std::string(BLayout::name)[0]
663  << std::string(CLayout::name)[0]
664  << ">"
665  << " BlkSize: "
666  << BlockSize << ", "
667  << "BlkTile: "
668  << MPerBlock<<"x"<<NPerBlock<<"x"<<KPerBlock << ", "
669  << "WaveTile: "
670  << MPerXDL<<"x"<<NPerXDL << ", "
671  << "WaveMap: "
672  << MXdlPerWave<<"x" << NXdlPerWave<<", "
673  << "VmemReadVec: "
674  << ABlockTransferSrcScalarPerVector<<"x"<<BBlockTransferSrcScalarPerVector<<", "
675  << "BlkGemmPipelineScheduler: "
676  << BlkGemmPipelineSchedulerToString[BlkGemmPipeSched] << ", "
677  << "BlkGemmPipelineVersion: "
678  << BlkGemmPipelineVersionToString[BlkGemmPipelineVer] << ", "
679  << "BlkGemmPipelinePrefetchStages: "
680  << GridwiseGemm::BlockwiseGemmPipe::PrefetchStages;
681  // clang-format on
682 
683  return str.str();
684  }
685 
686  size_t GetWorkSpaceSize(const BaseArgument* p_arg) const override
687  {
688  auto arg = *dynamic_cast<const Argument*>(p_arg);
689 
690  if(!(!(arg.IsReduceAdd() || NumDTensor > 0) &&
692  {
693  std::cout << "using workspace" << std::endl;
694  return arg.M * arg.N * arg.KBatch * sizeof(ReduceDataType);
695  }
696 
697  return 0;
698  }
699 };
700 
701 } // namespace device
702 } // namespace tensor_operation
703 } // namespace ck
float launch_and_time_kernel(const StreamConfig &stream_config, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args)
Definition: kernel_launch.hpp:14
std::string getGemmSpecializationString(const GemmSpecialization &s)
Definition: gemm_specialization.hpp:32
GemmSpecialization
Definition: gemm_specialization.hpp:11
void flush_icache()
Definition: flush_cache.hpp:216
Definition: ck.hpp:267
bool is_xdl_supported()
Definition: device_prop.hpp:68
typename tuple_element< I, TTuple >::type tuple_element_t
Definition: tuple.hpp:208
BlockGemmPipelineVersion
Definition: blkgemmpipe_scheduler.hpp:12
__global__ void kernel_gemm_xdl_cshuffle_v3_2lds(typename GridwiseGemm::Argument karg)
Definition: gridwise_gemm_xdl_cshuffle_streamk_v3.hpp:59
constexpr Tuple< Args &... > tie(Args &... args) noexcept
Definition: tuple.hpp:218
__host__ constexpr __device__ auto generate_sequence_v2(F &&f, Number< N >)
Definition: sequence_helper.hpp:25
BlockGemmPipelineScheduler
Definition: blkgemmpipe_scheduler.hpp:25
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition: type.hpp:297
__global__ void kernel_gemm_xdl_cshuffle_v3(typename GridwiseGemm::Argument karg)
Definition: gridwise_gemm_xdl_cshuffle_streamk_v3.hpp:38
int32_t index_t
Definition: ck.hpp:298
const GenericPointer< typename T::ValueType > T2 value
Definition: pointer.h:1350
Definition: stream_config.hpp:10
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:716
CDataType * p_c_grid
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:760
__host__ __device__ bool IsReduceAdd() const
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:748
index_t N
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:695
index_t KBatch
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:700
index_t M
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:694
index_t K
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:696
__host__ void Print() const
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:673
"Universal" GEMM kernel with SplitK support.
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:247
static constexpr __host__ TailNumber CalculateKBlockLoopTailNum(index_t K)
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:1456
static constexpr __host__ bool CheckValidity(const Argument &karg)
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:1237
static __host__ auto CalculateGridSize(index_t M, index_t N, index_t KBatch)
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:299
static constexpr __host__ bool CalculateHasMainKBlockLoop(index_t K)
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:1449
Definition: integral_constant.hpp:20
Definition: reduction_operator.hpp:37
Definition: device_base.hpp:51
void * p_workspace_
Definition: device_base.hpp:58
Definition: device_base.hpp:62
Definition: device_gemm_xdl_cshuffle_v3r1.hpp:142
const std::array< const void *, NumDTensor > p_ds
Definition: device_gemm_xdl_cshuffle_v3r1.hpp:171
std::array< ck::index_t, NumDTensor > StrideDs
Definition: device_gemm_xdl_cshuffle_v3r1.hpp:172
Argument(const ADataType *p_a_grid_, const BDataType *p_b_grid_, const std::array< const void *, NumDTensor > p_ds_, CDataType *p_c_grid_, index_t M_, index_t N_, index_t K_, index_t StrideA_, index_t StrideB_, std::array< ck::index_t, NumDTensor > StrideDs_, index_t StrideC_, index_t k_batch_)
Definition: device_gemm_xdl_cshuffle_v3r1.hpp:143
Definition: device_gemm_xdl_cshuffle_v3r1.hpp:208
float RunReduce(const Argument &arg, const StreamConfig &stream_config=StreamConfig{})
Definition: device_gemm_xdl_cshuffle_v3r1.hpp:209
float Run(const BaseArgument *p_arg, const StreamConfig &stream_config=StreamConfig{}) override
Definition: device_gemm_xdl_cshuffle_v3r1.hpp:546
float Run(const Argument &arg_, const StreamConfig &stream_config=StreamConfig{})
Definition: device_gemm_xdl_cshuffle_v3r1.hpp:281
Definition: device_gemm_xdl_cshuffle_v3r1.hpp:87
DeviceReduceThreadWiseMultiD< ReduceDataType, DsDataType, GemmAccDataType, CDataType, 3, 1, ReduceAdd, PassThrough, OutElementwiseOperation, 256, CShuffleBlockTransferScalarPerVector_NPerBlock, 1, 0, CShuffleBlockTransferScalarPerVector_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, decltype(DsVectorLengthSequence)> DeviceReduceInstance
Definition: device_gemm_xdl_cshuffle_v3r1.hpp:204
std::unique_ptr< BaseArgument > MakeArgumentPointer(const void *p_a, const void *p_b, std::array< const void *, NumDTensor > p_ds, void *p_c, index_t M, index_t N, index_t K, index_t StrideA, index_t StrideB, std::array< ck::index_t, NumDTensor > StrideDs, index_t StrideC, index_t KBatch, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation) override
Definition: device_gemm_xdl_cshuffle_v3r1.hpp:605
static constexpr index_t NumDTensor
Definition: device_gemm_xdl_cshuffle_v3r1.hpp:88
size_t GetWorkSpaceSize(const BaseArgument *p_arg) const override
Definition: device_gemm_xdl_cshuffle_v3r1.hpp:686
static constexpr bool IsValidCompilationParameter()
Definition: device_gemm_xdl_cshuffle_v3r1.hpp:553
static constexpr auto DsVectorLengthSequence
Definition: device_gemm_xdl_cshuffle_v3r1.hpp:178
ck::reduce::Add ReduceAdd
Definition: device_gemm_xdl_cshuffle_v3r1.hpp:175
static bool IsSupportedArgument(const Argument &arg)
Definition: device_gemm_xdl_cshuffle_v3r1.hpp:559
std::unique_ptr< BaseInvoker > MakeInvokerPointer() override
Definition: device_gemm_xdl_cshuffle_v3r1.hpp:636
bool IsSupportedArgument(const BaseArgument *p_arg) override
Definition: device_gemm_xdl_cshuffle_v3r1.hpp:578
GridwiseGemm_xdl_cshuffle_v3< ALayout, BLayout, CLayout, ADataType, BDataType, GemmAccDataType, CShuffleDataType, ReduceDataType, AElementwiseOperation, BElementwiseOperation, PassThrough, GemmSpec, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB > GridwiseGemm
Definition: device_gemm_xdl_cshuffle_v3r1.hpp:139
static auto MakeArgument(const ADataType *p_a, const BDataType *p_b, const std::array< const void *, NumDTensor > p_ds, CDataType *p_c, index_t M, index_t N, index_t K, index_t StrideA, index_t StrideB, std::array< ck::index_t, NumDTensor > StrideDs, index_t StrideC, index_t KBatch, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation)
Definition: device_gemm_xdl_cshuffle_v3r1.hpp:583
static auto MakeInvoker()
Definition: device_gemm_xdl_cshuffle_v3r1.hpp:602
CElementwiseOperation OutElementwiseOperation
Definition: device_gemm_xdl_cshuffle_v3r1.hpp:176
ck::tensor_operation::element_wise::PassThrough PassThrough
Definition: device_gemm_xdl_cshuffle_v3r1.hpp:90
std::string GetTypeString() const override
Definition: device_gemm_xdl_cshuffle_v3r1.hpp:642
Definition: device_gemm_v2.hpp:57
Definition: device_reduce_threadwise_multi_d.hpp:47
Definition: unary_element_wise_operation.hpp:334
Definition: flush_cache.hpp:138