21 namespace tensor_operation {
123 template <
typename ALayout,
129 typename GemmAccDataType,
130 typename CShuffleDataType,
131 typename AElementwiseOperation,
132 typename BElementwiseOperation,
133 typename CElementwiseOperation,
145 typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
146 typename ABlockTransferThreadClusterArrangeOrder,
147 typename ABlockTransferSrcAccessOrder,
148 index_t ABlockTransferSrcVectorDim,
149 index_t ABlockTransferSrcScalarPerVector,
150 index_t ABlockTransferDstScalarPerVector_AK1,
151 bool ABlockLdsExtraM,
152 typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
153 typename BBlockTransferThreadClusterArrangeOrder,
154 typename BBlockTransferSrcAccessOrder,
155 index_t BBlockTransferSrcVectorDim,
156 index_t BBlockTransferSrcScalarPerVector,
157 index_t BBlockTransferDstScalarPerVector_BK1,
158 bool BBlockLdsExtraN,
159 index_t CShuffleMXdlPerWavePerShuffle,
160 index_t CShuffleNXdlPerWavePerShuffle,
161 typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
162 index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
165 typename ComputeTypeA = CDataType,
166 typename ComputeTypeB = ComputeTypeA,
167 bool PermuteA =
false,
168 bool PermuteB =
false>
175 AElementwiseOperation,
176 BElementwiseOperation,
177 CElementwiseOperation>
184 template <index_t NXdlPerWave_>
194 AElementwiseOperation,
195 BElementwiseOperation,
196 CElementwiseOperation,
208 ABlockTransferThreadClusterLengths_AK0_M_AK1,
209 ABlockTransferThreadClusterArrangeOrder,
210 ABlockTransferSrcAccessOrder,
211 ABlockTransferSrcVectorDim,
212 ABlockTransferSrcScalarPerVector,
213 ABlockTransferDstScalarPerVector_AK1,
216 BBlockTransferThreadClusterLengths_BK0_N_BK1,
217 BBlockTransferThreadClusterArrangeOrder,
218 BBlockTransferSrcAccessOrder,
219 BBlockTransferSrcVectorDim,
220 BBlockTransferSrcScalarPerVector,
221 BBlockTransferDstScalarPerVector_BK1,
224 CShuffleMXdlPerWavePerShuffle,
225 CShuffleNXdlPerWavePerShuffle,
226 CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
227 CShuffleBlockTransferScalarPerVector_NPerBlock,
269 template <
typename Gr
idwiseGemm>
270 float RunImp(
const typename GridwiseGemm::Argument& arg,
273 if(stream_config.log_level_ > 0)
276 GridwiseGemm::BlockwiseGemmPipe::HotLoopInstList::Print();
279 if(!GridwiseGemm::CheckValidity(arg))
281 throw std::runtime_error(
"wrong! GridwiseGemm has invalid setting");
285 std::tie(gdx, gdy, gdz) = GridwiseGemm::CalculateGridSize(arg.M, arg.N, arg.KBatch);
289 index_t k_grain = arg.KBatch * KPerBlock;
290 index_t K_split = (arg.K + k_grain - 1) / k_grain * KPerBlock;
292 const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K_split);
294 const auto Run = [&](
const auto& kernel) {
295 if(stream_config.flush_cache)
299 const auto a_grid_desc_ak0_m_ak1 = GridwiseGemm::MakeAGridDescriptor_AK0_M_AK1(
300 arg_.M, arg_.MPadded, arg_.K, arg_.KPadded, arg_.StrideA, arg_.AK0);
301 const auto b_grid_desc_bk0_n_bk1 = GridwiseGemm::MakeBGridDescriptor_BK0_N_BK1(
302 arg_.K, arg_.KPadded, arg_.N, arg_.NPadded, arg_.StrideB, arg_.BK0);
304 auto size_a_buffer = a_grid_desc_ak0_m_ak1.GetElementSpaceSize() *
306 auto size_b_buffer = b_grid_desc_bk0_n_bk1.GetElementSpaceSize() *
310 arg_, stream_config.rotating_count, size_a_buffer, size_b_buffer);
311 rotating_mem.Print();
313 auto run_flush_cache = [&]() {
320 hipGetErrorString(hipMemsetAsync(arg_.p_c_grid,
322 arg_.M * arg_.N *
sizeof(CDataType),
323 stream_config.stream_id_));
326 ave_time = ck::utility::launch_and_time_kernel_with_preprocess<false>(
338 hipGetErrorString(hipMemsetAsync(arg.p_c_grid,
340 arg.M * arg.N *
sizeof(CDataType),
341 stream_config.stream_id_));
344 stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, arg);
348 constexpr
index_t minimum_occupancy = []() {
355 return (MPerBlock * NPerBlock / BlockSize <= 128) ? 2 : 1;
363 if(has_main_k_block_loop)
393 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::One)
403 else if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
415 if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 2)
417 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::Two)
429 if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 3)
431 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
444 if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 4)
446 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
459 if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 5)
461 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
474 if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 6)
476 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::Six)
488 if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 7)
490 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
505 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::One)
515 else if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
527 if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 2)
529 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::Two)
541 if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 3)
543 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
556 if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 4)
558 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
571 if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 5)
573 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
586 if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 6)
588 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::Six)
600 if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 7)
602 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
621 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::Odd)
644 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::Odd)
670 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::Odd)
693 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::Odd)
750 return Run(*
dynamic_cast<const Argument*
>(p_arg), stream_config);
762 if(!ck::is_xdl_wmma_supported<ComputeTypeA, ComputeTypeB, MPerXDL, NPerXDL>())
778 if(
sizeof(CDataType) == 1)
786 if constexpr(std::is_same_v<ADataType, ck::f8_t> ||
787 std::is_same_v<ADataType, ck::bf8_t>)
831 const BDataType* p_b,
840 AElementwiseOperation,
841 BElementwiseOperation,
842 CElementwiseOperation)
844 return Argument{p_a, p_b, p_c, M, N, K, StrideA, StrideB, StrideC, KBatch};
860 AElementwiseOperation,
861 BElementwiseOperation,
862 CElementwiseOperation)
override
864 return std::make_unique<Argument>(
static_cast<const ADataType*
>(p_a),
865 static_cast<const BDataType*
>(p_b),
866 static_cast<CDataType*
>(p_c),
879 return std::make_unique<Invoker>(
Invoker{});
885 auto str = std::stringstream();
887 std::map<BlockGemmPipelineScheduler, std::string> BlkGemmPipelineSchedulerToString{
891 std::map<BlockGemmPipelineVersion, std::string> BlkGemmPipelineVersionToString{
904 PrefetchStages = GridwiseGemm64::BlockwiseGemmPipe::PrefetchStages;
905 AMmaKStride = GridwiseGemm64::BlockwiseGemmPipe::AMmaKStride;
912 PrefetchStages = GridwiseGemm32::BlockwiseGemmPipe::PrefetchStages;
913 AMmaKStride = GridwiseGemm32::BlockwiseGemmPipe::AMmaKStride;
918 str <<
"DeviceGemmXdlUniversal"
921 << std::string(ALayout::name)[0]
922 << std::string(BLayout::name)[0]
923 << std::string(CLayout::name)[0]
928 << MPerBlock<<
"x"<<NPerBlock<<
"x"<<KPerBlock <<
", "
930 << MPerXDL<<
"x"<<NPerXDL <<
", "
932 << MXdlPerWave<<
"x" << NXdlPerWave<<
", "
934 << ABlockTransferSrcScalarPerVector<<
"x"<<BBlockTransferSrcScalarPerVector<<
", "
935 <<
"BlkGemmPipelineScheduler: "
936 << BlkGemmPipelineSchedulerToString[BlkGemmPipeSched] <<
", "
937 <<
"BlkGemmPipelineVersion: "
938 << BlkGemmPipelineVersionToString[BlkGemmPipelineVer] <<
", "
939 <<
"BlkGemmPipelinePrefetchStages: "
940 << PrefetchStages <<
", "
#define INVOKER_RUN3_IMPL
Definition: device_base.hpp:114
#define REGISTER_EXTRA_PRINTING_METHODS
Definition: device_base.hpp:47
#define GET_NXDL_PER_WAVE_IMPL
Definition: device_base.hpp:81
float launch_and_time_kernel(const StreamConfig &stream_config, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args)
Definition: kernel_launch.hpp:14
__host__ constexpr __device__ T max(T x)
Definition: math.hpp:84
std::string getGemmSpecializationString(const GemmSpecialization &s)
Definition: gemm_specialization.hpp:32
GemmSpecialization
Definition: gemm_specialization.hpp:11
void flush_icache()
Definition: flush_cache.hpp:361
BlockGemmPipelineVersion
Definition: blkgemmpipe_scheduler.hpp:12
__global__ void kernel_gemm_xdl_cshuffle_v3_2lds(typename GridwiseGemm::Argument karg)
Definition: gridwise_gemm_xdl_cshuffle_streamk_v3.hpp:62
constexpr Tuple< Args &... > tie(Args &... args) noexcept
Definition: tuple.hpp:218
constexpr __device__ index_t get_warp_size()
Definition: get_id.hpp:10
constexpr bool is_same_v
Definition: type.hpp:283
BlockGemmPipelineScheduler
Definition: blkgemmpipe_scheduler.hpp:25
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition: type.hpp:297
__global__ void kernel_gemm_xdl_cshuffle_v3(typename GridwiseGemm::Argument karg)
Definition: gridwise_gemm_xdl_cshuffle_streamk_v3.hpp:38
int32_t index_t
Definition: ck.hpp:299
bool is_gfx11_supported()
Definition: device_prop.hpp:60
bool is_bf16_atomic_supported()
Definition: device_prop.hpp:108
Definition: stream_config.hpp:10
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:716
"Universal" GEMM kernel with SplitK support.
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:247
static constexpr __host__ bool CheckValidity(const Argument &karg)
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:1202
Definition: data_type.hpp:187
Definition: device_base.hpp:197
Definition: device_base.hpp:208
Helper structure responsible for kernel invocation.
Definition: device_gemm_xdl_cshuffle_v3.hpp:263
float RunImp(const typename GridwiseGemm::Argument &arg, const StreamConfig &stream_config=StreamConfig{})
This function issues GPU kernel execution.
Definition: device_gemm_xdl_cshuffle_v3.hpp:270
INVOKER_RUN3_IMPL float Run(const BaseArgument *p_arg, const StreamConfig &stream_config=StreamConfig{}) override
Definition: device_gemm_xdl_cshuffle_v3.hpp:747
"Universal" GEMM operation with SplitK support.
Definition: device_gemm_xdl_cshuffle_v3.hpp:178
static constexpr GET_NXDL_PER_WAVE_IMPL auto NXdlPerWave64
Definition: device_gemm_xdl_cshuffle_v3.hpp:181
bool IsSupportedArgument(const BaseArgument *p_arg) override
Definition: device_gemm_xdl_cshuffle_v3.hpp:820
bool GetPermuteA() override
Definition: device_gemm_xdl_cshuffle_v3.hpp:827
std::unique_ptr< BaseInvoker > MakeInvokerPointer() override
Definition: device_gemm_xdl_cshuffle_v3.hpp:877
index_t GetKPerBlock() override
Definition: device_gemm_xdl_cshuffle_v3.hpp:825
static bool IsSupportedArgument(const Argument &arg)
Definition: device_gemm_xdl_cshuffle_v3.hpp:760
static constexpr bool IsValidCompilationParameter()
Definition: device_gemm_xdl_cshuffle_v3.hpp:754
std::string GetTypeString() const override
Definition: device_gemm_xdl_cshuffle_v3.hpp:883
static auto MakeArgument(const ADataType *p_a, const BDataType *p_b, CDataType *p_c, index_t M, index_t N, index_t K, index_t StrideA, index_t StrideB, index_t StrideC, index_t KBatch, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation)
Definition: device_gemm_xdl_cshuffle_v3.hpp:830
bool GetPermuteB() override
Definition: device_gemm_xdl_cshuffle_v3.hpp:828
static auto MakeInvoker()
Definition: device_gemm_xdl_cshuffle_v3.hpp:847
std::unique_ptr< BaseArgument > MakeArgumentPointer(const void *p_a, const void *p_b, void *p_c, index_t M, index_t N, index_t K, index_t StrideA, index_t StrideB, index_t StrideC, index_t KBatch, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation) override
Definition: device_gemm_xdl_cshuffle_v3.hpp:850
static constexpr index_t BPackedSize
Definition: device_gemm_xdl_cshuffle_v3.hpp:246
typename GridwiseGemm64::Argument Argument
Definition: device_gemm_xdl_cshuffle_v3.hpp:237
static constexpr auto NXdlPerWave32
Definition: device_gemm_xdl_cshuffle_v3.hpp:182
static constexpr index_t APackedSize
Definition: device_gemm_xdl_cshuffle_v3.hpp:239
Definition: device_gemm_v2.hpp:22
Definition: flush_cache.hpp:283