/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3.hpp Source File
device_gemm_xdl_cshuffle_v3.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include <iostream>
7 #include <sstream>
8 
19 
20 namespace ck {
21 namespace tensor_operation {
22 namespace device {
23 
123 template <typename ALayout,
124  typename BLayout,
125  typename CLayout,
126  typename ADataType,
127  typename BDataType,
128  typename CDataType,
129  typename GemmAccDataType,
130  typename CShuffleDataType,
131  typename AElementwiseOperation,
132  typename BElementwiseOperation,
133  typename CElementwiseOperation,
134  GemmSpecialization GemmSpec,
135  index_t BlockSize,
136  index_t MPerBlock,
137  index_t NPerBlock,
138  index_t KPerBlock,
139  index_t AK1,
140  index_t BK1,
141  index_t MPerXDL,
142  index_t NPerXDL,
143  index_t MXdlPerWave,
144  index_t NXdlPerWave,
145  typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
146  typename ABlockTransferThreadClusterArrangeOrder,
147  typename ABlockTransferSrcAccessOrder,
148  index_t ABlockTransferSrcVectorDim,
149  index_t ABlockTransferSrcScalarPerVector,
150  index_t ABlockTransferDstScalarPerVector_AK1,
151  bool ABlockLdsExtraM,
152  typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
153  typename BBlockTransferThreadClusterArrangeOrder,
154  typename BBlockTransferSrcAccessOrder,
155  index_t BBlockTransferSrcVectorDim,
156  index_t BBlockTransferSrcScalarPerVector,
157  index_t BBlockTransferDstScalarPerVector_BK1,
158  bool BBlockLdsExtraN,
159  index_t CShuffleMXdlPerWavePerShuffle,
160  index_t CShuffleNXdlPerWavePerShuffle,
161  typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
162  index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
165  typename ComputeTypeA = CDataType,
166  typename ComputeTypeB = ComputeTypeA,
167  bool PermuteA = false,
168  bool PermuteB = false>
169 struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout,
170  BLayout,
171  CLayout,
172  ADataType,
173  BDataType,
174  CDataType,
175  AElementwiseOperation,
176  BElementwiseOperation,
177  CElementwiseOperation>
178 {
179  // GridwiseGemm
181  static constexpr auto NXdlPerWave64 = GetNXdlPerWave<true>();
182  static constexpr auto NXdlPerWave32 = GetNXdlPerWave<false>();
183 
184  template <index_t NXdlPerWave_>
186  ALayout,
187  BLayout,
188  CLayout,
189  ADataType,
190  BDataType,
191  GemmAccDataType,
192  CShuffleDataType,
193  CDataType,
194  AElementwiseOperation,
195  BElementwiseOperation,
196  CElementwiseOperation,
197  GemmSpec,
198  BlockSize,
199  MPerBlock,
200  NPerBlock,
201  KPerBlock,
202  AK1,
203  BK1,
204  MPerXDL,
205  NPerXDL,
206  MXdlPerWave,
207  NXdlPerWave_,
208  ABlockTransferThreadClusterLengths_AK0_M_AK1,
209  ABlockTransferThreadClusterArrangeOrder,
210  ABlockTransferSrcAccessOrder,
211  ABlockTransferSrcVectorDim,
212  ABlockTransferSrcScalarPerVector,
213  ABlockTransferDstScalarPerVector_AK1,
214  false,
215  ABlockLdsExtraM,
216  BBlockTransferThreadClusterLengths_BK0_N_BK1,
217  BBlockTransferThreadClusterArrangeOrder,
218  BBlockTransferSrcAccessOrder,
219  BBlockTransferSrcVectorDim,
220  BBlockTransferSrcScalarPerVector,
221  BBlockTransferDstScalarPerVector_BK1,
222  false,
223  BBlockLdsExtraN,
224  CShuffleMXdlPerWavePerShuffle,
225  CShuffleNXdlPerWavePerShuffle,
226  CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
227  CShuffleBlockTransferScalarPerVector_NPerBlock,
228  BlkGemmPipeSched,
229  BlkGemmPipelineVer,
230  ComputeTypeA,
231  ComputeTypeB,
232  PermuteA,
233  PermuteB>;
236 
238 
239  static constexpr index_t APackedSize = []() {
241  return 2;
242  else
243  return 1;
244  }();
245 
246  static constexpr index_t BPackedSize = []() {
248  return 2;
249  else
250  return 1;
251  }();
252 
262  struct Invoker : public BaseInvoker
263  {
269  template <typename GridwiseGemm>
270  float RunImp(const typename GridwiseGemm::Argument& arg,
271  const StreamConfig& stream_config = StreamConfig{})
272  {
273  if(stream_config.log_level_ > 0)
274  {
275  arg.Print();
276  GridwiseGemm::BlockwiseGemmPipe::HotLoopInstList::Print();
277  }
278 
279  if(!GridwiseGemm::CheckValidity(arg))
280  {
281  throw std::runtime_error("wrong! GridwiseGemm has invalid setting");
282  }
283 
284  index_t gdx, gdy, gdz;
285  std::tie(gdx, gdy, gdz) = GridwiseGemm::CalculateGridSize(arg.M, arg.N, arg.KBatch);
286 
287  float ave_time = 0;
288 
289  index_t k_grain = arg.KBatch * KPerBlock;
290  index_t K_split = (arg.K + k_grain - 1) / k_grain * KPerBlock;
291 
292  const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K_split);
293 
294  const auto Run = [&](const auto& kernel) {
295  if(stream_config.flush_cache)
296  {
297  auto arg_ = arg;
298 
299  const auto a_grid_desc_ak0_m_ak1 = GridwiseGemm::MakeAGridDescriptor_AK0_M_AK1(
300  arg_.M, arg_.MPadded, arg_.K, arg_.KPadded, arg_.StrideA, arg_.AK0);
301  const auto b_grid_desc_bk0_n_bk1 = GridwiseGemm::MakeBGridDescriptor_BK0_N_BK1(
302  arg_.K, arg_.KPadded, arg_.N, arg_.NPadded, arg_.StrideB, arg_.BK0);
303 
304  auto size_a_buffer = a_grid_desc_ak0_m_ak1.GetElementSpaceSize() *
305  sizeof(ADataType) / APackedSize;
306  auto size_b_buffer = b_grid_desc_bk0_n_bk1.GetElementSpaceSize() *
307  sizeof(BDataType) / BPackedSize;
308 
310  arg_, stream_config.rotating_count, size_a_buffer, size_b_buffer);
311  rotating_mem.Print();
312 
313  auto run_flush_cache = [&]() {
314  // flush icache
316  // rotating mem
317  rotating_mem.Next();
318  // clear c mem
319  if(arg_.KBatch > 1)
320  hipGetErrorString(hipMemsetAsync(arg_.p_c_grid,
321  0,
322  arg_.M * arg_.N * sizeof(CDataType),
323  stream_config.stream_id_));
324  };
325 
326  ave_time = ck::utility::launch_and_time_kernel_with_preprocess<false>(
327  stream_config,
328  run_flush_cache,
329  kernel,
330  dim3(gdx, gdy, gdz),
331  dim3(BlockSize),
332  0,
333  arg_);
334  }
335  else
336  {
337  if(arg.KBatch > 1)
338  hipGetErrorString(hipMemsetAsync(arg.p_c_grid,
339  0,
340  arg.M * arg.N * sizeof(CDataType),
341  stream_config.stream_id_));
342 
343  ave_time = launch_and_time_kernel(
344  stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, arg);
345  }
346  };
347 
348  constexpr index_t minimum_occupancy = []() {
349  if constexpr(BlkGemmPipeSched == BlockGemmPipelineScheduler::Interwave)
350  {
351  return 2;
352  }
353  else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3)
354  {
355  return (MPerBlock * NPerBlock / BlockSize <= 128) ? 2 : 1;
356  }
357  else
358  {
359  return 1;
360  }
361  }();
362 
363  if(has_main_k_block_loop)
364  {
365  // Tail number always full
366  if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1 ||
367  BlkGemmPipelineVer == BlockGemmPipelineVersion::v3)
368  {
369  if(arg.KBatch > 1)
370  {
371  const auto kernel =
372  kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
373  true,
375  minimum_occupancy>;
376  Run(kernel);
377  }
378  else
379  {
380  const auto kernel =
381  kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
382  true,
384  minimum_occupancy>;
385  Run(kernel);
386  }
387  }
388  // Tail number could be One to Seven
389  else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v2)
390  {
391  if(arg.KBatch > 1)
392  {
393  if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::One)
394  {
395  const auto kernel =
396  kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
397  true,
399  minimum_occupancy,
401  Run(kernel);
402  }
403  else if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
405  {
406  const auto kernel =
407  kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
408  true,
410  minimum_occupancy,
412  Run(kernel);
413  }
414 
415  if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 2)
416  {
417  if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Two)
418  {
419  const auto kernel = kernel_gemm_xdl_cshuffle_v3<
420  GridwiseGemm,
421  true,
423  minimum_occupancy,
425  Run(kernel);
426  }
427  }
428 
429  if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 3)
430  {
431  if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
433  {
434  const auto kernel = kernel_gemm_xdl_cshuffle_v3<
435  GridwiseGemm,
436  true,
438  minimum_occupancy,
440  Run(kernel);
441  }
442  }
443 
444  if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 4)
445  {
446  if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
448  {
449  const auto kernel = kernel_gemm_xdl_cshuffle_v3<
450  GridwiseGemm,
451  true,
453  minimum_occupancy,
455  Run(kernel);
456  }
457  }
458 
459  if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 5)
460  {
461  if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
463  {
464  const auto kernel = kernel_gemm_xdl_cshuffle_v3<
465  GridwiseGemm,
466  true,
468  minimum_occupancy,
470  Run(kernel);
471  }
472  }
473 
474  if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 6)
475  {
476  if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Six)
477  {
478  const auto kernel = kernel_gemm_xdl_cshuffle_v3<
479  GridwiseGemm,
480  true,
482  minimum_occupancy,
484  Run(kernel);
485  }
486  }
487 
488  if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 7)
489  {
490  if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
492  {
493  const auto kernel = kernel_gemm_xdl_cshuffle_v3<
494  GridwiseGemm,
495  true,
497  minimum_occupancy,
499  Run(kernel);
500  }
501  }
502  }
503  else
504  {
505  if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::One)
506  {
507  const auto kernel =
508  kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
509  true,
511  minimum_occupancy,
513  Run(kernel);
514  }
515  else if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
517  {
518  const auto kernel =
519  kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
520  true,
522  minimum_occupancy,
524  Run(kernel);
525  }
526 
527  if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 2)
528  {
529  if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Two)
530  {
531  const auto kernel =
532  kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
533  true,
535  minimum_occupancy,
537  Run(kernel);
538  }
539  }
540 
541  if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 3)
542  {
543  if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
545  {
546  const auto kernel =
547  kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
548  true,
550  minimum_occupancy,
552  Run(kernel);
553  }
554  }
555 
556  if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 4)
557  {
558  if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
560  {
561  const auto kernel =
562  kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
563  true,
565  minimum_occupancy,
567  Run(kernel);
568  }
569  }
570 
571  if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 5)
572  {
573  if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
575  {
576  const auto kernel =
577  kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
578  true,
580  minimum_occupancy,
582  Run(kernel);
583  }
584  }
585 
586  if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 6)
587  {
588  if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Six)
589  {
590  const auto kernel =
591  kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
592  true,
594  minimum_occupancy,
596  Run(kernel);
597  }
598  }
599 
600  if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 7)
601  {
602  if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
604  {
605  const auto kernel =
606  kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
607  true,
609  minimum_occupancy,
611  Run(kernel);
612  }
613  }
614  }
615  }
616  // Tail number could be Odd or Even
617  else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v4)
618  {
619  if(arg.KBatch > 1)
620  {
621  if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
622  {
623  const auto kernel = kernel_gemm_xdl_cshuffle_v3_2lds<
624  GridwiseGemm,
625  true,
627  minimum_occupancy,
629  Run(kernel);
630  }
631  else
632  {
633  const auto kernel = kernel_gemm_xdl_cshuffle_v3_2lds<
634  GridwiseGemm,
635  true,
637  minimum_occupancy,
639  Run(kernel);
640  }
641  }
642  else
643  {
644  if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
645  {
646  const auto kernel =
648  true,
650  minimum_occupancy,
652  Run(kernel);
653  }
654  else
655  {
656  const auto kernel =
658  true,
660  minimum_occupancy,
662  Run(kernel);
663  }
664  }
665  }
666  else
667  {
668  if(arg.KBatch > 1)
669  {
670  if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
671  {
672  const auto kernel =
673  kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
674  true,
676  minimum_occupancy,
678  Run(kernel);
679  }
680  else
681  {
682  const auto kernel =
683  kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
684  true,
686  minimum_occupancy,
688  Run(kernel);
689  }
690  }
691  else
692  {
693  if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
694  {
695  const auto kernel =
696  kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
697  true,
699  minimum_occupancy,
701  Run(kernel);
702  }
703  else
704  {
705  const auto kernel =
706  kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
707  true,
709  minimum_occupancy,
711  Run(kernel);
712  }
713  }
714  }
715  }
716  else
717  {
718  // Tail number always 1
719  if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1)
720  {
721  if(arg.KBatch > 1)
722  {
723  const auto kernel =
724  kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
725  false,
727  minimum_occupancy>;
728  Run(kernel);
729  }
730  else
731  {
732  const auto kernel =
733  kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
734  false,
736  minimum_occupancy>;
737  Run(kernel);
738  }
739  }
740  }
741 
742  return ave_time;
743  }
744 
746  // polymorphic
747  float Run(const BaseArgument* p_arg,
748  const StreamConfig& stream_config = StreamConfig{}) override
749  {
750  return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
751  }
752  };
753 
754  static constexpr bool IsValidCompilationParameter()
755  {
756  // TODO: properly implement this check
757  return true;
758  }
759 
760  static bool IsSupportedArgument(const Argument& arg)
761  {
762  if(!ck::is_xdl_wmma_supported<ComputeTypeA, ComputeTypeB, MPerXDL, NPerXDL>())
763  {
764  return false;
765  }
766  if(arg.KBatch > 1)
767  {
768  if(is_gfx11_supported())
769  {
770  return false;
771  }
772 
773  if(!is_bf16_atomic_supported() && std::is_same_v<CDataType, ck::bhalf_t>)
774  {
775  return false;
776  }
777 
778  if(sizeof(CDataType) == 1)
779  {
780  return false;
781  }
782  }
783 
784  if(is_gfx11_supported())
785  {
786  if constexpr(std::is_same_v<ADataType, ck::f8_t> ||
787  std::is_same_v<ADataType, ck::bf8_t>)
788  {
789  return false;
790  }
791  }
792 
793  if((arg.K % AK1 != 0 || arg.K % BK1 != 0) && !(GemmSpec == GemmSpecialization::MKPadding ||
794  GemmSpec == GemmSpecialization::NKPadding ||
795  GemmSpec == GemmSpecialization::MNKPadding ||
796  GemmSpec == GemmSpecialization::KPadding))
797  {
798  return false;
799  }
800 
801  if(get_warp_size() == 64)
802  {
803  if constexpr(NXdlPerWave64 > 0)
804  {
805  return GridwiseGemm64::CheckValidity(arg);
806  }
807  }
808  else
809  {
810  if constexpr(NXdlPerWave32 > 0)
811  {
813  reinterpret_cast<const typename GridwiseGemm32::Argument&>(arg));
814  }
815  }
816  return false;
817  }
818 
819  // polymorphic
820  bool IsSupportedArgument(const BaseArgument* p_arg) override
821  {
822  return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
823  }
824 
825  index_t GetKPerBlock() override { return KPerBlock; }
826 
827  bool GetPermuteA() override { return PermuteA; }
828  bool GetPermuteB() override { return PermuteB; }
829 
830  static auto MakeArgument(const ADataType* p_a,
831  const BDataType* p_b,
832  CDataType* p_c,
833  index_t M,
834  index_t N,
835  index_t K,
836  index_t StrideA,
837  index_t StrideB,
838  index_t StrideC,
839  index_t KBatch,
840  AElementwiseOperation,
841  BElementwiseOperation,
842  CElementwiseOperation)
843  {
844  return Argument{p_a, p_b, p_c, M, N, K, StrideA, StrideB, StrideC, KBatch};
845  }
846 
847  static auto MakeInvoker() { return Invoker{}; }
848 
849  // polymorphic
850  std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a,
851  const void* p_b,
852  void* p_c,
853  index_t M,
854  index_t N,
855  index_t K,
856  index_t StrideA,
857  index_t StrideB,
858  index_t StrideC,
859  index_t KBatch,
860  AElementwiseOperation,
861  BElementwiseOperation,
862  CElementwiseOperation) override
863  {
864  return std::make_unique<Argument>(static_cast<const ADataType*>(p_a),
865  static_cast<const BDataType*>(p_b),
866  static_cast<CDataType*>(p_c),
867  M,
868  N,
869  K,
870  StrideA,
871  StrideB,
872  StrideC,
873  KBatch);
874  }
875 
876  // polymorphic
877  std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
878  {
879  return std::make_unique<Invoker>(Invoker{});
880  }
881 
882  // polymorphic
883  std::string GetTypeString() const override
884  {
885  auto str = std::stringstream();
886 
887  std::map<BlockGemmPipelineScheduler, std::string> BlkGemmPipelineSchedulerToString{
890 
891  std::map<BlockGemmPipelineVersion, std::string> BlkGemmPipelineVersionToString{
897 
898  index_t PrefetchStages = 0;
899  index_t AMmaKStride = 0;
900  if(get_warp_size() == 64)
901  {
902  if constexpr(NXdlPerWave64 > 0)
903  {
904  PrefetchStages = GridwiseGemm64::BlockwiseGemmPipe::PrefetchStages;
905  AMmaKStride = GridwiseGemm64::BlockwiseGemmPipe::AMmaKStride;
906  }
907  }
908  else
909  {
910  if constexpr(NXdlPerWave32 > 0)
911  {
912  PrefetchStages = GridwiseGemm32::BlockwiseGemmPipe::PrefetchStages;
913  AMmaKStride = GridwiseGemm32::BlockwiseGemmPipe::AMmaKStride;
914  }
915  }
916 
917  // clang-format off
918  str << "DeviceGemmXdlUniversal"
919  << "<"
920  << getGemmSpecializationString(GemmSpec) << ", "
921  << std::string(ALayout::name)[0]
922  << std::string(BLayout::name)[0]
923  << std::string(CLayout::name)[0]
924  << ">"
925  << " BlkSize: "
926  << BlockSize << ", "
927  << "BlkTile: "
928  << MPerBlock<<"x"<<NPerBlock<<"x"<<KPerBlock << ", "
929  << "WaveTile: "
930  << MPerXDL<<"x"<<NPerXDL << ", "
931  << "WaveMap: "
932  << MXdlPerWave<<"x" << NXdlPerWave<<", "
933  << "VmemReadVec: "
934  << ABlockTransferSrcScalarPerVector<<"x"<<BBlockTransferSrcScalarPerVector<<", "
935  << "BlkGemmPipelineScheduler: "
936  << BlkGemmPipelineSchedulerToString[BlkGemmPipeSched] << ", "
937  << "BlkGemmPipelineVersion: "
938  << BlkGemmPipelineVersionToString[BlkGemmPipelineVer] << ", "
939  << "BlkGemmPipelinePrefetchStages: "
940  << PrefetchStages << ", "
941  << "Kpack: "
942  << AMmaKStride;
943  // clang-format on
944 
945  return str.str();
946  }
948 };
949 
950 } // namespace device
951 } // namespace tensor_operation
952 } // namespace ck
#define INVOKER_RUN3_IMPL
Definition: device_base.hpp:114
#define REGISTER_EXTRA_PRINTING_METHODS
Definition: device_base.hpp:47
#define GET_NXDL_PER_WAVE_IMPL
Definition: device_base.hpp:81
float launch_and_time_kernel(const StreamConfig &stream_config, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args)
Definition: kernel_launch.hpp:14
__host__ constexpr __device__ T max(T x)
Definition: math.hpp:84
std::string getGemmSpecializationString(const GemmSpecialization &s)
Definition: gemm_specialization.hpp:32
GemmSpecialization
Definition: gemm_specialization.hpp:11
void flush_icache()
Definition: flush_cache.hpp:361
Definition: ck.hpp:268
BlockGemmPipelineVersion
Definition: blkgemmpipe_scheduler.hpp:12
__global__ void kernel_gemm_xdl_cshuffle_v3_2lds(typename GridwiseGemm::Argument karg)
Definition: gridwise_gemm_xdl_cshuffle_streamk_v3.hpp:62
constexpr Tuple< Args &... > tie(Args &... args) noexcept
Definition: tuple.hpp:218
constexpr __device__ index_t get_warp_size()
Definition: get_id.hpp:10
constexpr bool is_same_v
Definition: type.hpp:283
BlockGemmPipelineScheduler
Definition: blkgemmpipe_scheduler.hpp:25
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition: type.hpp:297
__global__ void kernel_gemm_xdl_cshuffle_v3(typename GridwiseGemm::Argument karg)
Definition: gridwise_gemm_xdl_cshuffle_streamk_v3.hpp:38
int32_t index_t
Definition: ck.hpp:299
bool is_gfx11_supported()
Definition: device_prop.hpp:60
bool is_bf16_atomic_supported()
Definition: device_prop.hpp:108
Definition: stream_config.hpp:10
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:716
"Universal" GEMM kernel with SplitK support.
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:247
static constexpr __host__ bool CheckValidity(const Argument &karg)
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:1202
Definition: data_type.hpp:187
Definition: device_base.hpp:197
Definition: device_base.hpp:208
Helper structure responsible for kernel invocation.
Definition: device_gemm_xdl_cshuffle_v3.hpp:263
float RunImp(const typename GridwiseGemm::Argument &arg, const StreamConfig &stream_config=StreamConfig{})
This function issues GPU kernel execution.
Definition: device_gemm_xdl_cshuffle_v3.hpp:270
INVOKER_RUN3_IMPL float Run(const BaseArgument *p_arg, const StreamConfig &stream_config=StreamConfig{}) override
Definition: device_gemm_xdl_cshuffle_v3.hpp:747
"Universal" GEMM operation with SplitK support.
Definition: device_gemm_xdl_cshuffle_v3.hpp:178
static constexpr GET_NXDL_PER_WAVE_IMPL auto NXdlPerWave64
Definition: device_gemm_xdl_cshuffle_v3.hpp:181
bool IsSupportedArgument(const BaseArgument *p_arg) override
Definition: device_gemm_xdl_cshuffle_v3.hpp:820
bool GetPermuteA() override
Definition: device_gemm_xdl_cshuffle_v3.hpp:827
std::unique_ptr< BaseInvoker > MakeInvokerPointer() override
Definition: device_gemm_xdl_cshuffle_v3.hpp:877
index_t GetKPerBlock() override
Definition: device_gemm_xdl_cshuffle_v3.hpp:825
static bool IsSupportedArgument(const Argument &arg)
Definition: device_gemm_xdl_cshuffle_v3.hpp:760
static constexpr bool IsValidCompilationParameter()
Definition: device_gemm_xdl_cshuffle_v3.hpp:754
std::string GetTypeString() const override
Definition: device_gemm_xdl_cshuffle_v3.hpp:883
static auto MakeArgument(const ADataType *p_a, const BDataType *p_b, CDataType *p_c, index_t M, index_t N, index_t K, index_t StrideA, index_t StrideB, index_t StrideC, index_t KBatch, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation)
Definition: device_gemm_xdl_cshuffle_v3.hpp:830
bool GetPermuteB() override
Definition: device_gemm_xdl_cshuffle_v3.hpp:828
static auto MakeInvoker()
Definition: device_gemm_xdl_cshuffle_v3.hpp:847
std::unique_ptr< BaseArgument > MakeArgumentPointer(const void *p_a, const void *p_b, void *p_c, index_t M, index_t N, index_t K, index_t StrideA, index_t StrideB, index_t StrideC, index_t KBatch, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation) override
Definition: device_gemm_xdl_cshuffle_v3.hpp:850
static constexpr index_t BPackedSize
Definition: device_gemm_xdl_cshuffle_v3.hpp:246
typename GridwiseGemm64::Argument Argument
Definition: device_gemm_xdl_cshuffle_v3.hpp:237
static constexpr auto NXdlPerWave32
Definition: device_gemm_xdl_cshuffle_v3.hpp:182
static constexpr index_t APackedSize
Definition: device_gemm_xdl_cshuffle_v3.hpp:239
Definition: device_gemm_v2.hpp:22
Definition: flush_cache.hpp:283