26 namespace tensor_operation {
29 template <
typename ALayout,
37 typename GemmAccDataType,
38 typename CShuffleDataType,
39 typename AElementwiseOperation,
40 typename BElementwiseOperation,
41 typename CElementwiseOperation,
53 typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
54 typename ABlockTransferThreadClusterArrangeOrder,
55 typename ABlockTransferSrcAccessOrder,
56 index_t ABlockTransferSrcVectorDim,
57 index_t ABlockTransferSrcScalarPerVector,
58 index_t ABlockTransferDstScalarPerVector_AK1,
60 typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
61 typename BBlockTransferThreadClusterArrangeOrder,
62 typename BBlockTransferSrcAccessOrder,
63 index_t BBlockTransferSrcVectorDim,
64 index_t BBlockTransferSrcScalarPerVector,
65 index_t BBlockTransferDstScalarPerVector_BK1,
67 index_t CShuffleMXdlPerWavePerShuffle,
68 index_t CShuffleNXdlPerWavePerShuffle,
69 typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
70 index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
73 typename ReduceDataType = CDataType,
74 typename ComputeTypeA = CDataType,
75 typename ComputeTypeB = ComputeTypeA>
84 AElementwiseOperation,
85 BElementwiseOperation,
86 CElementwiseOperation>
97 template <index_t NXdlPerWave_>
107 AElementwiseOperation,
108 BElementwiseOperation,
121 ABlockTransferThreadClusterLengths_AK0_M_AK1,
122 ABlockTransferThreadClusterArrangeOrder,
123 ABlockTransferSrcAccessOrder,
124 ABlockTransferSrcVectorDim,
125 ABlockTransferSrcScalarPerVector,
126 ABlockTransferDstScalarPerVector_AK1,
129 BBlockTransferThreadClusterLengths_BK0_N_BK1,
130 BBlockTransferThreadClusterArrangeOrder,
131 BBlockTransferSrcAccessOrder,
132 BBlockTransferSrcVectorDim,
133 BBlockTransferSrcScalarPerVector,
134 BBlockTransferDstScalarPerVector_BK1,
137 CShuffleMXdlPerWavePerShuffle,
138 CShuffleNXdlPerWavePerShuffle,
139 CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
140 CShuffleBlockTransferScalarPerVector_NPerBlock,
151 const BDataType* p_b_grid_,
152 const std::array<const void*, NumDTensor> p_ds_,
153 CDataType* p_c_grid_,
159 std::array<ck::index_t, NumDTensor> StrideDs_,
164 reinterpret_cast<ReduceDataType*>(p_c_grid_),
178 const std::array<const void*, NumDTensor>
p_ds;
206 CShuffleBlockTransferScalarPerVector_NPerBlock,
209 CShuffleBlockTransferScalarPerVector_NPerBlock,
210 CShuffleBlockTransferScalarPerVector_NPerBlock,
218 static constexpr
index_t NumInDim = 3;
219 static constexpr
index_t NumOutDim = 2;
221 std::array<ck::index_t, NumInDim> in_lengths = {arg.
KBatch, arg.
M, arg.
N};
222 std::array<ck::index_t, NumOutDim> out_lengths = {arg.
M, arg.
N};
224 std::array<ck::index_t, NumInDim> in_strides;
225 std::array<ck::index_t, NumOutDim> out_strides;
228 in_strides = {arg.
M * arg.
N, arg.
N, 1};
229 out_strides = {arg.
N, 1};
233 in_strides = {arg.
M * arg.
N, 1, arg.
M};
234 out_strides = {1, arg.
M};
237 std::array<int, 1> reduce_dims{0};
239 std::array<std::array<index_t, NumOutDim>,
NumDTensor> DsLengths;
240 std::array<std::array<index_t, NumOutDim>,
NumDTensor> DsStrides;
242 static_for<0, NumDTensor, 1>{}([&](
auto i) {
243 DsLengths[i] = out_lengths;
248 DsStrides[i] = {arg.
StrideDs[i], 1};
252 DsStrides[i] = {1, arg.
StrideDs[i]};
258 auto argument_ptr = reduce.MakeArgumentPointer(in_lengths,
271 auto invoker_ptr = reduce.MakeInvokerPointer();
275 if(reduce.IsSupportedArgument(argument_ptr.get()))
277 ave_time = invoker_ptr->Run(argument_ptr.get(), stream_config);
281 throw std::runtime_error(
282 "The runtime parameters seems not supported by the device instance, exiting!");
288 template <
typename Gr
idwiseGemm>
291 auto arg = *
reinterpret_cast<const typename GridwiseGemm::Argument*
>(&arg_);
298 throw std::runtime_error(
"using reduce , but empty workspace!");
304 if(stream_config.log_level_ > 0)
309 if(!GridwiseGemm::CheckValidity(arg))
311 throw std::runtime_error(
"wrong! GridwiseGemm has invalid setting");
315 std::tie(gdx, gdy, gdz) = GridwiseGemm::CalculateGridSize(arg.
M, arg.
N, arg.
KBatch);
320 index_t K_split = (arg.
K + k_grain - 1) / k_grain * KPerBlock;
322 const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K_split);
324 const auto Run = [&](
const auto& kernel) {
325 if(stream_config.flush_cache)
329 stream_config.rotating_count,
330 arg.
M * arg.
K *
sizeof(ADataType),
331 arg.
K * arg.
N *
sizeof(BDataType));
332 rotating_mem.Print();
334 auto run_flush_cache = [&]() {
341 ave_time = ck::utility::launch_and_time_kernel_with_preprocess<false>(
353 stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, arg);
357 constexpr
index_t minimum_occupancy =
360 if(has_main_k_block_loop)
376 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::One)
386 else if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::Full)
397 if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 2)
399 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::Two)
411 if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 3)
425 if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 4)
439 if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 5)
453 if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 6)
455 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::Six)
467 if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 7)
485 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::Odd)
508 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::Odd)
548 ave_time += RunReduce(arg_, stream_config);
559 return Run(*
dynamic_cast<const Argument*
>(p_arg), stream_config);
571 if(!ck::is_xdl_wmma_supported<ComputeTypeA, ComputeTypeB, MPerXDL, NPerXDL>())
585 if constexpr(NXdlPerWave64 > 0)
587 return GridwiseGemm64::CheckValidity(arg);
592 if constexpr(NXdlPerWave32 > 0)
594 return GridwiseGemm32::CheckValidity(
604 return IsSupportedArgument(*
dynamic_cast<const Argument*
>(p_arg));
608 const BDataType* p_b,
609 const std::array<const void*, NumDTensor> p_ds,
616 std::array<ck::index_t, NumDTensor> StrideDs,
619 AElementwiseOperation,
620 BElementwiseOperation,
621 CElementwiseOperation)
623 return Argument{p_a, p_b, p_ds, p_c, M, N, K, StrideA, StrideB, StrideDs, StrideC, KBatch};
631 std::array<const void*, NumDTensor> p_ds,
638 std::array<ck::index_t, NumDTensor> StrideDs,
641 AElementwiseOperation,
642 BElementwiseOperation,
643 CElementwiseOperation)
override
645 return std::make_unique<Argument>(
static_cast<const ADataType*
>(p_a),
646 static_cast<const BDataType*
>(p_b),
648 static_cast<CDataType*
>(p_c),
662 return std::make_unique<Invoker>(
Invoker{});
668 auto str = std::stringstream();
670 std::map<BlockGemmPipelineScheduler, std::string> BlkGemmPipelineSchedulerToString{
674 std::map<BlockGemmPipelineVersion, std::string> BlkGemmPipelineVersionToString{
682 str <<
"DeviceGemmXdlUniversalReduce"
685 << std::string(ALayout::name)[0]
686 << std::string(BLayout::name)[0]
687 << std::string(CLayout::name)[0]
692 << MPerBlock<<
"x"<<NPerBlock<<
"x"<<KPerBlock <<
", "
694 << MPerXDL<<
"x"<<NPerXDL <<
", "
696 << MXdlPerWave<<
"x" << NXdlPerWave<<
", "
698 << ABlockTransferSrcScalarPerVector<<
"x"<<BBlockTransferSrcScalarPerVector<<
", "
699 <<
"BlkGemmPipelineScheduler: "
700 << BlkGemmPipelineSchedulerToString[BlkGemmPipeSched] <<
", "
701 <<
"BlkGemmPipelineVersion: "
702 << BlkGemmPipelineVersionToString[BlkGemmPipelineVer] <<
", "
703 <<
"BlkGemmPipelinePrefetchStages: "
704 << GridwiseGemm64::BlockwiseGemmPipe::PrefetchStages;
712 auto arg = *
dynamic_cast<const Argument*
>(p_arg);
717 std::cout <<
"using workspace" << std::endl;
718 return arg.
M * arg.
N * arg.
KBatch *
sizeof(ReduceDataType);
#define GET_NXDL_PER_WAVE_IMPL
Definition: device_base.hpp:81
#define INVOKER_RUN_IMPL
Definition: device_base.hpp:94
float launch_and_time_kernel(const StreamConfig &stream_config, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args)
Definition: kernel_launch.hpp:14
__host__ constexpr __device__ T max(T x)
Definition: math.hpp:84
std::string getGemmSpecializationString(const GemmSpecialization &s)
Definition: gemm_specialization.hpp:32
GemmSpecialization
Definition: gemm_specialization.hpp:11
void flush_icache()
Definition: flush_cache.hpp:361
typename tuple_element< I, TTuple >::type tuple_element_t
Definition: tuple.hpp:208
BlockGemmPipelineVersion
Definition: blkgemmpipe_scheduler.hpp:12
__global__ void kernel_gemm_xdl_cshuffle_v3_2lds(typename GridwiseGemm::Argument karg)
Definition: gridwise_gemm_xdl_cshuffle_streamk_v3.hpp:62
constexpr Tuple< Args &... > tie(Args &... args) noexcept
Definition: tuple.hpp:218
__host__ constexpr __device__ auto generate_sequence_v2(F &&f, Number< N >)
Definition: sequence_helper.hpp:25
constexpr __device__ index_t get_warp_size()
Definition: get_id.hpp:10
BlockGemmPipelineScheduler
Definition: blkgemmpipe_scheduler.hpp:25
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition: type.hpp:297
__global__ void kernel_gemm_xdl_cshuffle_v3(typename GridwiseGemm::Argument karg)
Definition: gridwise_gemm_xdl_cshuffle_streamk_v3.hpp:38
int32_t index_t
Definition: ck.hpp:299
const GenericPointer< typename T::ValueType > T2 value
Definition: pointer.h:1350
Definition: stream_config.hpp:10
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:716
CDataType * p_c_grid
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:760
__host__ __device__ bool IsReduceAdd() const
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:748
index_t N
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:695
index_t KBatch
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:700
index_t M
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:694
index_t K
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:696
__host__ void Print() const
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:673
"Universal" GEMM kernel with SplitK support.
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:247
Definition: integral_constant.hpp:20
Definition: reduction_operator.hpp:37
Definition: device_base.hpp:197
void * p_workspace_
Definition: device_base.hpp:204
Definition: device_base.hpp:208
Definition: device_gemm_xdl_cshuffle_v3r1.hpp:149
const std::array< const void *, NumDTensor > p_ds
Definition: device_gemm_xdl_cshuffle_v3r1.hpp:178
std::array< ck::index_t, NumDTensor > StrideDs
Definition: device_gemm_xdl_cshuffle_v3r1.hpp:179
Argument(const ADataType *p_a_grid_, const BDataType *p_b_grid_, const std::array< const void *, NumDTensor > p_ds_, CDataType *p_c_grid_, index_t M_, index_t N_, index_t K_, index_t StrideA_, index_t StrideB_, std::array< ck::index_t, NumDTensor > StrideDs_, index_t StrideC_, index_t k_batch_)
Definition: device_gemm_xdl_cshuffle_v3r1.hpp:150
Definition: device_gemm_xdl_cshuffle_v3r1.hpp:215
INVOKER_RUN_IMPL float Run(const BaseArgument *p_arg, const StreamConfig &stream_config=StreamConfig{}) override
Definition: device_gemm_xdl_cshuffle_v3r1.hpp:556
float RunReduce(const Argument &arg, const StreamConfig &stream_config=StreamConfig{})
Definition: device_gemm_xdl_cshuffle_v3r1.hpp:216
float RunImp(const Argument &arg_, const StreamConfig &stream_config=StreamConfig{})
Definition: device_gemm_xdl_cshuffle_v3r1.hpp:289
Definition: device_gemm_xdl_cshuffle_v3r1.hpp:87
static constexpr GET_NXDL_PER_WAVE_IMPL auto NXdlPerWave64
Definition: device_gemm_xdl_cshuffle_v3r1.hpp:89
DeviceReduceThreadWiseMultiD< ReduceDataType, DsDataType, GemmAccDataType, CDataType, 3, 1, ReduceAdd, PassThrough, OutElementwiseOperation, 256, CShuffleBlockTransferScalarPerVector_NPerBlock, 1, 0, CShuffleBlockTransferScalarPerVector_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, decltype(DsVectorLengthSequence)> DeviceReduceInstance
Definition: device_gemm_xdl_cshuffle_v3r1.hpp:211
std::unique_ptr< BaseArgument > MakeArgumentPointer(const void *p_a, const void *p_b, std::array< const void *, NumDTensor > p_ds, void *p_c, index_t M, index_t N, index_t K, index_t StrideA, index_t StrideB, std::array< ck::index_t, NumDTensor > StrideDs, index_t StrideC, index_t KBatch, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation) override
Definition: device_gemm_xdl_cshuffle_v3r1.hpp:629
static constexpr index_t NumDTensor
Definition: device_gemm_xdl_cshuffle_v3r1.hpp:92
size_t GetWorkSpaceSize(const BaseArgument *p_arg) const override
Definition: device_gemm_xdl_cshuffle_v3r1.hpp:710
static constexpr auto NXdlPerWave32
Definition: device_gemm_xdl_cshuffle_v3r1.hpp:90
static constexpr bool IsValidCompilationParameter()
Definition: device_gemm_xdl_cshuffle_v3r1.hpp:563
static constexpr auto DsVectorLengthSequence
Definition: device_gemm_xdl_cshuffle_v3r1.hpp:185
ck::reduce::Add ReduceAdd
Definition: device_gemm_xdl_cshuffle_v3r1.hpp:182
static bool IsSupportedArgument(const Argument &arg)
Definition: device_gemm_xdl_cshuffle_v3r1.hpp:569
std::unique_ptr< BaseInvoker > MakeInvokerPointer() override
Definition: device_gemm_xdl_cshuffle_v3r1.hpp:660
bool IsSupportedArgument(const BaseArgument *p_arg) override
Definition: device_gemm_xdl_cshuffle_v3r1.hpp:602
static auto MakeArgument(const ADataType *p_a, const BDataType *p_b, const std::array< const void *, NumDTensor > p_ds, CDataType *p_c, index_t M, index_t N, index_t K, index_t StrideA, index_t StrideB, std::array< ck::index_t, NumDTensor > StrideDs, index_t StrideC, index_t KBatch, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation)
Definition: device_gemm_xdl_cshuffle_v3r1.hpp:607
static auto MakeInvoker()
Definition: device_gemm_xdl_cshuffle_v3r1.hpp:626
CElementwiseOperation OutElementwiseOperation
Definition: device_gemm_xdl_cshuffle_v3r1.hpp:183
ck::tensor_operation::element_wise::PassThrough PassThrough
Definition: device_gemm_xdl_cshuffle_v3r1.hpp:94
std::string GetTypeString() const override
Definition: device_gemm_xdl_cshuffle_v3r1.hpp:666
Definition: device_gemm_v2.hpp:57
Definition: device_reduce_threadwise_multi_d.hpp:47
Definition: unary_element_wise_operation.hpp:334
Definition: flush_cache.hpp:283