/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/device/impl/device_cgemm_4gemm_xdl_cshuffle.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/device/impl/device_cgemm_4gemm_xdl_cshuffle.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/device/impl/device_cgemm_4gemm_xdl_cshuffle.hpp Source File
device_cgemm_4gemm_xdl_cshuffle.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include <iostream>
7 #include <sstream>
8 
21 
22 namespace ck {
23 namespace tensor_operation {
24 namespace device {
25 
26 template <
27  typename ALayout,
28  typename BLayout,
29  typename CLayout,
30  typename ADataType,
31  typename BDataType,
32  typename CDataType,
33  typename GemmAccDataType,
34  typename CShuffleDataType,
35  typename AElementwiseOperation,
36  typename BElementwiseOperation,
37  typename CElementwiseOperation,
38  GemmSpecialization GemmSpec,
39  index_t NumGemmKPrefetchStage,
40  index_t BlockSize,
41  index_t MPerBlock,
42  index_t NPerBlock,
43  index_t KPerBlock,
44  index_t AK1,
45  index_t BK1,
46  index_t MPerXDL,
47  index_t NPerXDL,
48  index_t MXdlPerWave,
49  index_t NXdlPerWave,
50  typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
51  typename ABlockTransferThreadClusterArrangeOrder,
52  typename ABlockTransferSrcAccessOrder,
53  index_t ABlockTransferSrcVectorDim,
54  index_t ABlockTransferSrcScalarPerVector,
55  index_t ABlockTransferDstScalarPerVector_AK1,
56  bool ABlockLdsExtraM,
57  typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
58  typename BBlockTransferThreadClusterArrangeOrder,
59  typename BBlockTransferSrcAccessOrder,
60  index_t BBlockTransferSrcVectorDim,
61  index_t BBlockTransferSrcScalarPerVector,
62  index_t BBlockTransferDstScalarPerVector_BK1,
63  bool BBlockLdsExtraN,
64  index_t CShuffleMXdlPerWavePerShuffle,
65  index_t CShuffleNXdlPerWavePerShuffle,
66  typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
67  index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
70  is_same_v<AElementwiseOperation, ck::tensor_operation::element_wise::PassThrough> &&
71  is_same_v<BElementwiseOperation, ck::tensor_operation::element_wise::PassThrough> &&
72  is_same_v<CElementwiseOperation, ck::tensor_operation::element_wise::PassThrough>,
73  bool> = false>
75  : public DeviceCGemm<AElementwiseOperation, BElementwiseOperation, CElementwiseOperation>
76 {
79  static constexpr auto NXdlPerWave64 = GetNXdlPerWave<true>();
80  static constexpr auto NXdlPerWave32 = GetNXdlPerWave<false>();
81 
82  static constexpr auto I0 = Number<0>{};
83  static constexpr auto I1 = Number<1>{};
84  static constexpr auto I2 = Number<2>{};
85 
86  static constexpr index_t MPerThread =
87  MPerBlock / CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock::At(1);
88  static constexpr index_t NPerThread =
89  NPerBlock / CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock::At(3);
90 
91  static constexpr auto AScalarPerVector = Number<4>{};
92  static constexpr auto BScalarPerVector = Number<4>{};
93  static constexpr auto CScalarPerVector = Number<4>{};
94 
95  template <typename Desc_M_N>
96  static auto PadDescriptor_M_N(Desc_M_N desc)
97  {
98  const auto M = desc.GetLength(I0);
99  const auto N = desc.GetLength(I1);
100  const auto pad_M = math::integer_divide_ceil(M, MPerThread) * MPerThread - M;
101  const auto pad_N = math::integer_divide_ceil(N, NPerThread) * NPerThread - N;
102 
103  const auto padded_desc = transform_tensor_descriptor(
104  desc,
108 
109  return padded_desc;
110  }
111 
112  static auto MakeDescriptor_M_N(const std::vector<index_t>& lengths,
113  const std::vector<index_t>& strides)
114  {
115  auto tupleOfShape = generate_tuple([&](auto I) { return lengths[I]; }, Number<2>{});
116  auto tupleOfStride = generate_tuple([&](auto I) { return strides[I]; }, Number<2>{});
117 
118  // nd desc - [s0, s1, s2, ...]
119  const auto desc = make_naive_tensor_descriptor(tupleOfShape, tupleOfStride);
120  return PadDescriptor_M_N(desc);
121  }
122 
123  // GridwiseGemm
124  template <index_t NXdlPerWave_>
126  ALayout,
127  BLayout,
128  CLayout,
129  ADataType,
130  BDataType,
131  GemmAccDataType,
132  CShuffleDataType,
133  CDataType,
134  AElementwiseOperation,
135  BElementwiseOperation,
136  CElementwiseOperation,
137  GemmSpec,
139  NumGemmKPrefetchStage,
140  BlockSize,
141  MPerBlock,
142  NPerBlock,
143  KPerBlock,
144  AK1,
145  BK1,
146  MPerXDL,
147  NPerXDL,
148  MXdlPerWave,
149  NXdlPerWave_,
150  ABlockTransferThreadClusterLengths_AK0_M_AK1,
151  ABlockTransferThreadClusterArrangeOrder,
152  ABlockTransferSrcAccessOrder,
153  ABlockTransferSrcVectorDim,
154  ABlockTransferSrcScalarPerVector,
155  ABlockTransferDstScalarPerVector_AK1,
156  false,
157  ABlockLdsExtraM,
158  BBlockTransferThreadClusterLengths_BK0_N_BK1,
159  BBlockTransferThreadClusterArrangeOrder,
160  BBlockTransferSrcAccessOrder,
161  BBlockTransferSrcVectorDim,
162  BBlockTransferSrcScalarPerVector,
163  BBlockTransferDstScalarPerVector_BK1,
164  false,
165  BBlockLdsExtraN,
166  CShuffleMXdlPerWavePerShuffle,
167  CShuffleNXdlPerWavePerShuffle,
168  CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
169  CShuffleBlockTransferScalarPerVector_NPerBlock,
170  LoopSched>;
173 
174  using CGridDesc_M_N = decltype(MakeDescriptor_M_N({1, 1}, {1, 1}));
175 
176  // Argument
178  {
180 
181  Argument(const ADataType* p_a_grid_real_,
182  const ADataType* p_a_grid_imag_,
183  const BDataType* p_b_grid_real_,
184  const BDataType* p_b_grid_imag_,
185  CDataType* p_c_grid_real_,
186  CDataType* p_c_grid_imag_,
187  CDataType* p_workspace,
188  index_t M_,
189  index_t N_,
190  index_t K_,
191  index_t StrideA_,
192  index_t StrideB_,
193  index_t StrideC_)
194  : Problem{M_, N_, K_, StrideA_, StrideB_, StrideC_},
195  p_a_grid_real{p_a_grid_real_},
196  p_a_grid_imag{p_a_grid_imag_},
197  p_b_grid_real{p_b_grid_real_},
198  p_b_grid_imag{p_b_grid_imag_},
199  p_c_grid_real{p_c_grid_real_},
200  p_c_grid_imag{p_c_grid_imag_},
201  p_aux_grid{p_workspace}
202  {
204  {
205  c_grid_desc_m_n = DeviceOp::MakeDescriptor_M_N({M_, N_}, {StrideC_, I1});
206  }
208  {
209  c_grid_desc_m_n = DeviceOp::MakeDescriptor_M_N({M_, N_}, {I1, StrideC_});
210  }
211 
212  p_aux_2_grid = p_workspace + GetCElementSpaceSize(M_, N_, StrideC_);
213  }
214 
215  // private:
216  const ADataType* p_a_grid_real;
217  const ADataType* p_a_grid_imag;
218  const BDataType* p_b_grid_real;
219  const BDataType* p_b_grid_imag;
220  CDataType* p_c_grid_real;
221  CDataType* p_c_grid_imag;
222  CDataType* p_aux_grid;
223  CDataType* p_aux_2_grid;
225  };
226 
227  // Invoker
228  struct Invoker : public BaseInvoker
229  {
230  template <typename GridwiseGemm>
231  float RunImp(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
232  {
233  if(stream_config.log_level_ > 0)
234  {
235  arg.Print();
236  }
237 
238  typename GridwiseGemm::Problem problem(
239  arg.M, arg.N, arg.K, arg.StrideA, arg.StrideB, arg.StrideC);
240  if(!GridwiseGemm::CheckValidity(problem))
241  {
242  throw std::runtime_error("wrong! GridwiseGemm has invalid setting");
243  }
244 
245  index_t gdx, gdy, gdz;
246  std::tie(gdx, gdy, gdz) = GridwiseGemm::CalculateGridSize(arg.M, arg.N);
247 
248  const auto K = GridwiseGemm::CalculateAK0(arg.K) * AK1;
249 
250  float ave_time = 0;
251 
254 
256 
261  Block2TileMap,
262  Add,
263  BlockSize,
264  MPerBlock,
265  NPerBlock,
266  MPerThread,
267  NPerThread,
271  I1,
272  I1>;
273 
274  using GridwiseBinSubtract =
279  Block2TileMap,
280  Subtract,
281  BlockSize,
282  MPerBlock,
283  NPerBlock,
284  MPerThread,
285  NPerThread,
289  I1,
290  I1>;
291 
292  const index_t M = arg.c_grid_desc_m_n.GetLength(I0);
293  const index_t N = arg.c_grid_desc_m_n.GetLength(I1);
294  const auto block_2_tile_map = Block2TileMap(M, N);
295 
296  const auto add_kernel = kernel_elementwise<GridwiseBinAdd,
301  Block2TileMap,
302  Add>;
303 
304  const auto subtract_kernel =
305  kernel_elementwise<GridwiseBinSubtract,
310  Block2TileMap,
311  Subtract>;
312 
313  if(GridwiseGemm::CalculateHasMainKBlockLoop(K))
314  {
315  const auto kernel = kernel_gemm_xdl_cshuffle_v1<GridwiseGemm,
316  ADataType,
317  BDataType,
318  CDataType,
319  true>;
320 
321  ave_time += launch_and_time_kernel(stream_config,
322  kernel,
323  dim3(gdx, gdy, gdz),
324  dim3(BlockSize),
325  0,
326  arg.p_a_grid_real,
327  arg.p_b_grid_real,
328  arg.p_aux_grid,
329  problem);
330 
331  ave_time += launch_and_time_kernel(stream_config,
332  kernel,
333  dim3(gdx, gdy, gdz),
334  dim3(BlockSize),
335  0,
336  arg.p_a_grid_imag,
337  arg.p_b_grid_imag,
338  arg.p_aux_2_grid,
339  problem);
340 
341  // c_real = aux - aux_2
342  ave_time += launch_and_time_kernel(
343  stream_config,
344  subtract_kernel,
345  dim3(gdx, gdy, gdz),
346  dim3(BlockSize),
347  0,
350  make_tuple(const_cast<const CDataType*>(arg.p_aux_grid),
351  const_cast<const CDataType*>(arg.p_aux_2_grid)),
353  block_2_tile_map,
354  Subtract{});
355 
356  ave_time += launch_and_time_kernel(stream_config,
357  kernel,
358  dim3(gdx, gdy, gdz),
359  dim3(BlockSize),
360  0,
361  arg.p_a_grid_real,
362  arg.p_b_grid_imag,
363  arg.p_aux_grid,
364  problem);
365 
366  ave_time += launch_and_time_kernel(stream_config,
367  kernel,
368  dim3(gdx, gdy, gdz),
369  dim3(BlockSize),
370  0,
371  arg.p_a_grid_imag,
372  arg.p_b_grid_real,
373  arg.p_aux_2_grid,
374  problem);
375 
376  // c_imag = aux + aux_2
377  ave_time += launch_and_time_kernel(
378  stream_config,
379  add_kernel,
380  dim3(gdx, gdy, gdz),
381  dim3(BlockSize),
382  0,
385  make_tuple(const_cast<const CDataType*>(arg.p_aux_grid),
386  const_cast<const CDataType*>(arg.p_aux_2_grid)),
388  block_2_tile_map,
389  Add{});
390  }
391  else
392  {
393  const auto kernel = kernel_gemm_xdl_cshuffle_v1<GridwiseGemm,
394  ADataType,
395  BDataType,
396  CDataType,
397  false>;
398 
399  ave_time += launch_and_time_kernel(stream_config,
400  kernel,
401  dim3(gdx, gdy, gdz),
402  dim3(BlockSize),
403  0,
404  arg.p_a_grid_real,
405  arg.p_b_grid_real,
406  arg.p_aux_grid,
407  problem);
408 
409  ave_time += launch_and_time_kernel(stream_config,
410  kernel,
411  dim3(gdx, gdy, gdz),
412  dim3(BlockSize),
413  0,
414  arg.p_a_grid_imag,
415  arg.p_b_grid_imag,
416  arg.p_aux_2_grid,
417  problem);
418 
419  // c_real = aux - aux_2
420  ave_time += launch_and_time_kernel(
421  stream_config,
422  subtract_kernel,
423  dim3(gdx, gdy, gdz),
424  dim3(BlockSize),
425  0,
428  make_tuple(const_cast<const CDataType*>(arg.p_aux_grid),
429  const_cast<const CDataType*>(arg.p_aux_2_grid)),
431  block_2_tile_map,
432  Subtract{});
433 
434  ave_time += launch_and_time_kernel(stream_config,
435  kernel,
436  dim3(gdx, gdy, gdz),
437  dim3(BlockSize),
438  0,
439  arg.p_a_grid_real,
440  arg.p_b_grid_imag,
441  arg.p_aux_grid,
442  problem);
443 
444  ave_time += launch_and_time_kernel(stream_config,
445  kernel,
446  dim3(gdx, gdy, gdz),
447  dim3(BlockSize),
448  0,
449  arg.p_a_grid_imag,
450  arg.p_b_grid_real,
451  arg.p_aux_2_grid,
452  problem);
453 
454  // c_imag = aux + aux_2
455  ave_time += launch_and_time_kernel(
456  stream_config,
457  add_kernel,
458  dim3(gdx, gdy, gdz),
459  dim3(BlockSize),
460  0,
463  make_tuple(const_cast<const CDataType*>(arg.p_aux_grid),
464  const_cast<const CDataType*>(arg.p_aux_2_grid)),
466  block_2_tile_map,
467  Add{});
468  }
469 
470  return ave_time;
471  }
472 
474 
475  // polymorphic
476  float Run(const BaseArgument* p_arg,
477  const StreamConfig& stream_config = StreamConfig{}) override
478  {
479  return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
480  }
481  };
482 
483  static constexpr bool IsValidCompilationParameter()
484  {
485  // TODO: properly implement this check
486  return true;
487  }
488 
489  static bool IsSupportedArgument(const Argument& arg)
490  {
491  if(!ck::is_xdl_wmma_supported<ADataType, BDataType, MPerXDL, NPerXDL>())
492  {
493  return false;
494  }
495  if(get_warp_size() == 64)
496  {
497  if constexpr(NXdlPerWave64 > 0)
498  {
499  return GridwiseGemm64::CheckValidity(arg);
500  }
501  }
502  else
503  {
504  if constexpr(NXdlPerWave32 > 0)
505  {
506  typename GridwiseGemm32::Problem problem(
507  arg.M, arg.N, arg.K, arg.StrideA, arg.StrideB, arg.StrideC);
508  return GridwiseGemm32::CheckValidity(problem);
509  }
510  }
511  return false;
512  }
513 
514  // polymorphic
515  bool IsSupportedArgument(const BaseArgument* p_arg) override
516  {
517  return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
518  }
519 
520  static auto MakeArgument(const ADataType* p_a_real,
521  const ADataType* p_a_imag,
522  const BDataType* p_b_real,
523  const BDataType* p_b_imag,
524  CDataType* p_c_real,
525  CDataType* p_c_imag,
526  CDataType* p_workspace,
527  index_t M,
528  index_t N,
529  index_t K,
530  index_t StrideA,
531  index_t StrideB,
532  index_t StrideC,
533  AElementwiseOperation,
534  BElementwiseOperation,
535  CElementwiseOperation)
536  {
537  return Argument{p_a_real,
538  p_a_imag,
539  p_b_real,
540  p_b_imag,
541  p_c_real,
542  p_c_imag,
543  p_workspace,
544  M,
545  N,
546  K,
547  StrideA,
548  StrideB,
549  StrideC};
550  }
551 
552  static auto MakeInvoker() { return Invoker{}; }
553 
554  // polymorphic
555  std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a_real,
556  const void* p_a_imag,
557  const void* p_b_real,
558  const void* p_b_imag,
559  void* p_c_real,
560  void* p_c_imag,
561  void* p_workspace,
562  index_t M,
563  index_t N,
564  index_t K,
565  index_t StrideA,
566  index_t StrideB,
567  index_t StrideC,
568  AElementwiseOperation,
569  BElementwiseOperation,
570  CElementwiseOperation,
571  index_t /* KBatch */ = 1) override
572  {
573  return std::make_unique<Argument>(static_cast<const ADataType*>(p_a_real),
574  static_cast<const ADataType*>(p_a_imag),
575  static_cast<const BDataType*>(p_b_real),
576  static_cast<const BDataType*>(p_b_imag),
577  static_cast<CDataType*>(p_c_real),
578  static_cast<CDataType*>(p_c_imag),
579  static_cast<CDataType*>(p_workspace),
580  M,
581  N,
582  K,
583  StrideA,
584  StrideB,
585  StrideC);
586  }
587 
588  // polymorphic
589  std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
590  {
591  return std::make_unique<Invoker>(Invoker{});
592  }
593 
594  // polymorphic
595  std::string GetTypeString() const override
596  {
597  auto str = std::stringstream();
598 
599  // clang-format off
600  str << "DeviceCGemm_4Gemm_Xdl_CShuffle"
601  << "<"
602  << BlockSize << ", "
603  << MPerBlock << ", "
604  << NPerBlock << ", "
605  << KPerBlock << ", "
606  << AK1 << ", "
607  << BK1
608  << ">";
609  // clang-format on
610 
611  return str.str();
612  }
613 
614  static std::size_t GetCElementSpaceSize(index_t M, index_t N, index_t StrideC)
615  {
616  const auto c_grid_desc_m_n =
619  N,
621  StrideC);
622 
623  return c_grid_desc_m_n.GetElementSpaceSize();
624  }
625 
626  std::size_t GetWorkspaceSize(index_t M,
627  index_t N,
628  [[maybe_unused]] index_t K,
629  [[maybe_unused]] index_t StrideA,
630  [[maybe_unused]] index_t StrideB,
631  index_t StrideC) const override
632  {
633  return 2 * sizeof(CDataType) * GetCElementSpaceSize(M, N, StrideC);
634  }
635 
636  std::size_t GetWorkSpaceSize(const BaseArgument* base_arg) const override
637  {
638  const auto* parg = dynamic_cast<const Argument*>(base_arg);
639 
640  if(!parg)
641  {
642  std::ostringstream err;
643  err << "Provided argument pointer is not of an Argument class!" << " In " << __FILE__
644  << ":" << __LINE__ << ", in function: " << __func__;
645  throw std::runtime_error(err.str());
646  }
647 
648  return GetWorkspaceSize(
649  parg->M, parg->N, parg->K, parg->StrideA, parg->StrideB, parg->StrideC);
650  }
651 };
652 
653 } // namespace device
654 } // namespace tensor_operation
655 } // namespace ck
#define GET_NXDL_PER_WAVE_IMPL
Definition: device_base.hpp:81
#define INVOKER_RUN_IMPL
Definition: device_base.hpp:94
float launch_and_time_kernel(const StreamConfig &stream_config, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args)
Definition: kernel_launch.hpp:14
__host__ constexpr __device__ auto integer_divide_ceil(X x, Y y)
Definition: math.hpp:72
__host__ constexpr __device__ T max(T x)
Definition: math.hpp:84
GemmSpecialization
Definition: gemm_specialization.hpp:11
Definition: ck.hpp:268
__host__ constexpr __device__ auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition: tensor_descriptor_helper.hpp:49
__host__ constexpr __device__ auto generate_tuple(F &&f, Number< N >)
Definition: tuple_helper.hpp:21
__global__ void kernel_gemm_xdl_cshuffle_v1(typename GridwiseGemm::Argument karg)
Definition: gridwise_gemm_xdl_cshuffle_v1.hpp:25
constexpr Tuple< Args &... > tie(Args &... args) noexcept
Definition: tuple.hpp:218
constexpr __device__ index_t get_warp_size()
Definition: get_id.hpp:10
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
LoopScheduler
Definition: loop_scheduler.hpp:15
int32_t index_t
Definition: ck.hpp:299
typename std::enable_if< B, T >::type enable_if_t
Definition: enable_if.hpp:27
__host__ constexpr __device__ auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition: tensor_descriptor.hpp:319
__host__ constexpr __device__ auto make_right_pad_transform(const LowLength &low_length, const RightPadLength &right_pad, integral_constant< bool, SkipIsValidCheck >=integral_constant< bool, false >{})
Definition: multi_index_transform_helper.hpp:37
__global__ void kernel_elementwise(const InGridDescTuple in_grid_desc_tuple, const OutGridDescTuple out_grid_desc_tuple, const InDataTypePointerTuple p_in_global_tuple, const OutDataTypePointerTuple p_out_global_tuple, const Block2TileMap block_2_tile_map, const ElementwiseOperation elementwise_op)
Definition: gridwise_elementwise_2d.hpp:29
constexpr LoopScheduler make_default_loop_scheduler()
Definition: loop_scheduler.hpp:20
Definition: stream_config.hpp:10
Definition: gridwise_elementwise_2d.hpp:278
Definition: gridwise_gemm_xdl_cshuffle_v1.hpp:421
index_t N
Definition: gridwise_gemm_xdl_cshuffle_v1.hpp:454
index_t StrideA
Definition: gridwise_gemm_xdl_cshuffle_v1.hpp:456
index_t StrideB
Definition: gridwise_gemm_xdl_cshuffle_v1.hpp:457
index_t K
Definition: gridwise_gemm_xdl_cshuffle_v1.hpp:455
__host__ void Print() const
Definition: gridwise_gemm_xdl_cshuffle_v1.hpp:444
index_t M
Definition: gridwise_gemm_xdl_cshuffle_v1.hpp:453
index_t StrideC
Definition: gridwise_gemm_xdl_cshuffle_v1.hpp:458
Definition: gridwise_gemm_xdl_cshuffle_v1.hpp:121
static constexpr __host__ bool CheckValidity(const Problem &problem)
Definition: gridwise_gemm_xdl_cshuffle_v1.hpp:571
static __host__ auto CalculateNPadded(index_t N)
Definition: gridwise_gemm_xdl_cshuffle_v1.hpp:149
static __host__ auto CalculateMPadded(index_t M)
Definition: gridwise_gemm_xdl_cshuffle_v1.hpp:144
__host__ static __device__ auto MakeCGridDescriptor_M_N(index_t M, index_t MPad, index_t N, index_t NPad, index_t StrideC)
Definition: gridwise_gemm_xdl_cshuffle_v1.hpp:368
Definition: sequence.hpp:43
Definition: tuple.hpp:117
Definition: integral_constant.hpp:20
Definition: type.hpp:177
Definition: device_base.hpp:197
Definition: device_base.hpp:208
Definition: device_cgemm_4gemm_xdl_cshuffle.hpp:178
CGridDesc_M_N c_grid_desc_m_n
Definition: device_cgemm_4gemm_xdl_cshuffle.hpp:224
CDataType * p_c_grid_real
Definition: device_cgemm_4gemm_xdl_cshuffle.hpp:220
const BDataType * p_b_grid_imag
Definition: device_cgemm_4gemm_xdl_cshuffle.hpp:219
typename GridwiseGemm64::Problem Problem
Definition: device_cgemm_4gemm_xdl_cshuffle.hpp:179
const ADataType * p_a_grid_imag
Definition: device_cgemm_4gemm_xdl_cshuffle.hpp:217
const ADataType * p_a_grid_real
Definition: device_cgemm_4gemm_xdl_cshuffle.hpp:216
CDataType * p_aux_grid
Definition: device_cgemm_4gemm_xdl_cshuffle.hpp:222
CDataType * p_aux_2_grid
Definition: device_cgemm_4gemm_xdl_cshuffle.hpp:223
CDataType * p_c_grid_imag
Definition: device_cgemm_4gemm_xdl_cshuffle.hpp:221
Argument(const ADataType *p_a_grid_real_, const ADataType *p_a_grid_imag_, const BDataType *p_b_grid_real_, const BDataType *p_b_grid_imag_, CDataType *p_c_grid_real_, CDataType *p_c_grid_imag_, CDataType *p_workspace, index_t M_, index_t N_, index_t K_, index_t StrideA_, index_t StrideB_, index_t StrideC_)
Definition: device_cgemm_4gemm_xdl_cshuffle.hpp:181
const BDataType * p_b_grid_real
Definition: device_cgemm_4gemm_xdl_cshuffle.hpp:218
Definition: device_cgemm_4gemm_xdl_cshuffle.hpp:229
float RunImp(const Argument &arg, const StreamConfig &stream_config=StreamConfig{})
Definition: device_cgemm_4gemm_xdl_cshuffle.hpp:231
INVOKER_RUN_IMPL float Run(const BaseArgument *p_arg, const StreamConfig &stream_config=StreamConfig{}) override
Definition: device_cgemm_4gemm_xdl_cshuffle.hpp:476
Definition: device_cgemm_4gemm_xdl_cshuffle.hpp:76
static constexpr auto I2
Definition: device_cgemm_4gemm_xdl_cshuffle.hpp:84
static std::size_t GetCElementSpaceSize(index_t M, index_t N, index_t StrideC)
Definition: device_cgemm_4gemm_xdl_cshuffle.hpp:614
static auto MakeDescriptor_M_N(const std::vector< index_t > &lengths, const std::vector< index_t > &strides)
Definition: device_cgemm_4gemm_xdl_cshuffle.hpp:112
bool IsSupportedArgument(const BaseArgument *p_arg) override
Definition: device_cgemm_4gemm_xdl_cshuffle.hpp:515
static constexpr bool IsValidCompilationParameter()
Definition: device_cgemm_4gemm_xdl_cshuffle.hpp:483
static auto MakeInvoker()
Definition: device_cgemm_4gemm_xdl_cshuffle.hpp:552
static constexpr auto NXdlPerWave32
Definition: device_cgemm_4gemm_xdl_cshuffle.hpp:80
std::size_t GetWorkspaceSize(index_t M, index_t N, [[maybe_unused]] index_t K, [[maybe_unused]] index_t StrideA, [[maybe_unused]] index_t StrideB, index_t StrideC) const override
Definition: device_cgemm_4gemm_xdl_cshuffle.hpp:626
static constexpr index_t MPerThread
Definition: device_cgemm_4gemm_xdl_cshuffle.hpp:86
decltype(MakeDescriptor_M_N({1, 1}, {1, 1})) CGridDesc_M_N
Definition: device_cgemm_4gemm_xdl_cshuffle.hpp:174
static constexpr auto I1
Definition: device_cgemm_4gemm_xdl_cshuffle.hpp:83
static constexpr auto I0
Definition: device_cgemm_4gemm_xdl_cshuffle.hpp:82
static constexpr auto CScalarPerVector
Definition: device_cgemm_4gemm_xdl_cshuffle.hpp:93
std::string GetTypeString() const override
Definition: device_cgemm_4gemm_xdl_cshuffle.hpp:595
static auto PadDescriptor_M_N(Desc_M_N desc)
Definition: device_cgemm_4gemm_xdl_cshuffle.hpp:96
static constexpr GET_NXDL_PER_WAVE_IMPL auto NXdlPerWave64
Definition: device_cgemm_4gemm_xdl_cshuffle.hpp:79
std::unique_ptr< BaseArgument > MakeArgumentPointer(const void *p_a_real, const void *p_a_imag, const void *p_b_real, const void *p_b_imag, void *p_c_real, void *p_c_imag, void *p_workspace, index_t M, index_t N, index_t K, index_t StrideA, index_t StrideB, index_t StrideC, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, index_t=1) override
Definition: device_cgemm_4gemm_xdl_cshuffle.hpp:555
static constexpr auto BScalarPerVector
Definition: device_cgemm_4gemm_xdl_cshuffle.hpp:92
static auto MakeArgument(const ADataType *p_a_real, const ADataType *p_a_imag, const BDataType *p_b_real, const BDataType *p_b_imag, CDataType *p_c_real, CDataType *p_c_imag, CDataType *p_workspace, index_t M, index_t N, index_t K, index_t StrideA, index_t StrideB, index_t StrideC, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation)
Definition: device_cgemm_4gemm_xdl_cshuffle.hpp:520
static bool IsSupportedArgument(const Argument &arg)
Definition: device_cgemm_4gemm_xdl_cshuffle.hpp:489
std::size_t GetWorkSpaceSize(const BaseArgument *base_arg) const override
Definition: device_cgemm_4gemm_xdl_cshuffle.hpp:636
static constexpr index_t NPerThread
Definition: device_cgemm_4gemm_xdl_cshuffle.hpp:88
static constexpr auto AScalarPerVector
Definition: device_cgemm_4gemm_xdl_cshuffle.hpp:91
std::unique_ptr< BaseInvoker > MakeInvokerPointer() override
Definition: device_cgemm_4gemm_xdl_cshuffle.hpp:589
Definition: device_cgemm.hpp:15
Definition: binary_element_wise_operation.hpp:14
Definition: binary_element_wise_operation.hpp:237