/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3r1.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3r1.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3r1.hpp Source File
device_gemm_xdl_cshuffle_v3r1.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include <iostream>
7 #include <sstream>
8 #include <typeinfo>
9 
20 
24 
25 namespace ck {
26 namespace tensor_operation {
27 namespace device {
28 
29 template <typename ALayout,
30  typename BLayout,
31  typename DsLayout,
32  typename CLayout,
33  typename ADataType,
34  typename BDataType,
35  typename DsDataType,
36  typename CDataType,
37  typename GemmAccDataType,
38  typename CShuffleDataType,
39  typename AElementwiseOperation,
40  typename BElementwiseOperation,
41  typename CElementwiseOperation,
42  GemmSpecialization GemmSpec,
43  index_t BlockSize,
44  index_t MPerBlock,
45  index_t NPerBlock,
46  index_t KPerBlock,
47  index_t AK1,
48  index_t BK1,
49  index_t MPerXDL,
50  index_t NPerXDL,
51  index_t MXdlPerWave,
52  index_t NXdlPerWave,
53  typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
54  typename ABlockTransferThreadClusterArrangeOrder,
55  typename ABlockTransferSrcAccessOrder,
56  index_t ABlockTransferSrcVectorDim,
57  index_t ABlockTransferSrcScalarPerVector,
58  index_t ABlockTransferDstScalarPerVector_AK1,
59  bool ABlockLdsExtraM,
60  typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
61  typename BBlockTransferThreadClusterArrangeOrder,
62  typename BBlockTransferSrcAccessOrder,
63  index_t BBlockTransferSrcVectorDim,
64  index_t BBlockTransferSrcScalarPerVector,
65  index_t BBlockTransferDstScalarPerVector_BK1,
66  bool BBlockLdsExtraN,
67  index_t CShuffleMXdlPerWavePerShuffle,
68  index_t CShuffleNXdlPerWavePerShuffle,
69  typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
70  index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
73  typename ReduceDataType = CDataType,
74  typename ComputeTypeA = CDataType,
75  typename ComputeTypeB = ComputeTypeA>
77  BLayout,
78  DsLayout,
79  CLayout,
80  ADataType,
81  BDataType,
82  DsDataType,
83  CDataType,
84  AElementwiseOperation,
85  BElementwiseOperation,
86  CElementwiseOperation>
87 {
89  static constexpr auto NXdlPerWave64 = GetNXdlPerWave<true>();
90  static constexpr auto NXdlPerWave32 = GetNXdlPerWave<false>();
91 
92  static constexpr index_t NumDTensor = DsDataType::Size();
93 
95 
96  // GridwiseGemm
97  template <index_t NXdlPerWave_>
99  ALayout,
100  BLayout,
101  CLayout,
102  ADataType,
103  BDataType,
104  GemmAccDataType,
105  CShuffleDataType,
106  ReduceDataType,
107  AElementwiseOperation,
108  BElementwiseOperation,
109  PassThrough,
110  GemmSpec,
111  BlockSize,
112  MPerBlock,
113  NPerBlock,
114  KPerBlock,
115  AK1,
116  BK1,
117  MPerXDL,
118  NPerXDL,
119  MXdlPerWave,
120  NXdlPerWave_,
121  ABlockTransferThreadClusterLengths_AK0_M_AK1,
122  ABlockTransferThreadClusterArrangeOrder,
123  ABlockTransferSrcAccessOrder,
124  ABlockTransferSrcVectorDim,
125  ABlockTransferSrcScalarPerVector,
126  ABlockTransferDstScalarPerVector_AK1,
127  false,
128  ABlockLdsExtraM,
129  BBlockTransferThreadClusterLengths_BK0_N_BK1,
130  BBlockTransferThreadClusterArrangeOrder,
131  BBlockTransferSrcAccessOrder,
132  BBlockTransferSrcVectorDim,
133  BBlockTransferSrcScalarPerVector,
134  BBlockTransferDstScalarPerVector_BK1,
135  false,
136  BBlockLdsExtraN,
137  CShuffleMXdlPerWavePerShuffle,
138  CShuffleNXdlPerWavePerShuffle,
139  CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
140  CShuffleBlockTransferScalarPerVector_NPerBlock,
141  BlkGemmPipeSched,
142  BlkGemmPipelineVer,
143  ComputeTypeA,
144  ComputeTypeB>;
147 
149  {
150  Argument(const ADataType* p_a_grid_,
151  const BDataType* p_b_grid_,
152  const std::array<const void*, NumDTensor> p_ds_,
153  CDataType* p_c_grid_,
154  index_t M_,
155  index_t N_,
156  index_t K_,
157  index_t StrideA_,
158  index_t StrideB_,
159  std::array<ck::index_t, NumDTensor> StrideDs_,
160  index_t StrideC_,
161  index_t k_batch_)
162  : GridwiseGemm64::Argument(p_a_grid_,
163  p_b_grid_,
164  reinterpret_cast<ReduceDataType*>(p_c_grid_),
165  M_,
166  N_,
167  K_,
168  StrideA_,
169  StrideB_,
170  StrideC_,
171  k_batch_,
172  true),
173  p_ds(p_ds_),
174  StrideDs(StrideDs_)
175  {
176  }
177 
178  const std::array<const void*, NumDTensor> p_ds;
179  std::array<ck::index_t, NumDTensor> StrideDs;
180  };
181 
183  using OutElementwiseOperation = CElementwiseOperation;
184 
186  [](auto i) {
187  using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
190  else
191  return Number<1>{};
192  },
194 
196  ReduceDataType, // InDataType,
197  DsDataType, // DsDatatype
198  GemmAccDataType, // AccDataType,
199  CDataType, // OutDataType,
200  3, // Rank
201  1, // NumReduceDim
202  ReduceAdd,
203  PassThrough,
205  256, // BlockSize_,
206  CShuffleBlockTransferScalarPerVector_NPerBlock, // MThreadSliceSize_,
207  1, // KThreadSliceSize_,
208  0, // InSrcVectorDim_,
209  CShuffleBlockTransferScalarPerVector_NPerBlock, // InSrcVectorSize_,
210  CShuffleBlockTransferScalarPerVector_NPerBlock, // OutDstVectorSize_
211  decltype(DsVectorLengthSequence)>;
212 
213  // Invoker
214  struct Invoker : public BaseInvoker
215  {
216  float RunReduce(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
217  {
218  static constexpr index_t NumInDim = 3;
219  static constexpr index_t NumOutDim = 2;
220 
221  std::array<ck::index_t, NumInDim> in_lengths = {arg.KBatch, arg.M, arg.N};
222  std::array<ck::index_t, NumOutDim> out_lengths = {arg.M, arg.N};
223 
224  std::array<ck::index_t, NumInDim> in_strides;
225  std::array<ck::index_t, NumOutDim> out_strides;
227  {
228  in_strides = {arg.M * arg.N, arg.N, 1};
229  out_strides = {arg.N, 1};
230  }
231  else
232  {
233  in_strides = {arg.M * arg.N, 1, arg.M};
234  out_strides = {1, arg.M};
235  }
236 
237  std::array<int, 1> reduce_dims{0};
238 
239  std::array<std::array<index_t, NumOutDim>, NumDTensor> DsLengths;
240  std::array<std::array<index_t, NumOutDim>, NumDTensor> DsStrides;
241 
242  static_for<0, NumDTensor, 1>{}([&](auto i) {
243  DsLengths[i] = out_lengths;
244 
245  using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
247  {
248  DsStrides[i] = {arg.StrideDs[i], 1};
249  }
250  else
251  {
252  DsStrides[i] = {1, arg.StrideDs[i]};
253  }
254  });
255 
256  auto reduce = DeviceReduceInstance{};
257 
258  auto argument_ptr = reduce.MakeArgumentPointer(in_lengths,
259  in_strides,
260  DsLengths,
261  DsStrides,
262  out_lengths,
263  out_strides,
264  reduce_dims,
265  arg.p_workspace_,
266  arg.p_ds,
267  arg.p_c_grid,
268  PassThrough{},
270 
271  auto invoker_ptr = reduce.MakeInvokerPointer();
272 
273  float ave_time = 0;
274 
275  if(reduce.IsSupportedArgument(argument_ptr.get()))
276  {
277  ave_time = invoker_ptr->Run(argument_ptr.get(), stream_config);
278  }
279  else
280  {
281  throw std::runtime_error(
282  "The runtime parameters seems not supported by the device instance, exiting!");
283  }
284 
285  return ave_time;
286  }
287 
288  template <typename GridwiseGemm>
289  float RunImp(const Argument& arg_, const StreamConfig& stream_config = StreamConfig{})
290  {
291  auto arg = *reinterpret_cast<const typename GridwiseGemm::Argument*>(&arg_);
292 
293  if(!(!(arg.IsReduceAdd() || NumDTensor > 0) &&
295  {
296  if(arg.p_workspace_ == nullptr)
297  {
298  throw std::runtime_error("using reduce , but empty workspace!");
299  }
300 
301  arg.p_c_grid = reinterpret_cast<ReduceDataType*>(arg.p_workspace_);
302  }
303 
304  if(stream_config.log_level_ > 0)
305  {
306  arg.Print();
307  }
308 
309  if(!GridwiseGemm::CheckValidity(arg))
310  {
311  throw std::runtime_error("wrong! GridwiseGemm has invalid setting");
312  }
313 
314  index_t gdx, gdy, gdz;
315  std::tie(gdx, gdy, gdz) = GridwiseGemm::CalculateGridSize(arg.M, arg.N, arg.KBatch);
316 
317  float ave_time = 0;
318 
319  index_t k_grain = arg.KBatch * KPerBlock;
320  index_t K_split = (arg.K + k_grain - 1) / k_grain * KPerBlock;
321 
322  const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K_split);
323 
324  const auto Run = [&](const auto& kernel) {
325  if(stream_config.flush_cache)
326  {
328  arg,
329  stream_config.rotating_count,
330  arg.M * arg.K * sizeof(ADataType),
331  arg.K * arg.N * sizeof(BDataType));
332  rotating_mem.Print();
333 
334  auto run_flush_cache = [&]() {
335  // flush icache
337  // rotating mem
338  rotating_mem.Next();
339  };
340 
341  ave_time = ck::utility::launch_and_time_kernel_with_preprocess<false>(
342  stream_config,
343  run_flush_cache,
344  kernel,
345  dim3(gdx, gdy, gdz),
346  dim3(BlockSize),
347  0,
348  arg);
349  }
350  else
351  {
352  ave_time = launch_and_time_kernel(
353  stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, arg);
354  }
355  };
356 
357  constexpr index_t minimum_occupancy =
358  BlkGemmPipeSched == BlockGemmPipelineScheduler::Intrawave ? 1 : 2;
359 
360  if(has_main_k_block_loop)
361  {
362  // Tail number always full
363  if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1 ||
364  BlkGemmPipelineVer == BlockGemmPipelineVersion::v3)
365  {
366 
367  const auto kernel = kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
368  true,
370  minimum_occupancy>;
371  Run(kernel);
372  }
373  // Tail number could be One to Seven
374  else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v2)
375  {
376  if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::One)
377  {
378  const auto kernel =
379  kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
380  true,
382  minimum_occupancy,
384  Run(kernel);
385  }
386  else if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Full)
387  {
388  const auto kernel =
389  kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
390  true,
392  minimum_occupancy,
394  Run(kernel);
395  }
396 
397  if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 2)
398  {
399  if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Two)
400  {
401  const auto kernel =
402  kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
403  true,
405  minimum_occupancy,
407  Run(kernel);
408  }
409  }
410 
411  if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 3)
412  {
413  if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Three)
414  {
415  const auto kernel =
416  kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
417  true,
419  minimum_occupancy,
421  Run(kernel);
422  }
423  }
424 
425  if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 4)
426  {
427  if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Four)
428  {
429  const auto kernel =
430  kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
431  true,
433  minimum_occupancy,
435  Run(kernel);
436  }
437  }
438 
439  if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 5)
440  {
441  if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Five)
442  {
443  const auto kernel =
444  kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
445  true,
447  minimum_occupancy,
449  Run(kernel);
450  }
451  }
452 
453  if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 6)
454  {
455  if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Six)
456  {
457  const auto kernel =
458  kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
459  true,
461  minimum_occupancy,
463  Run(kernel);
464  }
465  }
466 
467  if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 7)
468  {
469  if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Seven)
470  {
471  const auto kernel =
472  kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
473  true,
475  minimum_occupancy,
477  Run(kernel);
478  }
479  }
480  }
481  // Tail number could be Odd or Even
482  else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v4)
483  {
484 
485  if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
486  {
487  const auto kernel =
489  true,
491  minimum_occupancy,
493  Run(kernel);
494  }
495  else
496  {
497  const auto kernel =
499  true,
501  minimum_occupancy,
503  Run(kernel);
504  }
505  }
506  else
507  {
508  if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
509  {
510  const auto kernel =
511  kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
512  true,
514  minimum_occupancy,
516  Run(kernel);
517  }
518  else
519  {
520  const auto kernel =
521  kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
522  true,
524  minimum_occupancy,
526  Run(kernel);
527  }
528  }
529  }
530  else
531  {
532  // Tail number always 1
533  if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1)
534  {
535 
536  const auto kernel = kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
537  false,
539  minimum_occupancy>;
540  Run(kernel);
541  }
542  }
543 
544  if(!(!(arg.IsReduceAdd() || NumDTensor > 0) &&
546  {
547  // reduce c data
548  ave_time += RunReduce(arg_, stream_config);
549  }
550  return ave_time;
551  }
552 
554 
555  // polymorphic
556  float Run(const BaseArgument* p_arg,
557  const StreamConfig& stream_config = StreamConfig{}) override
558  {
559  return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
560  }
561  };
562 
563  static constexpr bool IsValidCompilationParameter()
564  {
565  // TODO: properly implement this check
566  return true;
567  }
568 
569  static bool IsSupportedArgument(const Argument& arg)
570  {
571  if(!ck::is_xdl_wmma_supported<ComputeTypeA, ComputeTypeB, MPerXDL, NPerXDL>())
572  {
573  return false;
574  }
575  if((arg.K % AK1 != 0 || arg.K % BK1 != 0) && !(GemmSpec == GemmSpecialization::MKPadding ||
576  GemmSpec == GemmSpecialization::NKPadding ||
577  GemmSpec == GemmSpecialization::MNKPadding ||
578  GemmSpec == GemmSpecialization::KPadding))
579  {
580  return false;
581  }
582 
583  if(get_warp_size() == 64)
584  {
585  if constexpr(NXdlPerWave64 > 0)
586  {
587  return GridwiseGemm64::CheckValidity(arg);
588  }
589  }
590  else
591  {
592  if constexpr(NXdlPerWave32 > 0)
593  {
594  return GridwiseGemm32::CheckValidity(
595  reinterpret_cast<const typename GridwiseGemm32::Argument&>(arg));
596  }
597  }
598  return false;
599  }
600 
601  // polymorphic
602  bool IsSupportedArgument(const BaseArgument* p_arg) override
603  {
604  return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
605  }
606 
607  static auto MakeArgument(const ADataType* p_a,
608  const BDataType* p_b,
609  const std::array<const void*, NumDTensor> p_ds,
610  CDataType* p_c,
611  index_t M,
612  index_t N,
613  index_t K,
614  index_t StrideA,
615  index_t StrideB,
616  std::array<ck::index_t, NumDTensor> StrideDs,
617  index_t StrideC,
618  index_t KBatch,
619  AElementwiseOperation,
620  BElementwiseOperation,
621  CElementwiseOperation)
622  {
623  return Argument{p_a, p_b, p_ds, p_c, M, N, K, StrideA, StrideB, StrideDs, StrideC, KBatch};
624  }
625 
626  static auto MakeInvoker() { return Invoker{}; }
627 
628  // polymorphic
629  std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a,
630  const void* p_b,
631  std::array<const void*, NumDTensor> p_ds,
632  void* p_c,
633  index_t M,
634  index_t N,
635  index_t K,
636  index_t StrideA,
637  index_t StrideB,
638  std::array<ck::index_t, NumDTensor> StrideDs,
639  index_t StrideC,
640  index_t KBatch,
641  AElementwiseOperation,
642  BElementwiseOperation,
643  CElementwiseOperation) override
644  {
645  return std::make_unique<Argument>(static_cast<const ADataType*>(p_a),
646  static_cast<const BDataType*>(p_b),
647  p_ds,
648  static_cast<CDataType*>(p_c),
649  M,
650  N,
651  K,
652  StrideA,
653  StrideB,
654  StrideDs,
655  StrideC,
656  KBatch);
657  }
658 
659  // polymorphic
660  std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
661  {
662  return std::make_unique<Invoker>(Invoker{});
663  }
664 
665  // polymorphic
666  std::string GetTypeString() const override
667  {
668  auto str = std::stringstream();
669 
670  std::map<BlockGemmPipelineScheduler, std::string> BlkGemmPipelineSchedulerToString{
673 
674  std::map<BlockGemmPipelineVersion, std::string> BlkGemmPipelineVersionToString{
680 
681  // clang-format off
682  str << "DeviceGemmXdlUniversalReduce"
683  << "<"
684  << getGemmSpecializationString(GemmSpec) << ", "
685  << std::string(ALayout::name)[0]
686  << std::string(BLayout::name)[0]
687  << std::string(CLayout::name)[0]
688  << ">"
689  << " BlkSize: "
690  << BlockSize << ", "
691  << "BlkTile: "
692  << MPerBlock<<"x"<<NPerBlock<<"x"<<KPerBlock << ", "
693  << "WaveTile: "
694  << MPerXDL<<"x"<<NPerXDL << ", "
695  << "WaveMap: "
696  << MXdlPerWave<<"x" << NXdlPerWave<<", "
697  << "VmemReadVec: "
698  << ABlockTransferSrcScalarPerVector<<"x"<<BBlockTransferSrcScalarPerVector<<", "
699  << "BlkGemmPipelineScheduler: "
700  << BlkGemmPipelineSchedulerToString[BlkGemmPipeSched] << ", "
701  << "BlkGemmPipelineVersion: "
702  << BlkGemmPipelineVersionToString[BlkGemmPipelineVer] << ", "
703  << "BlkGemmPipelinePrefetchStages: "
704  << GridwiseGemm64::BlockwiseGemmPipe::PrefetchStages;
705  // clang-format on
706 
707  return str.str();
708  }
709 
710  size_t GetWorkSpaceSize(const BaseArgument* p_arg) const override
711  {
712  auto arg = *dynamic_cast<const Argument*>(p_arg);
713 
714  if(!(!(arg.IsReduceAdd() || NumDTensor > 0) &&
716  {
717  std::cout << "using workspace" << std::endl;
718  return arg.M * arg.N * arg.KBatch * sizeof(ReduceDataType);
719  }
720 
721  return 0;
722  }
723 };
724 
725 } // namespace device
726 } // namespace tensor_operation
727 } // namespace ck
#define GET_NXDL_PER_WAVE_IMPL
Definition: device_base.hpp:81
#define INVOKER_RUN_IMPL
Definition: device_base.hpp:94
float launch_and_time_kernel(const StreamConfig &stream_config, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args)
Definition: kernel_launch.hpp:14
__host__ constexpr __device__ T max(T x)
Definition: math.hpp:84
std::string getGemmSpecializationString(const GemmSpecialization &s)
Definition: gemm_specialization.hpp:32
GemmSpecialization
Definition: gemm_specialization.hpp:11
void flush_icache()
Definition: flush_cache.hpp:361
Definition: ck.hpp:268
typename tuple_element< I, TTuple >::type tuple_element_t
Definition: tuple.hpp:208
BlockGemmPipelineVersion
Definition: blkgemmpipe_scheduler.hpp:12
__global__ void kernel_gemm_xdl_cshuffle_v3_2lds(typename GridwiseGemm::Argument karg)
Definition: gridwise_gemm_xdl_cshuffle_streamk_v3.hpp:62
constexpr Tuple< Args &... > tie(Args &... args) noexcept
Definition: tuple.hpp:218
__host__ constexpr __device__ auto generate_sequence_v2(F &&f, Number< N >)
Definition: sequence_helper.hpp:25
constexpr __device__ index_t get_warp_size()
Definition: get_id.hpp:10
BlockGemmPipelineScheduler
Definition: blkgemmpipe_scheduler.hpp:25
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition: type.hpp:297
__global__ void kernel_gemm_xdl_cshuffle_v3(typename GridwiseGemm::Argument karg)
Definition: gridwise_gemm_xdl_cshuffle_streamk_v3.hpp:38
int32_t index_t
Definition: ck.hpp:299
const GenericPointer< typename T::ValueType > T2 value
Definition: pointer.h:1350
Definition: stream_config.hpp:10
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:716
CDataType * p_c_grid
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:760
__host__ __device__ bool IsReduceAdd() const
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:748
index_t N
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:695
index_t KBatch
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:700
index_t M
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:694
index_t K
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:696
__host__ void Print() const
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:673
"Universal" GEMM kernel with SplitK support.
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:247
Definition: integral_constant.hpp:20
Definition: reduction_operator.hpp:37
Definition: device_base.hpp:197
void * p_workspace_
Definition: device_base.hpp:204
Definition: device_base.hpp:208
Definition: device_gemm_xdl_cshuffle_v3r1.hpp:149
const std::array< const void *, NumDTensor > p_ds
Definition: device_gemm_xdl_cshuffle_v3r1.hpp:178
std::array< ck::index_t, NumDTensor > StrideDs
Definition: device_gemm_xdl_cshuffle_v3r1.hpp:179
Argument(const ADataType *p_a_grid_, const BDataType *p_b_grid_, const std::array< const void *, NumDTensor > p_ds_, CDataType *p_c_grid_, index_t M_, index_t N_, index_t K_, index_t StrideA_, index_t StrideB_, std::array< ck::index_t, NumDTensor > StrideDs_, index_t StrideC_, index_t k_batch_)
Definition: device_gemm_xdl_cshuffle_v3r1.hpp:150
Definition: device_gemm_xdl_cshuffle_v3r1.hpp:215
INVOKER_RUN_IMPL float Run(const BaseArgument *p_arg, const StreamConfig &stream_config=StreamConfig{}) override
Definition: device_gemm_xdl_cshuffle_v3r1.hpp:556
float RunReduce(const Argument &arg, const StreamConfig &stream_config=StreamConfig{})
Definition: device_gemm_xdl_cshuffle_v3r1.hpp:216
float RunImp(const Argument &arg_, const StreamConfig &stream_config=StreamConfig{})
Definition: device_gemm_xdl_cshuffle_v3r1.hpp:289
Definition: device_gemm_xdl_cshuffle_v3r1.hpp:87
static constexpr GET_NXDL_PER_WAVE_IMPL auto NXdlPerWave64
Definition: device_gemm_xdl_cshuffle_v3r1.hpp:89
DeviceReduceThreadWiseMultiD< ReduceDataType, DsDataType, GemmAccDataType, CDataType, 3, 1, ReduceAdd, PassThrough, OutElementwiseOperation, 256, CShuffleBlockTransferScalarPerVector_NPerBlock, 1, 0, CShuffleBlockTransferScalarPerVector_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, decltype(DsVectorLengthSequence)> DeviceReduceInstance
Definition: device_gemm_xdl_cshuffle_v3r1.hpp:211
std::unique_ptr< BaseArgument > MakeArgumentPointer(const void *p_a, const void *p_b, std::array< const void *, NumDTensor > p_ds, void *p_c, index_t M, index_t N, index_t K, index_t StrideA, index_t StrideB, std::array< ck::index_t, NumDTensor > StrideDs, index_t StrideC, index_t KBatch, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation) override
Definition: device_gemm_xdl_cshuffle_v3r1.hpp:629
static constexpr index_t NumDTensor
Definition: device_gemm_xdl_cshuffle_v3r1.hpp:92
size_t GetWorkSpaceSize(const BaseArgument *p_arg) const override
Definition: device_gemm_xdl_cshuffle_v3r1.hpp:710
static constexpr auto NXdlPerWave32
Definition: device_gemm_xdl_cshuffle_v3r1.hpp:90
static constexpr bool IsValidCompilationParameter()
Definition: device_gemm_xdl_cshuffle_v3r1.hpp:563
static constexpr auto DsVectorLengthSequence
Definition: device_gemm_xdl_cshuffle_v3r1.hpp:185
ck::reduce::Add ReduceAdd
Definition: device_gemm_xdl_cshuffle_v3r1.hpp:182
static bool IsSupportedArgument(const Argument &arg)
Definition: device_gemm_xdl_cshuffle_v3r1.hpp:569
std::unique_ptr< BaseInvoker > MakeInvokerPointer() override
Definition: device_gemm_xdl_cshuffle_v3r1.hpp:660
bool IsSupportedArgument(const BaseArgument *p_arg) override
Definition: device_gemm_xdl_cshuffle_v3r1.hpp:602
static auto MakeArgument(const ADataType *p_a, const BDataType *p_b, const std::array< const void *, NumDTensor > p_ds, CDataType *p_c, index_t M, index_t N, index_t K, index_t StrideA, index_t StrideB, std::array< ck::index_t, NumDTensor > StrideDs, index_t StrideC, index_t KBatch, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation)
Definition: device_gemm_xdl_cshuffle_v3r1.hpp:607
static auto MakeInvoker()
Definition: device_gemm_xdl_cshuffle_v3r1.hpp:626
CElementwiseOperation OutElementwiseOperation
Definition: device_gemm_xdl_cshuffle_v3r1.hpp:183
ck::tensor_operation::element_wise::PassThrough PassThrough
Definition: device_gemm_xdl_cshuffle_v3r1.hpp:94
std::string GetTypeString() const override
Definition: device_gemm_xdl_cshuffle_v3r1.hpp:666
Definition: device_gemm_v2.hpp:57
Definition: device_reduce_threadwise_multi_d.hpp:47
Definition: unary_element_wise_operation.hpp:334
Definition: flush_cache.hpp:283