/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3.hpp Source File
device_gemm_xdl_cshuffle_v3.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include <iostream>
7 #include <sstream>
8 
19 
20 namespace ck {
21 namespace tensor_operation {
22 namespace device {
23 
123 template <typename ALayout,
124  typename BLayout,
125  typename CLayout,
126  typename ADataType,
127  typename BDataType,
128  typename CDataType,
129  typename GemmAccDataType,
130  typename CShuffleDataType,
131  typename AElementwiseOperation,
132  typename BElementwiseOperation,
133  typename CElementwiseOperation,
134  GemmSpecialization GemmSpec,
135  index_t BlockSize,
136  index_t MPerBlock,
137  index_t NPerBlock,
138  index_t KPerBlock,
139  index_t AK1,
140  index_t BK1,
141  index_t MPerXDL,
142  index_t NPerXDL,
143  index_t MXdlPerWave,
144  index_t NXdlPerWave,
145  typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
146  typename ABlockTransferThreadClusterArrangeOrder,
147  typename ABlockTransferSrcAccessOrder,
148  index_t ABlockTransferSrcVectorDim,
149  index_t ABlockTransferSrcScalarPerVector,
150  index_t ABlockTransferDstScalarPerVector_AK1,
151  bool ABlockLdsExtraM,
152  typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
153  typename BBlockTransferThreadClusterArrangeOrder,
154  typename BBlockTransferSrcAccessOrder,
155  index_t BBlockTransferSrcVectorDim,
156  index_t BBlockTransferSrcScalarPerVector,
157  index_t BBlockTransferDstScalarPerVector_BK1,
158  bool BBlockLdsExtraN,
159  index_t CShuffleMXdlPerWavePerShuffle,
160  index_t CShuffleNXdlPerWavePerShuffle,
161  typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
162  index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
165  typename ComputeTypeA = CDataType,
166  typename ComputeTypeB = ComputeTypeA,
167  bool PermuteA = false,
168  bool PermuteB = false>
169 struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout,
170  BLayout,
171  CLayout,
172  ADataType,
173  BDataType,
174  CDataType,
175  AElementwiseOperation,
176  BElementwiseOperation,
177  CElementwiseOperation>
178 {
179  template <bool isWave64>
180  static constexpr auto GetNXdlPerWave()
181  {
182  constexpr index_t Waves = isWave64 ? BlockSize / 64 : BlockSize / 32;
183  constexpr index_t MWaves = MPerBlock / (MXdlPerWave * MPerXDL);
184  static_assert(MWaves > 0);
185 
186  constexpr index_t NWaves = Waves / MWaves;
187  if constexpr(NWaves == 0)
188  {
189  return 0;
190  }
191  else
192  {
193  if constexpr(NPerBlock % (NPerXDL * NWaves) == 0)
194  {
195  return NPerBlock / (NWaves * NPerXDL);
196  }
197  else
198  {
199  return 0;
200  }
201  }
202  }
203  // GridwiseGemm
204  static constexpr auto NXdlPerWave64 = GetNXdlPerWave<true>();
205  static constexpr auto NXdlPerWave32 = GetNXdlPerWave<false>();
206 
207  template <index_t NXdlPerWave_>
209  ALayout,
210  BLayout,
211  CLayout,
212  ADataType,
213  BDataType,
214  GemmAccDataType,
215  CShuffleDataType,
216  CDataType,
217  AElementwiseOperation,
218  BElementwiseOperation,
219  CElementwiseOperation,
220  GemmSpec,
221  BlockSize,
222  MPerBlock,
223  NPerBlock,
224  KPerBlock,
225  AK1,
226  BK1,
227  MPerXDL,
228  NPerXDL,
229  MXdlPerWave,
230  NXdlPerWave_,
231  ABlockTransferThreadClusterLengths_AK0_M_AK1,
232  ABlockTransferThreadClusterArrangeOrder,
233  ABlockTransferSrcAccessOrder,
234  ABlockTransferSrcVectorDim,
235  ABlockTransferSrcScalarPerVector,
236  ABlockTransferDstScalarPerVector_AK1,
237  false,
238  ABlockLdsExtraM,
239  BBlockTransferThreadClusterLengths_BK0_N_BK1,
240  BBlockTransferThreadClusterArrangeOrder,
241  BBlockTransferSrcAccessOrder,
242  BBlockTransferSrcVectorDim,
243  BBlockTransferSrcScalarPerVector,
244  BBlockTransferDstScalarPerVector_BK1,
245  false,
246  BBlockLdsExtraN,
247  CShuffleMXdlPerWavePerShuffle,
248  CShuffleNXdlPerWavePerShuffle,
249  CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
250  CShuffleBlockTransferScalarPerVector_NPerBlock,
251  BlkGemmPipeSched,
252  BlkGemmPipelineVer,
253  ComputeTypeA,
254  ComputeTypeB,
255  PermuteA,
256  PermuteB>;
259 
261 
262  static constexpr index_t APackedSize = []() {
264  return 2;
265  else
266  return 1;
267  }();
268 
269  static constexpr index_t BPackedSize = []() {
271  return 2;
272  else
273  return 1;
274  }();
275 
285  struct Invoker : public BaseInvoker
286  {
287  template <typename GridwiseGemm>
288  float RunImp(const typename GridwiseGemm::Argument& arg,
289  const StreamConfig& stream_config = StreamConfig{})
290  {
291  if(stream_config.log_level_ > 0)
292  {
293  arg.Print();
294  GridwiseGemm::BlockwiseGemmPipe::HotLoopInstList::Print();
295  }
296 
298  {
299  throw std::runtime_error("wrong! GridwiseGemm has invalid setting");
300  }
301 
302  index_t gdx, gdy, gdz;
303  std::tie(gdx, gdy, gdz) = GridwiseGemm::CalculateGridSize(arg.M, arg.N, arg.KBatch);
304 
305  float ave_time = 0;
306 
307  index_t k_grain = arg.KBatch * KPerBlock;
308  index_t K_split = (arg.K + k_grain - 1) / k_grain * KPerBlock;
309 
310  const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K_split);
311 
312  const auto Run = [&](const auto& kernel) {
313  if(stream_config.flush_cache)
314  {
315  auto arg_ = arg;
316 
317  const auto a_grid_desc_ak0_m_ak1 = GridwiseGemm::MakeAGridDescriptor_AK0_M_AK1(
318  arg_.M, arg_.MPadded, arg_.K, arg_.KPadded, arg_.StrideA, arg_.AK0);
319  const auto b_grid_desc_bk0_n_bk1 = GridwiseGemm::MakeBGridDescriptor_BK0_N_BK1(
320  arg_.K, arg_.KPadded, arg_.N, arg_.NPadded, arg_.StrideB, arg_.BK0);
321 
322  auto size_a_buffer = a_grid_desc_ak0_m_ak1.GetElementSpaceSize() *
323  sizeof(ADataType) / APackedSize;
324  auto size_b_buffer = b_grid_desc_bk0_n_bk1.GetElementSpaceSize() *
325  sizeof(BDataType) / BPackedSize;
326 
328  arg_, stream_config.rotating_count, size_a_buffer, size_b_buffer);
329  rotating_mem.Print();
330 
331  auto run_flush_cache = [&]() {
332  // flush icache
334  // rotating mem
335  rotating_mem.Next();
336  // clear c mem
337  if(arg_.KBatch > 1)
338  hipGetErrorString(hipMemsetAsync(arg_.p_c_grid,
339  0,
340  arg_.M * arg_.N * sizeof(CDataType),
341  stream_config.stream_id_));
342  };
343 
344  ave_time = ck::utility::launch_and_time_kernel_with_preprocess<false>(
345  stream_config,
346  run_flush_cache,
347  kernel,
348  dim3(gdx, gdy, gdz),
349  dim3(BlockSize),
350  0,
351  arg_);
352  }
353  else
354  {
355  if(arg.KBatch > 1)
356  hipGetErrorString(hipMemsetAsync(arg.p_c_grid,
357  0,
358  arg.M * arg.N * sizeof(CDataType),
359  stream_config.stream_id_));
360 
361  ave_time = launch_and_time_kernel(
362  stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, arg);
363  }
364  };
365 
366  constexpr index_t minimum_occupancy = []() {
367  if constexpr(BlkGemmPipeSched == BlockGemmPipelineScheduler::Interwave)
368  {
369  return 2;
370  }
371  else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3)
372  {
373  return (MPerBlock * NPerBlock / BlockSize <= 128) ? 2 : 1;
374  }
375  else
376  {
377  return 1;
378  }
379  }();
380 
381  if(has_main_k_block_loop)
382  {
383  // Tail number always full
384  if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1 ||
385  BlkGemmPipelineVer == BlockGemmPipelineVersion::v3)
386  {
387  if(arg.KBatch > 1)
388  {
389  const auto kernel =
391  true,
393  minimum_occupancy>;
394  Run(kernel);
395  }
396  else
397  {
398  const auto kernel =
400  true,
402  minimum_occupancy>;
403  Run(kernel);
404  }
405  }
406  // Tail number could be One to Seven
407  else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v2)
408  {
409  if(arg.KBatch > 1)
410  {
412  {
413  const auto kernel =
415  true,
417  minimum_occupancy,
419  Run(kernel);
420  }
421  else if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
423  {
424  const auto kernel =
426  true,
428  minimum_occupancy,
430  Run(kernel);
431  }
432 
433  if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 2)
434  {
436  {
437  const auto kernel = kernel_gemm_xdl_cshuffle_v3<
438  GridwiseGemm,
439  true,
441  minimum_occupancy,
443  Run(kernel);
444  }
445  }
446 
447  if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 3)
448  {
451  {
452  const auto kernel = kernel_gemm_xdl_cshuffle_v3<
453  GridwiseGemm,
454  true,
456  minimum_occupancy,
458  Run(kernel);
459  }
460  }
461 
462  if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 4)
463  {
466  {
467  const auto kernel = kernel_gemm_xdl_cshuffle_v3<
468  GridwiseGemm,
469  true,
471  minimum_occupancy,
473  Run(kernel);
474  }
475  }
476 
477  if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 5)
478  {
481  {
482  const auto kernel = kernel_gemm_xdl_cshuffle_v3<
483  GridwiseGemm,
484  true,
486  minimum_occupancy,
488  Run(kernel);
489  }
490  }
491 
492  if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 6)
493  {
495  {
496  const auto kernel = kernel_gemm_xdl_cshuffle_v3<
497  GridwiseGemm,
498  true,
500  minimum_occupancy,
502  Run(kernel);
503  }
504  }
505 
506  if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 7)
507  {
510  {
511  const auto kernel = kernel_gemm_xdl_cshuffle_v3<
512  GridwiseGemm,
513  true,
515  minimum_occupancy,
517  Run(kernel);
518  }
519  }
520  }
521  else
522  {
524  {
525  const auto kernel =
527  true,
529  minimum_occupancy,
531  Run(kernel);
532  }
533  else if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
535  {
536  const auto kernel =
538  true,
540  minimum_occupancy,
542  Run(kernel);
543  }
544 
545  if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 2)
546  {
548  {
549  const auto kernel =
551  true,
553  minimum_occupancy,
555  Run(kernel);
556  }
557  }
558 
559  if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 3)
560  {
563  {
564  const auto kernel =
566  true,
568  minimum_occupancy,
570  Run(kernel);
571  }
572  }
573 
574  if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 4)
575  {
578  {
579  const auto kernel =
581  true,
583  minimum_occupancy,
585  Run(kernel);
586  }
587  }
588 
589  if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 5)
590  {
593  {
594  const auto kernel =
596  true,
598  minimum_occupancy,
600  Run(kernel);
601  }
602  }
603 
604  if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 6)
605  {
607  {
608  const auto kernel =
610  true,
612  minimum_occupancy,
614  Run(kernel);
615  }
616  }
617 
618  if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 7)
619  {
622  {
623  const auto kernel =
625  true,
627  minimum_occupancy,
629  Run(kernel);
630  }
631  }
632  }
633  }
634  // Tail number could be Odd or Even
635  else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v4)
636  {
637  if(arg.KBatch > 1)
638  {
640  {
641  const auto kernel = kernel_gemm_xdl_cshuffle_v3_2lds<
642  GridwiseGemm,
643  true,
645  minimum_occupancy,
647  Run(kernel);
648  }
649  else
650  {
651  const auto kernel = kernel_gemm_xdl_cshuffle_v3_2lds<
652  GridwiseGemm,
653  true,
655  minimum_occupancy,
657  Run(kernel);
658  }
659  }
660  else
661  {
663  {
664  const auto kernel =
666  true,
668  minimum_occupancy,
670  Run(kernel);
671  }
672  else
673  {
674  const auto kernel =
676  true,
678  minimum_occupancy,
680  Run(kernel);
681  }
682  }
683  }
684  else
685  {
686  if(arg.KBatch > 1)
687  {
689  {
690  const auto kernel =
692  true,
694  minimum_occupancy,
696  Run(kernel);
697  }
698  else
699  {
700  const auto kernel =
702  true,
704  minimum_occupancy,
706  Run(kernel);
707  }
708  }
709  else
710  {
712  {
713  const auto kernel =
715  true,
717  minimum_occupancy,
719  Run(kernel);
720  }
721  else
722  {
723  const auto kernel =
725  true,
727  minimum_occupancy,
729  Run(kernel);
730  }
731  }
732  }
733  }
734  else
735  {
736  // Tail number always 1
737  if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1)
738  {
739  if(arg.KBatch > 1)
740  {
741  const auto kernel =
743  false,
745  minimum_occupancy>;
746  Run(kernel);
747  }
748  else
749  {
750  const auto kernel =
752  false,
754  minimum_occupancy>;
755  Run(kernel);
756  }
757  }
758  }
759 
760  return ave_time;
761  }
762 
768  float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
769  {
770  if(get_warp_size() == 64)
771  {
772  if constexpr(NXdlPerWave64 > 0)
773  {
774  return RunImp<GridwiseGemm64>(arg, stream_config);
775  }
776  }
777  else
778  {
779  if constexpr(NXdlPerWave32 > 0)
780  {
781  return RunImp<GridwiseGemm32>(
782  reinterpret_cast<const typename GridwiseGemm32::Argument&>(arg),
783  stream_config);
784  }
785  }
786  return 0;
787  }
788  // polymorphic
789  float Run(const BaseArgument* p_arg,
790  const StreamConfig& stream_config = StreamConfig{}) override
791  {
792  return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
793  }
794  };
795 
796  static constexpr bool IsValidCompilationParameter()
797  {
798  // TODO: properly implement this check
799  return true;
800  }
801 
802  static bool IsSupportedArgument(const Argument& arg)
803  {
804  if(!ck::is_xdl_supported())
805  {
806  return false;
807  }
808 
809  if(arg.KBatch > 1)
810  {
811  if(is_gfx11_supported())
812  {
813  return false;
814  }
815 
816  if(!is_bf16_atomic_supported() && std::is_same_v<CDataType, ck::bhalf_t>)
817  {
818  return false;
819  }
820 
821  if(sizeof(CDataType) == 1)
822  {
823  return false;
824  }
825  }
826 
828  {
829  if(MPerXDL != 16 || NPerXDL != 16)
830  {
831  return false;
832  }
833  }
834 
835  if(is_gfx11_supported())
836  {
837  if constexpr(std::is_same_v<ADataType, ck::f8_t> ||
838  std::is_same_v<ADataType, ck::bf8_t>)
839  {
840  return false;
841  }
842  }
843 
844  if((arg.K % AK1 != 0 || arg.K % BK1 != 0) && !(GemmSpec == GemmSpecialization::MKPadding ||
845  GemmSpec == GemmSpecialization::NKPadding ||
846  GemmSpec == GemmSpecialization::MNKPadding ||
847  GemmSpec == GemmSpecialization::KPadding))
848  {
849  return false;
850  }
851 
852  if(get_warp_size() == 64)
853  {
854  if constexpr(NXdlPerWave64 > 0)
855  {
856  return GridwiseGemm64::CheckValidity(arg);
857  }
858  else
859  {
860  return false;
861  }
862  }
863  else
864  {
865  if constexpr(NXdlPerWave32 > 0)
866  {
868  reinterpret_cast<const typename GridwiseGemm32::Argument&>(arg));
869  }
870  else
871  {
872  return false;
873  }
874  }
875  }
876 
877  // polymorphic
878  bool IsSupportedArgument(const BaseArgument* p_arg) override
879  {
880  return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
881  }
882 
883  index_t GetKPerBlock() override { return KPerBlock; }
884 
885  bool GetPermuteA() override { return PermuteA; }
886  bool GetPermuteB() override { return PermuteB; }
887 
888  static auto MakeArgument(const ADataType* p_a,
889  const BDataType* p_b,
890  CDataType* p_c,
891  index_t M,
892  index_t N,
893  index_t K,
894  index_t StrideA,
895  index_t StrideB,
896  index_t StrideC,
897  index_t KBatch,
898  AElementwiseOperation,
899  BElementwiseOperation,
900  CElementwiseOperation)
901  {
902  return Argument{p_a, p_b, p_c, M, N, K, StrideA, StrideB, StrideC, KBatch};
903  }
904 
905  static auto MakeInvoker() { return Invoker{}; }
906 
907  // polymorphic
908  std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a,
909  const void* p_b,
910  void* p_c,
911  index_t M,
912  index_t N,
913  index_t K,
914  index_t StrideA,
915  index_t StrideB,
916  index_t StrideC,
917  index_t KBatch,
918  AElementwiseOperation,
919  BElementwiseOperation,
920  CElementwiseOperation) override
921  {
922  return std::make_unique<Argument>(static_cast<const ADataType*>(p_a),
923  static_cast<const BDataType*>(p_b),
924  static_cast<CDataType*>(p_c),
925  M,
926  N,
927  K,
928  StrideA,
929  StrideB,
930  StrideC,
931  KBatch);
932  }
933 
934  // polymorphic
935  std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
936  {
937  return std::make_unique<Invoker>(Invoker{});
938  }
939 
940  // polymorphic
941  std::string GetTypeString() const override
942  {
943  auto str = std::stringstream();
944 
945  std::map<BlockGemmPipelineScheduler, std::string> BlkGemmPipelineSchedulerToString{
948 
949  std::map<BlockGemmPipelineVersion, std::string> BlkGemmPipelineVersionToString{
955 
956  index_t PrefetchStages = 0;
957  index_t AMmaKStride = 0;
958  if(get_warp_size() == 64)
959  {
960  if constexpr(NXdlPerWave64 > 0)
961  {
962  PrefetchStages = GridwiseGemm64::BlockwiseGemmPipe::PrefetchStages;
963  AMmaKStride = GridwiseGemm64::BlockwiseGemmPipe::AMmaKStride;
964  }
965  }
966  else
967  {
968  if constexpr(NXdlPerWave32 > 0)
969  {
970  PrefetchStages = GridwiseGemm32::BlockwiseGemmPipe::PrefetchStages;
971  AMmaKStride = GridwiseGemm32::BlockwiseGemmPipe::AMmaKStride;
972  }
973  }
974 
975  // clang-format off
976  str << "DeviceGemmXdlUniversal"
977  << "<"
978  << getGemmSpecializationString(GemmSpec) << ", "
979  << std::string(ALayout::name)[0]
980  << std::string(BLayout::name)[0]
981  << std::string(CLayout::name)[0]
982  << ">"
983  << " BlkSize: "
984  << BlockSize << ", "
985  << "BlkTile: "
986  << MPerBlock<<"x"<<NPerBlock<<"x"<<KPerBlock << ", "
987  << "WaveTile: "
988  << MPerXDL<<"x"<<NPerXDL << ", "
989  << "WaveMap: "
990  << MXdlPerWave<<"x" << NXdlPerWave<<", "
991  << "VmemReadVec: "
992  << ABlockTransferSrcScalarPerVector<<"x"<<BBlockTransferSrcScalarPerVector<<", "
993  << "BlkGemmPipelineScheduler: "
994  << BlkGemmPipelineSchedulerToString[BlkGemmPipeSched] << ", "
995  << "BlkGemmPipelineVersion: "
996  << BlkGemmPipelineVersionToString[BlkGemmPipelineVer] << ", "
997  << "BlkGemmPipelinePrefetchStages: "
998  << PrefetchStages << ", "
999  << "Kpack: "
1000  << AMmaKStride;
1001  // clang-format on
1002 
1003  return str.str();
1004  }
1006 };
1007 
1008 } // namespace device
1009 } // namespace tensor_operation
1010 } // namespace ck
#define REGISTER_EXTRA_PRINTING_METHODS
Definition: device_base.hpp:46
float launch_and_time_kernel(const StreamConfig &stream_config, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args)
Definition: kernel_launch.hpp:14
__host__ constexpr __device__ T max(T x)
Definition: math.hpp:84
std::string getGemmSpecializationString(const GemmSpecialization &s)
Definition: gemm_specialization.hpp:32
GemmSpecialization
Definition: gemm_specialization.hpp:11
void flush_icache()
Definition: flush_cache.hpp:216
Definition: ck.hpp:267
bool is_xdl_supported()
Definition: device_prop.hpp:68
__host__ constexpr __device__ index_t get_warp_size()
Definition: get_id.hpp:42
BlockGemmPipelineVersion
Definition: blkgemmpipe_scheduler.hpp:12
__global__ void kernel_gemm_xdl_cshuffle_v3_2lds(typename GridwiseGemm::Argument karg)
Definition: gridwise_gemm_xdl_cshuffle_streamk_v3.hpp:59
constexpr Tuple< Args &... > tie(Args &... args) noexcept
Definition: tuple.hpp:218
bool is_gfx12_supported()
Definition: device_prop.hpp:55
constexpr bool is_same_v
Definition: type.hpp:283
BlockGemmPipelineScheduler
Definition: blkgemmpipe_scheduler.hpp:25
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition: type.hpp:297
__global__ void kernel_gemm_xdl_cshuffle_v3(typename GridwiseGemm::Argument karg)
Definition: gridwise_gemm_xdl_cshuffle_streamk_v3.hpp:38
int32_t index_t
Definition: ck.hpp:298
bool is_gfx11_supported()
Definition: device_prop.hpp:60
bool is_bf16_atomic_supported()
Definition: device_prop.hpp:85
Definition: stream_config.hpp:10
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:716
CDataType * p_c_grid
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:760
index_t N
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:695
index_t KBatch
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:700
index_t M
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:694
index_t K
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:696
__host__ void Print() const
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:673
"Universal" GEMM kernel with SplitK support.
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:247
__host__ static __device__ auto MakeBGridDescriptor_BK0_N_BK1(index_t K, index_t KPad, index_t N, index_t NPad, index_t StrideB, index_t BK0)
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:451
static constexpr __host__ TailNumber CalculateKBlockLoopTailNum(index_t K)
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:1456
static constexpr __host__ bool CheckValidity(const Argument &karg)
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:1237
static __host__ auto CalculateGridSize(index_t M, index_t N, index_t KBatch)
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:299
__host__ static __device__ auto MakeAGridDescriptor_AK0_M_AK1(index_t M, index_t MPad, index_t K, index_t KPad, index_t StrideA, index_t AK0)
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:369
static constexpr __host__ bool CalculateHasMainKBlockLoop(index_t K)
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:1449
Definition: data_type.hpp:186
Definition: device_base.hpp:51
Definition: device_base.hpp:62
Helper structure responsible for kernel invocation.
Definition: device_gemm_xdl_cshuffle_v3.hpp:286
float RunImp(const typename GridwiseGemm::Argument &arg, const StreamConfig &stream_config=StreamConfig{})
Definition: device_gemm_xdl_cshuffle_v3.hpp:288
float Run(const BaseArgument *p_arg, const StreamConfig &stream_config=StreamConfig{}) override
Definition: device_gemm_xdl_cshuffle_v3.hpp:789
float Run(const Argument &arg, const StreamConfig &stream_config=StreamConfig{})
This function issues GPU kernel execution.
Definition: device_gemm_xdl_cshuffle_v3.hpp:768
"Universal" GEMM operation with SplitK support.
Definition: device_gemm_xdl_cshuffle_v3.hpp:178
static constexpr auto GetNXdlPerWave()
Definition: device_gemm_xdl_cshuffle_v3.hpp:180
bool IsSupportedArgument(const BaseArgument *p_arg) override
Definition: device_gemm_xdl_cshuffle_v3.hpp:878
GridwiseGemm_xdl_cshuffle_v3< ALayout, BLayout, CLayout, ADataType, BDataType, GemmAccDataType, CShuffleDataType, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, BlockSize, ScaleBlockN, ScaleBlockK, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB, PermuteA, PermuteB > GridwiseGemm
Definition: device_gemm_xdl_cshuffle_v3_b_scale.hpp:138
bool GetPermuteA() override
Definition: device_gemm_xdl_cshuffle_v3.hpp:885
std::unique_ptr< BaseInvoker > MakeInvokerPointer() override
Definition: device_gemm_xdl_cshuffle_v3.hpp:935
index_t GetKPerBlock() override
Definition: device_gemm_xdl_cshuffle_v3.hpp:883
static bool IsSupportedArgument(const Argument &arg)
Definition: device_gemm_xdl_cshuffle_v3.hpp:802
static constexpr bool IsValidCompilationParameter()
Definition: device_gemm_xdl_cshuffle_v3.hpp:796
std::string GetTypeString() const override
Definition: device_gemm_xdl_cshuffle_v3.hpp:941
static constexpr auto NXdlPerWave32
Definition: device_gemm_xdl_cshuffle_v3.hpp:205
static auto MakeArgument(const ADataType *p_a, const BDataType *p_b, CDataType *p_c, index_t M, index_t N, index_t K, index_t StrideA, index_t StrideB, index_t StrideC, index_t KBatch, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation)
Definition: device_gemm_xdl_cshuffle_v3.hpp:888
static constexpr auto NXdlPerWave64
Definition: device_gemm_xdl_cshuffle_v3.hpp:204
bool GetPermuteB() override
Definition: device_gemm_xdl_cshuffle_v3.hpp:886
static auto MakeInvoker()
Definition: device_gemm_xdl_cshuffle_v3.hpp:905
std::unique_ptr< BaseArgument > MakeArgumentPointer(const void *p_a, const void *p_b, void *p_c, index_t M, index_t N, index_t K, index_t StrideA, index_t StrideB, index_t StrideC, index_t KBatch, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation) override
Definition: device_gemm_xdl_cshuffle_v3.hpp:908
static constexpr index_t BPackedSize
Definition: device_gemm_xdl_cshuffle_v3.hpp:269
typename GridwiseGemm64::Argument Argument
Definition: device_gemm_xdl_cshuffle_v3.hpp:260
static constexpr index_t APackedSize
Definition: device_gemm_xdl_cshuffle_v3.hpp:262
Definition: device_gemm_v2.hpp:22
Definition: flush_cache.hpp:138