26 namespace tensor_operation {
 
   29 template <
typename ALayout,
 
   37           typename GemmAccDataType,
 
   38           typename CShuffleDataType,
 
   39           typename AElementwiseOperation,
 
   40           typename BElementwiseOperation,
 
   41           typename CElementwiseOperation,
 
   53           typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
 
   54           typename ABlockTransferThreadClusterArrangeOrder,
 
   55           typename ABlockTransferSrcAccessOrder,
 
   56           index_t ABlockTransferSrcVectorDim,
 
   57           index_t ABlockTransferSrcScalarPerVector,
 
   58           index_t ABlockTransferDstScalarPerVector_AK1,
 
   60           typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
 
   61           typename BBlockTransferThreadClusterArrangeOrder,
 
   62           typename BBlockTransferSrcAccessOrder,
 
   63           index_t BBlockTransferSrcVectorDim,
 
   64           index_t BBlockTransferSrcScalarPerVector,
 
   65           index_t BBlockTransferDstScalarPerVector_BK1,
 
   67           index_t CShuffleMXdlPerWavePerShuffle,
 
   68           index_t CShuffleNXdlPerWavePerShuffle,
 
   69           typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
 
   70           index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
 
   73           typename ReduceDataType                     = CDataType,
 
   74           typename ComputeTypeA                       = CDataType,
 
   75           typename ComputeTypeB                       = ComputeTypeA>
 
   84                                                            AElementwiseOperation,
 
   85                                                            BElementwiseOperation,
 
   86                                                            CElementwiseOperation>
 
   97     template <index_t NXdlPerWave_>
 
  107         AElementwiseOperation,
 
  108         BElementwiseOperation,
 
  121         ABlockTransferThreadClusterLengths_AK0_M_AK1,
 
  122         ABlockTransferThreadClusterArrangeOrder,
 
  123         ABlockTransferSrcAccessOrder,
 
  124         ABlockTransferSrcVectorDim,
 
  125         ABlockTransferSrcScalarPerVector,
 
  126         ABlockTransferDstScalarPerVector_AK1,
 
  129         BBlockTransferThreadClusterLengths_BK0_N_BK1,
 
  130         BBlockTransferThreadClusterArrangeOrder,
 
  131         BBlockTransferSrcAccessOrder,
 
  132         BBlockTransferSrcVectorDim,
 
  133         BBlockTransferSrcScalarPerVector,
 
  134         BBlockTransferDstScalarPerVector_BK1,
 
  137         CShuffleMXdlPerWavePerShuffle,
 
  138         CShuffleNXdlPerWavePerShuffle,
 
  139         CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
 
  140         CShuffleBlockTransferScalarPerVector_NPerBlock,
 
  151                  const BDataType* p_b_grid_,
 
  152                  const std::array<const void*, NumDTensor> p_ds_,
 
  153                  CDataType* p_c_grid_,
 
  159                  std::array<ck::index_t, NumDTensor> StrideDs_,
 
  164                                        reinterpret_cast<ReduceDataType*>(p_c_grid_),
 
  178         const std::array<const void*, NumDTensor> 
p_ds;
 
  206         CShuffleBlockTransferScalarPerVector_NPerBlock, 
 
  209         CShuffleBlockTransferScalarPerVector_NPerBlock, 
 
  210         CShuffleBlockTransferScalarPerVector_NPerBlock, 
 
  218             static constexpr 
index_t NumInDim  = 3;
 
  219             static constexpr 
index_t NumOutDim = 2;
 
  221             std::array<ck::index_t, NumInDim> in_lengths   = {arg.
KBatch, arg.
M, arg.
N};
 
  222             std::array<ck::index_t, NumOutDim> out_lengths = {arg.
M, arg.
N};
 
  224             std::array<ck::index_t, NumInDim> in_strides;
 
  225             std::array<ck::index_t, NumOutDim> out_strides;
 
  228                 in_strides  = {arg.
M * arg.
N, arg.
N, 1};
 
  229                 out_strides = {arg.
N, 1};
 
  233                 in_strides  = {arg.
M * arg.
N, 1, arg.
M};
 
  234                 out_strides = {1, arg.
M};
 
  237             std::array<int, 1> reduce_dims{0};
 
  239             std::array<std::array<index_t, NumOutDim>, 
NumDTensor> DsLengths;
 
  240             std::array<std::array<index_t, NumOutDim>, 
NumDTensor> DsStrides;
 
  242             static_for<0, NumDTensor, 1>{}([&](
auto i) {
 
  243                 DsLengths[i] = out_lengths;
 
  248                     DsStrides[i] = {arg.
StrideDs[i], 1};
 
  252                     DsStrides[i] = {1, arg.
StrideDs[i]};
 
  258             auto argument_ptr = reduce.MakeArgumentPointer(in_lengths,
 
  271             auto invoker_ptr = reduce.MakeInvokerPointer();
 
  275             if(reduce.IsSupportedArgument(argument_ptr.get()))
 
  277                 ave_time = invoker_ptr->Run(argument_ptr.get(), stream_config);
 
  281                 throw std::runtime_error(
 
  282                     "The runtime parameters seems not supported by the device instance, exiting!");
 
  288         template <
typename Gr
idwiseGemm>
 
  291             auto arg = *
reinterpret_cast<const typename GridwiseGemm::Argument*
>(&arg_);
 
  298                     throw std::runtime_error(
"using reduce , but empty workspace!");
 
  304             if(stream_config.log_level_ > 0)
 
  309             if(!GridwiseGemm::CheckValidity(arg))
 
  311                 throw std::runtime_error(
"wrong! GridwiseGemm has invalid setting");
 
  315             std::tie(gdx, gdy, gdz) = GridwiseGemm::CalculateGridSize(arg.
M, arg.
N, arg.
KBatch);
 
  320             index_t K_split = (arg.
K + k_grain - 1) / k_grain * KPerBlock;
 
  322             const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K_split);
 
  324             const auto Run = [&](
const auto& kernel) {
 
  325                 if(stream_config.flush_cache)
 
  329                         stream_config.rotating_count,
 
  330                         arg.
M * arg.
K * 
sizeof(ADataType),
 
  331                         arg.
K * arg.
N * 
sizeof(BDataType));
 
  332                     rotating_mem.Print();
 
  334                     auto run_flush_cache = [&]() {
 
  341                     ave_time = ck::utility::launch_and_time_kernel_with_preprocess<false>(
 
  353                         stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, arg);
 
  357             constexpr 
index_t minimum_occupancy =
 
  360             if(has_main_k_block_loop)
 
  376                     if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == 
TailNumber::One)
 
  386                     else if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == 
TailNumber::Full)
 
  397                     if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 2)
 
  399                         if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == 
TailNumber::Two)
 
  411                     if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 3)
 
  425                     if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 4)
 
  439                     if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 5)
 
  453                     if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 6)
 
  455                         if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == 
TailNumber::Six)
 
  467                     if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 7)
 
  485                     if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == 
TailNumber::Odd)
 
  508                     if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == 
TailNumber::Odd)
 
  548                 ave_time += RunReduce(arg_, stream_config);
 
  559             return Run(*
dynamic_cast<const Argument*
>(p_arg), stream_config);
 
  571         if(!ck::is_xdl_wmma_supported<ComputeTypeA, ComputeTypeB, MPerXDL, NPerXDL>())
 
  585             if constexpr(NXdlPerWave64 > 0)
 
  587                 return GridwiseGemm64::CheckValidity(arg);
 
  592             if constexpr(NXdlPerWave32 > 0)
 
  594                 return GridwiseGemm32::CheckValidity(
 
  604         return IsSupportedArgument(*
dynamic_cast<const Argument*
>(p_arg));
 
  608                              const BDataType* p_b,
 
  609                              const std::array<const void*, NumDTensor> p_ds,
 
  616                              std::array<ck::index_t, NumDTensor> StrideDs,
 
  619                              AElementwiseOperation,
 
  620                              BElementwiseOperation,
 
  621                              CElementwiseOperation)
 
  623         return Argument{p_a, p_b, p_ds, p_c, M, N, K, StrideA, StrideB, StrideDs, StrideC, KBatch};
 
  631                                                       std::array<const void*, NumDTensor> p_ds,
 
  638                                                       std::array<ck::index_t, NumDTensor> StrideDs,
 
  641                                                       AElementwiseOperation,
 
  642                                                       BElementwiseOperation,
 
  643                                                       CElementwiseOperation)
 override 
  645         return std::make_unique<Argument>(
static_cast<const ADataType*
>(p_a),
 
  646                                           static_cast<const BDataType*
>(p_b),
 
  648                                           static_cast<CDataType*
>(p_c),
 
  662         return std::make_unique<Invoker>(
Invoker{});
 
  668         auto str = std::stringstream();
 
  670         std::map<BlockGemmPipelineScheduler, std::string> BlkGemmPipelineSchedulerToString{
 
  674         std::map<BlockGemmPipelineVersion, std::string> BlkGemmPipelineVersionToString{
 
  682         str << 
"DeviceGemmXdlUniversalReduce" 
  685             << std::string(ALayout::name)[0]
 
  686             << std::string(BLayout::name)[0]
 
  687             << std::string(CLayout::name)[0]
 
  692             << MPerBlock<<
"x"<<NPerBlock<<
"x"<<KPerBlock << 
", " 
  694             << MPerXDL<<
"x"<<NPerXDL << 
", " 
  696             << MXdlPerWave<<
"x" << NXdlPerWave<<
", " 
  698             << ABlockTransferSrcScalarPerVector<<
"x"<<BBlockTransferSrcScalarPerVector<<
", " 
  699             << 
"BlkGemmPipelineScheduler: " 
  700             << BlkGemmPipelineSchedulerToString[BlkGemmPipeSched] << 
", " 
  701             << 
"BlkGemmPipelineVersion: " 
  702             << BlkGemmPipelineVersionToString[BlkGemmPipelineVer] << 
", " 
  703             << 
"BlkGemmPipelinePrefetchStages: " 
  704             << GridwiseGemm64::BlockwiseGemmPipe::PrefetchStages;
 
  712         auto arg = *
dynamic_cast<const Argument*
>(p_arg);
 
  717             std::cout << 
"using workspace" << std::endl;
 
  718             return arg.
M * arg.
N * arg.
KBatch * 
sizeof(ReduceDataType);
 
#define GET_NXDL_PER_WAVE_IMPL
Definition: device_base.hpp:81
 
#define INVOKER_RUN_IMPL
Definition: device_base.hpp:94
 
float launch_and_time_kernel(const StreamConfig &stream_config, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args)
Definition: kernel_launch.hpp:14
 
__host__ constexpr __device__ T max(T x)
Definition: math.hpp:84
 
std::string getGemmSpecializationString(const GemmSpecialization &s)
Definition: gemm_specialization.hpp:32
 
GemmSpecialization
Definition: gemm_specialization.hpp:11
 
void flush_icache()
Definition: flush_cache.hpp:383
 
ck_tile::element_wise::PassThrough PassThrough
Definition: grouped_convolution_utils.hpp:47
 
typename tuple_element< I, TTuple >::type tuple_element_t
Definition: tuple.hpp:208
 
BlockGemmPipelineVersion
Definition: blkgemmpipe_scheduler.hpp:12
 
__global__ void kernel_gemm_xdl_cshuffle_v3_2lds(typename GridwiseGemm::Argument karg)
Definition: gridwise_gemm_xdl_cshuffle_streamk_v3.hpp:62
 
constexpr Tuple< Args &... > tie(Args &... args) noexcept
Definition: tuple.hpp:218
 
__host__ constexpr __device__ auto generate_sequence_v2(F &&f, Number< N >)
Definition: sequence_helper.hpp:25
 
constexpr __device__ index_t get_warp_size()
Definition: get_id.hpp:10
 
BlockGemmPipelineScheduler
Definition: blkgemmpipe_scheduler.hpp:25
 
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition: type.hpp:297
 
__global__ void kernel_gemm_xdl_cshuffle_v3(typename GridwiseGemm::Argument karg)
Definition: gridwise_gemm_xdl_cshuffle_streamk_v3.hpp:38
 
int32_t index_t
Definition: ck.hpp:299
 
const GenericPointer< typename T::ValueType > T2 value
Definition: pointer.h:1697
 
Definition: stream_config.hpp:10
 
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:716
 
CDataType * p_c_grid
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:760
 
__host__ __device__ bool IsReduceAdd() const
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:748
 
index_t N
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:695
 
index_t KBatch
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:700
 
index_t M
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:694
 
index_t K
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:696
 
__host__ void Print() const
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:673
 
"Universal" GEMM kernel with SplitK support.
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:247
 
Definition: integral_constant.hpp:20
 
Definition: reduction_operator.hpp:37
 
Definition: device_base.hpp:197
 
void * p_workspace_
Definition: device_base.hpp:204
 
Definition: device_base.hpp:208
 
Definition: device_gemm_xdl_cshuffle_v3r1.hpp:149
 
const std::array< const void *, NumDTensor > p_ds
Definition: device_gemm_xdl_cshuffle_v3r1.hpp:178
 
std::array< ck::index_t, NumDTensor > StrideDs
Definition: device_gemm_xdl_cshuffle_v3r1.hpp:179
 
Argument(const ADataType *p_a_grid_, const BDataType *p_b_grid_, const std::array< const void *, NumDTensor > p_ds_, CDataType *p_c_grid_, index_t M_, index_t N_, index_t K_, index_t StrideA_, index_t StrideB_, std::array< ck::index_t, NumDTensor > StrideDs_, index_t StrideC_, index_t k_batch_)
Definition: device_gemm_xdl_cshuffle_v3r1.hpp:150
 
Definition: device_gemm_xdl_cshuffle_v3r1.hpp:215
 
INVOKER_RUN_IMPL float Run(const BaseArgument *p_arg, const StreamConfig &stream_config=StreamConfig{}) override
Definition: device_gemm_xdl_cshuffle_v3r1.hpp:556
 
float RunReduce(const Argument &arg, const StreamConfig &stream_config=StreamConfig{})
Definition: device_gemm_xdl_cshuffle_v3r1.hpp:216
 
float RunImp(const Argument &arg_, const StreamConfig &stream_config=StreamConfig{})
Definition: device_gemm_xdl_cshuffle_v3r1.hpp:289
 
Definition: device_gemm_xdl_cshuffle_v3r1.hpp:87
 
static constexpr GET_NXDL_PER_WAVE_IMPL auto NXdlPerWave64
Definition: device_gemm_xdl_cshuffle_v3r1.hpp:89
 
DeviceReduceThreadWiseMultiD< ReduceDataType, DsDataType, GemmAccDataType, CDataType, 3, 1, ReduceAdd, PassThrough, OutElementwiseOperation, 256, CShuffleBlockTransferScalarPerVector_NPerBlock, 1, 0, CShuffleBlockTransferScalarPerVector_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, decltype(DsVectorLengthSequence)> DeviceReduceInstance
Definition: device_gemm_xdl_cshuffle_v3r1.hpp:211
 
std::unique_ptr< BaseArgument > MakeArgumentPointer(const void *p_a, const void *p_b, std::array< const void *, NumDTensor > p_ds, void *p_c, index_t M, index_t N, index_t K, index_t StrideA, index_t StrideB, std::array< ck::index_t, NumDTensor > StrideDs, index_t StrideC, index_t KBatch, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation) override
Definition: device_gemm_xdl_cshuffle_v3r1.hpp:629
 
static constexpr index_t NumDTensor
Definition: device_gemm_xdl_cshuffle_v3r1.hpp:92
 
size_t GetWorkSpaceSize(const BaseArgument *p_arg) const override
Definition: device_gemm_xdl_cshuffle_v3r1.hpp:710
 
static constexpr auto NXdlPerWave32
Definition: device_gemm_xdl_cshuffle_v3r1.hpp:90
 
static constexpr bool IsValidCompilationParameter()
Definition: device_gemm_xdl_cshuffle_v3r1.hpp:563
 
static constexpr auto DsVectorLengthSequence
Definition: device_gemm_xdl_cshuffle_v3r1.hpp:185
 
ck::reduce::Add ReduceAdd
Definition: device_gemm_xdl_cshuffle_v3r1.hpp:182
 
static bool IsSupportedArgument(const Argument &arg)
Definition: device_gemm_xdl_cshuffle_v3r1.hpp:569
 
std::unique_ptr< BaseInvoker > MakeInvokerPointer() override
Definition: device_gemm_xdl_cshuffle_v3r1.hpp:660
 
bool IsSupportedArgument(const BaseArgument *p_arg) override
Definition: device_gemm_xdl_cshuffle_v3r1.hpp:602
 
static auto MakeArgument(const ADataType *p_a, const BDataType *p_b, const std::array< const void *, NumDTensor > p_ds, CDataType *p_c, index_t M, index_t N, index_t K, index_t StrideA, index_t StrideB, std::array< ck::index_t, NumDTensor > StrideDs, index_t StrideC, index_t KBatch, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation)
Definition: device_gemm_xdl_cshuffle_v3r1.hpp:607
 
static auto MakeInvoker()
Definition: device_gemm_xdl_cshuffle_v3r1.hpp:626
 
CElementwiseOperation OutElementwiseOperation
Definition: device_gemm_xdl_cshuffle_v3r1.hpp:183
 
ck::tensor_operation::element_wise::PassThrough PassThrough
Definition: device_gemm_xdl_cshuffle_v3r1.hpp:94
 
std::string GetTypeString() const override
Definition: device_gemm_xdl_cshuffle_v3r1.hpp:666
 
Definition: device_gemm_v2.hpp:57
 
Definition: device_reduce_threadwise_multi_d.hpp:47
 
Definition: unary_element_wise_operation.hpp:340
 
Definition: flush_cache.hpp:299