21 namespace tensor_operation {
 
  123 template <
typename ALayout,
 
  129           typename GemmAccDataType,
 
  130           typename CShuffleDataType,
 
  131           typename AElementwiseOperation,
 
  132           typename BElementwiseOperation,
 
  133           typename CElementwiseOperation,
 
  145           typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
 
  146           typename ABlockTransferThreadClusterArrangeOrder,
 
  147           typename ABlockTransferSrcAccessOrder,
 
  148           index_t ABlockTransferSrcVectorDim,
 
  149           index_t ABlockTransferSrcScalarPerVector,
 
  150           index_t ABlockTransferDstScalarPerVector_AK1,
 
  151           bool ABlockLdsExtraM,
 
  152           typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
 
  153           typename BBlockTransferThreadClusterArrangeOrder,
 
  154           typename BBlockTransferSrcAccessOrder,
 
  155           index_t BBlockTransferSrcVectorDim,
 
  156           index_t BBlockTransferSrcScalarPerVector,
 
  157           index_t BBlockTransferDstScalarPerVector_BK1,
 
  158           bool BBlockLdsExtraN,
 
  159           index_t CShuffleMXdlPerWavePerShuffle,
 
  160           index_t CShuffleNXdlPerWavePerShuffle,
 
  161           typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
 
  162           index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
 
  165           typename ComputeTypeA                       = CDataType,
 
  166           typename ComputeTypeB                       = ComputeTypeA,
 
  167           bool PermuteA                               = 
false,
 
  168           bool PermuteB                               = 
false>
 
  175                                                        AElementwiseOperation,
 
  176                                                        BElementwiseOperation,
 
  177                                                        CElementwiseOperation>
 
  184     template <index_t NXdlPerWave_>
 
  194         AElementwiseOperation,
 
  195         BElementwiseOperation,
 
  196         CElementwiseOperation,
 
  208         ABlockTransferThreadClusterLengths_AK0_M_AK1,
 
  209         ABlockTransferThreadClusterArrangeOrder,
 
  210         ABlockTransferSrcAccessOrder,
 
  211         ABlockTransferSrcVectorDim,
 
  212         ABlockTransferSrcScalarPerVector,
 
  213         ABlockTransferDstScalarPerVector_AK1,
 
  216         BBlockTransferThreadClusterLengths_BK0_N_BK1,
 
  217         BBlockTransferThreadClusterArrangeOrder,
 
  218         BBlockTransferSrcAccessOrder,
 
  219         BBlockTransferSrcVectorDim,
 
  220         BBlockTransferSrcScalarPerVector,
 
  221         BBlockTransferDstScalarPerVector_BK1,
 
  224         CShuffleMXdlPerWavePerShuffle,
 
  225         CShuffleNXdlPerWavePerShuffle,
 
  226         CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
 
  227         CShuffleBlockTransferScalarPerVector_NPerBlock,
 
  269         template <
typename Gr
idwiseGemm>
 
  270         float RunImp(
const typename GridwiseGemm::Argument& arg,
 
  273             if(stream_config.log_level_ > 0)
 
  276                 GridwiseGemm::BlockwiseGemmPipe::HotLoopInstList::Print();
 
  279             if(!GridwiseGemm::CheckValidity(arg))
 
  281                 throw std::runtime_error(
"wrong! GridwiseGemm has invalid setting");
 
  285             std::tie(gdx, gdy, gdz) = GridwiseGemm::CalculateGridSize(arg.M, arg.N, arg.KBatch);
 
  289             index_t k_grain = arg.KBatch * KPerBlock;
 
  290             index_t K_split = (arg.K + k_grain - 1) / k_grain * KPerBlock;
 
  292             const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K_split);
 
  294             const auto Run = [&](
const auto& kernel) {
 
  295                 if(stream_config.flush_cache)
 
  299                     const auto a_grid_desc_ak0_m_ak1 = GridwiseGemm::MakeAGridDescriptor_AK0_M_AK1(
 
  300                         arg_.M, arg_.MPadded, arg_.K, arg_.KPadded, arg_.StrideA, arg_.AK0);
 
  301                     const auto b_grid_desc_bk0_n_bk1 = GridwiseGemm::MakeBGridDescriptor_BK0_N_BK1(
 
  302                         arg_.K, arg_.KPadded, arg_.N, arg_.NPadded, arg_.StrideB, arg_.BK0);
 
  304                     auto size_a_buffer = a_grid_desc_ak0_m_ak1.GetElementSpaceSize() *
 
  306                     auto size_b_buffer = b_grid_desc_bk0_n_bk1.GetElementSpaceSize() *
 
  310                         arg_, stream_config.rotating_count, size_a_buffer, size_b_buffer);
 
  311                     rotating_mem.Print();
 
  313                     auto run_flush_cache = [&]() {
 
  320                             hipGetErrorString(hipMemsetAsync(arg_.p_c_grid,
 
  322                                                              arg_.M * arg_.N * 
sizeof(CDataType),
 
  323                                                              stream_config.stream_id_));
 
  326                     ave_time = ck::utility::launch_and_time_kernel_with_preprocess<false>(
 
  338                         hipGetErrorString(hipMemsetAsync(arg.p_c_grid,
 
  340                                                          arg.M * arg.N * 
sizeof(CDataType),
 
  341                                                          stream_config.stream_id_));
 
  344                         stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, arg);
 
  348             constexpr 
index_t minimum_occupancy = []() {
 
  355                     return (MPerBlock * NPerBlock / BlockSize <= 128) ? 2 : 1;
 
  363             if(has_main_k_block_loop)
 
  393                         if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == 
TailNumber::One)
 
  403                         else if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
 
  415                         if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 2)
 
  417                             if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == 
TailNumber::Two)
 
  429                         if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 3)
 
  431                             if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
 
  444                         if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 4)
 
  446                             if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
 
  459                         if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 5)
 
  461                             if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
 
  474                         if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 6)
 
  476                             if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == 
TailNumber::Six)
 
  488                         if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 7)
 
  490                             if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
 
  505                         if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == 
TailNumber::One)
 
  515                         else if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
 
  527                         if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 2)
 
  529                             if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == 
TailNumber::Two)
 
  541                         if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 3)
 
  543                             if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
 
  556                         if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 4)
 
  558                             if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
 
  571                         if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 5)
 
  573                             if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
 
  586                         if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 6)
 
  588                             if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == 
TailNumber::Six)
 
  600                         if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 7)
 
  602                             if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
 
  621                         if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == 
TailNumber::Odd)
 
  644                         if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == 
TailNumber::Odd)
 
  670                         if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == 
TailNumber::Odd)
 
  693                         if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == 
TailNumber::Odd)
 
  750             return Run(*
dynamic_cast<const Argument*
>(p_arg), stream_config);
 
  762         if(!ck::is_xdl_wmma_supported<ComputeTypeA, ComputeTypeB, MPerXDL, NPerXDL>())
 
  778             if(
sizeof(CDataType) == 1)
 
  786             if constexpr(std::is_same_v<ADataType, ck::f8_t> ||
 
  787                          std::is_same_v<ADataType, ck::bf8_t>)
 
  831                              const BDataType* p_b,
 
  840                              AElementwiseOperation,
 
  841                              BElementwiseOperation,
 
  842                              CElementwiseOperation)
 
  844         return Argument{p_a, p_b, p_c, M, N, K, StrideA, StrideB, StrideC, KBatch};
 
  860                                                       AElementwiseOperation,
 
  861                                                       BElementwiseOperation,
 
  862                                                       CElementwiseOperation)
 override 
  864         return std::make_unique<Argument>(
static_cast<const ADataType*
>(p_a),
 
  865                                           static_cast<const BDataType*
>(p_b),
 
  866                                           static_cast<CDataType*
>(p_c),
 
  879         return std::make_unique<Invoker>(
Invoker{});
 
  885         auto str = std::stringstream();
 
  887         std::map<BlockGemmPipelineScheduler, std::string> BlkGemmPipelineSchedulerToString{
 
  891         std::map<BlockGemmPipelineVersion, std::string> BlkGemmPipelineVersionToString{
 
  904                 PrefetchStages = GridwiseGemm64::BlockwiseGemmPipe::PrefetchStages;
 
  905                 AMmaKStride    = GridwiseGemm64::BlockwiseGemmPipe::AMmaKStride;
 
  912                 PrefetchStages = GridwiseGemm32::BlockwiseGemmPipe::PrefetchStages;
 
  913                 AMmaKStride    = GridwiseGemm32::BlockwiseGemmPipe::AMmaKStride;
 
  918         str << 
"DeviceGemmXdlUniversal" 
  921             << std::string(ALayout::name)[0]
 
  922             << std::string(BLayout::name)[0]
 
  923             << std::string(CLayout::name)[0]
 
  928             << MPerBlock<<
"x"<<NPerBlock<<
"x"<<KPerBlock << 
", " 
  930             << MPerXDL<<
"x"<<NPerXDL << 
", " 
  932             << MXdlPerWave<<
"x" << NXdlPerWave<<
", " 
  934             << ABlockTransferSrcScalarPerVector<<
"x"<<BBlockTransferSrcScalarPerVector<<
", " 
  935             << 
"BlkGemmPipelineScheduler: " 
  936             << BlkGemmPipelineSchedulerToString[BlkGemmPipeSched] << 
", " 
  937             << 
"BlkGemmPipelineVersion: " 
  938             << BlkGemmPipelineVersionToString[BlkGemmPipelineVer] << 
", " 
  939             << 
"BlkGemmPipelinePrefetchStages: " 
  940             << PrefetchStages << 
", " 
#define INVOKER_RUN3_IMPL
Definition: device_base.hpp:114
 
#define REGISTER_EXTRA_PRINTING_METHODS
Definition: device_base.hpp:47
 
#define GET_NXDL_PER_WAVE_IMPL
Definition: device_base.hpp:81
 
float launch_and_time_kernel(const StreamConfig &stream_config, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args)
Definition: kernel_launch.hpp:14
 
__host__ constexpr __device__ T max(T x)
Definition: math.hpp:84
 
std::string getGemmSpecializationString(const GemmSpecialization &s)
Definition: gemm_specialization.hpp:32
 
GemmSpecialization
Definition: gemm_specialization.hpp:11
 
void flush_icache()
Definition: flush_cache.hpp:383
 
BlockGemmPipelineVersion
Definition: blkgemmpipe_scheduler.hpp:12
 
__global__ void kernel_gemm_xdl_cshuffle_v3_2lds(typename GridwiseGemm::Argument karg)
Definition: gridwise_gemm_xdl_cshuffle_streamk_v3.hpp:62
 
constexpr Tuple< Args &... > tie(Args &... args) noexcept
Definition: tuple.hpp:218
 
constexpr __device__ index_t get_warp_size()
Definition: get_id.hpp:10
 
constexpr bool is_same_v
Definition: type.hpp:283
 
BlockGemmPipelineScheduler
Definition: blkgemmpipe_scheduler.hpp:25
 
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition: type.hpp:297
 
__global__ void kernel_gemm_xdl_cshuffle_v3(typename GridwiseGemm::Argument karg)
Definition: gridwise_gemm_xdl_cshuffle_streamk_v3.hpp:38
 
int32_t index_t
Definition: ck.hpp:299
 
bool is_gfx11_supported()
Definition: device_prop.hpp:60
 
bool is_bf16_atomic_supported()
Definition: device_prop.hpp:108
 
Definition: stream_config.hpp:10
 
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:716
 
"Universal" GEMM kernel with SplitK support.
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:247
 
static constexpr __host__ bool CheckValidity(const Argument &karg)
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:1202
 
Definition: data_type.hpp:187
 
Definition: device_base.hpp:197
 
Definition: device_base.hpp:208
 
Helper structure responsible for kernel invocation.
Definition: device_gemm_xdl_cshuffle_v3.hpp:263
 
float RunImp(const typename GridwiseGemm::Argument &arg, const StreamConfig &stream_config=StreamConfig{})
This function issues GPU kernel execution.
Definition: device_gemm_xdl_cshuffle_v3.hpp:270
 
INVOKER_RUN3_IMPL float Run(const BaseArgument *p_arg, const StreamConfig &stream_config=StreamConfig{}) override
Definition: device_gemm_xdl_cshuffle_v3.hpp:747
 
"Universal" GEMM operation with SplitK support.
Definition: device_gemm_xdl_cshuffle_v3.hpp:178
 
static constexpr GET_NXDL_PER_WAVE_IMPL auto NXdlPerWave64
Definition: device_gemm_xdl_cshuffle_v3.hpp:181
 
bool IsSupportedArgument(const BaseArgument *p_arg) override
Definition: device_gemm_xdl_cshuffle_v3.hpp:820
 
bool GetPermuteA() override
Definition: device_gemm_xdl_cshuffle_v3.hpp:827
 
std::unique_ptr< BaseInvoker > MakeInvokerPointer() override
Definition: device_gemm_xdl_cshuffle_v3.hpp:877
 
index_t GetKPerBlock() override
Definition: device_gemm_xdl_cshuffle_v3.hpp:825
 
static bool IsSupportedArgument(const Argument &arg)
Definition: device_gemm_xdl_cshuffle_v3.hpp:760
 
static constexpr bool IsValidCompilationParameter()
Definition: device_gemm_xdl_cshuffle_v3.hpp:754
 
std::string GetTypeString() const override
Definition: device_gemm_xdl_cshuffle_v3.hpp:883
 
static auto MakeArgument(const ADataType *p_a, const BDataType *p_b, CDataType *p_c, index_t M, index_t N, index_t K, index_t StrideA, index_t StrideB, index_t StrideC, index_t KBatch, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation)
Definition: device_gemm_xdl_cshuffle_v3.hpp:830
 
bool GetPermuteB() override
Definition: device_gemm_xdl_cshuffle_v3.hpp:828
 
static auto MakeInvoker()
Definition: device_gemm_xdl_cshuffle_v3.hpp:847
 
std::unique_ptr< BaseArgument > MakeArgumentPointer(const void *p_a, const void *p_b, void *p_c, index_t M, index_t N, index_t K, index_t StrideA, index_t StrideB, index_t StrideC, index_t KBatch, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation) override
Definition: device_gemm_xdl_cshuffle_v3.hpp:850
 
static constexpr index_t BPackedSize
Definition: device_gemm_xdl_cshuffle_v3.hpp:246
 
typename GridwiseGemm64::Argument Argument
Definition: device_gemm_xdl_cshuffle_v3.hpp:237
 
static constexpr auto NXdlPerWave32
Definition: device_gemm_xdl_cshuffle_v3.hpp:182
 
static constexpr index_t APackedSize
Definition: device_gemm_xdl_cshuffle_v3.hpp:239
 
Definition: device_gemm_v2.hpp:22
 
Definition: flush_cache.hpp:299