21 template <
typename Gr
idwiseGemm,
bool HasMainKBlockLoop>
23 #if CK_USE_LAUNCH_BOUNDS
26 #if CK_USE_WAVES_PER_EU
27 __attribute__((amdgpu_waves_per_eu(CK_MIN_WAVES_PER_EU, CK_MAX_WAVES_PER_EU)))
31 #if(defined(__gfx103__) || defined(__gfx11__))
32 __shared__
char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
35 GridwiseGemm::MakeAGridDescriptor_AK0_M_AK1(karg.M, karg.K, karg.AK0, karg.StrideA));
37 GridwiseGemm::MakeBGridDescriptor_BK0_N_BK1(karg.K, karg.N, karg.BK0, karg.StrideB));
39 GridwiseGemm::MakeCGridDescriptor_M_N(karg.M, karg.N, karg.StrideC));
41 GridwiseGemm::template Run<HasMainKBlockLoop>(karg.p_a_grid,
45 a_grid_desc_ak0_m_ak1,
46 b_grid_desc_bk0_n_bk1,
61 typename AElementwiseOperation,
62 typename BElementwiseOperation,
63 typename CElementwiseOperation,
74 typename ABlockTransferThreadClusterLengths_K0_M_K1,
75 typename ABlockTransferThreadClusterArrangeOrder,
76 typename ABlockTransferSrcAccessOrder,
77 index_t ABlockTransferSrcVectorDim,
78 index_t ABlockTransferSrcScalarPerVector,
79 index_t ABlockTransferDstScalarPerVector_K1,
80 bool AThreadTransferSrcResetCoordinateAfterRun,
82 typename BBlockTransferThreadClusterLengths_K0_N_K1,
83 typename BBlockTransferThreadClusterArrangeOrder,
84 typename BBlockTransferSrcAccessOrder,
85 index_t BBlockTransferSrcVectorDim,
86 index_t BBlockTransferSrcScalarPerVector,
87 index_t BBlockTransferDstScalarPerVector_K1,
88 bool BThreadTransferSrcResetCoordinateAfterRun,
90 typename CThreadTransferSrcDstAccessOrder,
91 index_t CThreadTransferSrcDstVectorDim,
92 index_t CThreadTransferDstScalarPerVector,
93 index_t NumGemmKPrefetchStage = 1,
157 std::cout <<
"problem {" <<
"M:" <<
M <<
", " <<
"N:" <<
N <<
", " <<
"K:" <<
K <<
", "
160 <<
"AK0:" <<
AK0 <<
", " <<
"BK0:" <<
BK0 <<
"}" << std::endl;
179 const ABDataType* p_b_grid_,
180 CDataType* p_c_grid_,
187 :
Problem{M_, N_, K_, StrideA_, StrideB_, StrideC_},
200 decltype(GridwiseGemmPipeline_Selector<PipelineVer, NumGemmKPrefetchStage>())>;
205 constexpr
auto a_block_desc_ak0_m_ak1 = [&]() {
206 if constexpr(ABlockLdsExtraM)
219 return a_block_desc_ak0_m_ak1;
225 constexpr
auto b_block_desc_bk0_n_bk1 = [&]() {
226 if constexpr(BBlockLdsExtraN)
239 return b_block_desc_bk0_n_bk1;
249 a_block_desc_ak0_m_ak1.GetElementSpaceSize(),
max_lds_align);
251 b_block_desc_bk0_n_bk1.GetElementSpaceSize(),
max_lds_align);
253 return (a_block_space_size_aligned + b_block_space_size_aligned) *
sizeof(ABDataType);
259 "Wrong! AK1 must be known at the time of compilation.");
261 "Wrong! BK1 must be known at the time of compilation.");
264 MPerBlock % (MPerDpp * MDppPerWave) == 0,
265 "Invalid tuning parameters! MPerBlock must be divisible by MPerDpp * MDppPerWave.");
267 NPerBlock % (NPerDpp * NDppPerWave) == 0,
268 "Invalid tuning parameters! NPerBlock must be divisible by NPerDpp * NDppPerWave.");
271 KPerBlock % AK1Value == 0 && KPerBlock % BK1Value == 0,
272 "Invalid tuning parameters! KPerBlock must be divisible by both AK1 and BK1.");
274 static_assert(AK1Value % ABlockTransferDstScalarPerVector_K1 == 0,
275 "Invalid tuning parameters! AK1Value must be divisible by "
276 "ABlockTransferDstScalarPerVector_K1");
278 static_assert(BK1Value % BBlockTransferDstScalarPerVector_K1 == 0,
279 "Invalid tuning parameters! BK1Value must be divisible by "
280 "BBlockTransferDstScalarPerVector_K1");
287 if(!(problem.
M % MPerBlock == 0))
298 if(!(problem.
N % NPerBlock == 0))
306 if(problem.
K % ABlockTransferSrcScalarPerVector != 0)
313 if(problem.
M % ABlockTransferSrcScalarPerVector != 0)
321 if(problem.
N % BBlockTransferSrcScalarPerVector != 0)
328 if(problem.
K % BBlockTransferSrcScalarPerVector != 0)
334 if(problem.
K % KPerBlock != 0)
340 const auto num_k_loop = problem.
K / KPerBlock;
341 if(!GridwiseGemmPipe::IsSupported(num_k_loop))
351 const auto num_loop = K / KPerBlock;
353 return GridwiseGemmPipe::CalculateHasMainLoop(num_loop);
356 template <
typename CGr
idDesc>
357 __host__ __device__
static constexpr
auto
366 using BlockwiseGemm =
370 decltype(a_block_desc_ak0_m_ak1),
371 decltype(b_block_desc_bk0_n_bk1),
378 return BlockwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_N2(c_grid_desc_m_n);
383 MPerBlock, NPerBlock, KPerBlock};
385 __device__
static auto
388 const auto a_grid_desc_mraw_kraw = [&]() {
399 const auto a_grid_desc_m_k =
matrix_padder.PadADescriptor_M_K(a_grid_desc_mraw_kraw);
408 __device__
static auto
411 const auto b_grid_desc_nraw_kraw = [&]() {
422 const auto b_grid_desc_n_k =
matrix_padder.PadBDescriptor_N_K(b_grid_desc_nraw_kraw);
433 const auto c_grid_desc_mraw_nraw = [&]() {
444 return matrix_padder.PadCDescriptor_M_N(c_grid_desc_mraw_nraw);
447 template <
bool HasMainKBlockLoop,
448 typename AGridDesc_AK0_M_AK1,
449 typename BGridDesc_BK0_N_BK1,
450 typename CGridDesc_M_N>
451 __device__
static void Run(
const ABDataType* __restrict__ p_a_grid,
452 const ABDataType* __restrict__ p_b_grid,
453 CDataType* __restrict__ p_c_grid,
454 void* __restrict__ p_shared,
455 const AGridDesc_AK0_M_AK1& a_grid_desc_ak0_m_ak1,
456 const BGridDesc_BK0_N_BK1& b_grid_desc_bk0_n_bk1,
457 const CGridDesc_M_N& c_grid_desc_m_n)
459 const auto c_grid_desc_m0_n0_m1_n1_m2_n2 =
462 const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
463 p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
464 const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
465 p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize());
466 auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
467 p_c_grid, c_grid_desc_m0_n0_m1_n1_m2_n2.GetElementSpaceSize());
469 const AElementwiseOperation a_element_op{};
470 const BElementwiseOperation b_element_op{};
471 const CElementwiseOperation c_element_op{};
473 const auto block_2_ctile_map =
477 const auto block_work_idx =
480 if(!block_2_ctile_map.ValidCTileIndex(
483 c_grid_desc_m0_n0_m1_n1_m2_n2.GetLength(
I1))))
489 const index_t m_block_data_idx_on_grid =
490 __builtin_amdgcn_readfirstlane(block_work_idx[
I0] * MPerBlock);
491 const index_t n_block_data_idx_on_grid =
492 __builtin_amdgcn_readfirstlane(block_work_idx[
I1] * NPerBlock);
499 auto a_blockwise_copy =
501 AElementwiseOperation,
505 ABlockTransferThreadClusterLengths_K0_M_K1,
506 ABlockTransferThreadClusterArrangeOrder,
509 decltype(a_grid_desc_ak0_m_ak1),
510 decltype(a_block_desc_ak0_m_ak1),
511 ABlockTransferSrcAccessOrder,
513 ABlockTransferSrcVectorDim,
515 ABlockTransferSrcScalarPerVector,
516 ABlockTransferDstScalarPerVector_K1,
519 AThreadTransferSrcResetCoordinateAfterRun,
521 NumGemmKPrefetchStage>(
522 a_grid_desc_ak0_m_ak1,
525 a_block_desc_ak0_m_ak1,
529 auto b_blockwise_copy =
531 BElementwiseOperation,
535 BBlockTransferThreadClusterLengths_K0_N_K1,
536 BBlockTransferThreadClusterArrangeOrder,
539 decltype(b_grid_desc_bk0_n_bk1),
540 decltype(b_block_desc_bk0_n_bk1),
541 BBlockTransferSrcAccessOrder,
543 BBlockTransferSrcVectorDim,
545 BBlockTransferSrcScalarPerVector,
546 BBlockTransferDstScalarPerVector_K1,
549 BThreadTransferSrcResetCoordinateAfterRun,
551 NumGemmKPrefetchStage>(
552 b_grid_desc_bk0_n_bk1,
555 b_block_desc_bk0_n_bk1,
567 auto blockwise_gemm =
571 decltype(a_block_desc_ak0_m_ak1),
572 decltype(b_block_desc_bk0_n_bk1),
579 auto c_thread_buf = blockwise_gemm.GetCThreadBuffer();
583 a_block_desc_ak0_m_ak1.GetElementSpaceSize(),
max_lds_align);
585 auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
586 static_cast<ABDataType*
>(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
588 auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
589 static_cast<ABDataType*
>(p_shared) + a_block_space_size_aligned,
590 b_block_desc_bk0_n_bk1.GetElementSpaceSize());
596 const auto AK0 = a_grid_desc_ak0_m_ak1.GetLength(
I0);
598 const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(AK0 /
AK0PerBlock);
600 GridwiseGemmPipe::template Run<HasMainKBlockLoop>(a_grid_desc_ak0_m_ak1,
601 a_block_desc_ak0_m_ak1,
605 a_block_slice_copy_step,
606 b_grid_desc_bk0_n_bk1,
607 b_block_desc_bk0_n_bk1,
611 b_block_slice_copy_step,
614 num_k_block_main_loop);
618 constexpr
auto c_thread_desc_m0_n0_m1_n1_m2_n2 =
619 blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_N2();
621 constexpr
auto c_block_desc_m0_n0_m1_n1_m2_n2 =
622 blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_N2();
624 constexpr
auto M0 = c_block_desc_m0_n0_m1_n1_m2_n2.GetLength(
I0);
625 constexpr
auto N0 = c_block_desc_m0_n0_m1_n1_m2_n2.GetLength(
I1);
626 constexpr
auto M1 = c_block_desc_m0_n0_m1_n1_m2_n2.GetLength(
I2);
627 constexpr
auto N1 = c_block_desc_m0_n0_m1_n1_m2_n2.GetLength(
I3);
628 constexpr
auto M2 = c_block_desc_m0_n0_m1_n1_m2_n2.GetLength(
I4);
629 constexpr
auto N2 = c_block_desc_m0_n0_m1_n1_m2_n2.GetLength(
I5);
631 constexpr
auto MPerThread = c_thread_desc_m0_n0_m1_n1_m2_n2.GetLength(
I4);
632 constexpr
auto NPerThread = c_thread_desc_m0_n0_m1_n1_m2_n2.GetLength(
I5);
636 const auto c_thread_mtx_on_block =
637 blockwise_gemm.CalculateCThreadOriginDataIndex(
I0,
I0);
639 const index_t m_thread_data_on_grid =
640 m_block_data_idx_on_grid + c_thread_mtx_on_block[
I0];
642 const index_t n_thread_data_on_grid =
643 n_block_data_idx_on_grid + c_thread_mtx_on_block[
I1];
650 const auto m_thread_data_on_grid_idx =
651 m_thread_data_on_grid_to_m0_m1_m2_adaptor.CalculateBottomIndex(
659 const auto n_thread_data_on_grid_idx =
660 n_thread_data_on_grid_to_n0_n1_n2_adaptor.CalculateBottomIndex(
666 decltype(c_thread_desc_m0_n0_m1_n1_m2_n2),
667 decltype(c_grid_desc_m0_n0_m1_n1_m2_n2),
668 CElementwiseOperation,
670 CThreadTransferSrcDstAccessOrder,
671 CThreadTransferSrcDstVectorDim,
672 CThreadTransferDstScalarPerVector,
673 CGlobalMemoryDataOperation,
676 c_grid_desc_m0_n0_m1_n1_m2_n2,
678 n_thread_data_on_grid_idx[
I0],
679 m_thread_data_on_grid_idx[
I1],
680 n_thread_data_on_grid_idx[
I1],
681 m_thread_data_on_grid_idx[
I2],
682 n_thread_data_on_grid_idx[
I2]),
685 c_thread_copy.Run(c_thread_desc_m0_n0_m1_n1_m2_n2,
688 c_grid_desc_m0_n0_m1_n1_m2_n2,
#define CK_MIN_BLOCK_PER_CU
Definition: ck.hpp:30
#define CK_MAX_THREAD_PER_BLOCK
Definition: ck.hpp:29
Y __host__ constexpr __device__ auto lcm(X x, Y y)
Definition: math.hpp:198
__host__ constexpr __device__ auto integer_least_multiple(X x, Y y)
Definition: math.hpp:78
__host__ constexpr __device__ auto integer_divide_ceil(X x, Y y)
Definition: math.hpp:72
__host__ constexpr __device__ auto integer_divide_floor(X x, Y y)
Definition: math.hpp:66
__host__ constexpr __device__ T max(T x)
Definition: math.hpp:84
GemmSpecialization
Definition: gemm_specialization.hpp:11
__host__ constexpr __device__ auto make_multi_index(Xs &&... xs)
Definition: array_multi_index.hpp:15
__host__ constexpr __device__ auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition: tensor_descriptor_helper.hpp:49
InMemoryDataOperationEnum
Definition: ck.hpp:276
__global__ void kernel_gemm_dpp(const typename GridwiseGemm::Argument karg)
Definition: gridwise_gemm_dpp.hpp:29
__host__ constexpr __device__ auto make_merge_transform(const LowLengths &low_lengths)
Definition: multi_index_transform_helper.hpp:55
__host__ constexpr __device__ auto make_naive_tensor_descriptor_aligned(const Tuple< Lengths... > &lengths, Align align)
Definition: tensor_descriptor_helper.hpp:132
__host__ constexpr __device__ auto make_single_stage_tensor_adaptor(const Transforms &transforms, LowerDimensionOldTopIdss, UpperDimensionNewTopIdss)
Definition: tensor_adaptor.hpp:425
constexpr detail::ignore_t ignore
Definition: ignore.hpp:20
__device__ uint32_t amd_wave_read_first_lane(uint32_t value)
Definition: amd_wave_read_first_lane.hpp:100
__device__ index_t get_block_1d_id()
Definition: get_id.hpp:58
__host__ constexpr __device__ auto make_pass_through_transform(const LowLength &low_length)
Definition: multi_index_transform_helper.hpp:12
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition: type.hpp:297
__host__ constexpr __device__ auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition: multi_index_transform_helper.hpp:90
int32_t index_t
Definition: ck.hpp:298
__host__ constexpr __device__ auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition: tensor_descriptor.hpp:319
PipelineVersion
Definition: gridwise_gemm_pipeline_selector.hpp:18
typename remove_cv< T >::type remove_cv_t
Definition: type.hpp:295
const GenericPointer< typename T::ValueType > T2 value
Definition: pointer.h:1350
Definition: blockwise_gemm_dpp.hpp:33
Definition: dpp_gemm.hpp:322
Definition: gridwise_gemm_dpp.hpp:177
const ABDataType * p_a_grid
Definition: gridwise_gemm_dpp.hpp:194
const ABDataType * p_b_grid
Definition: gridwise_gemm_dpp.hpp:195
CDataType * p_c_grid
Definition: gridwise_gemm_dpp.hpp:196
__host__ Argument(const ABDataType *p_a_grid_, const ABDataType *p_b_grid_, CDataType *p_c_grid_, index_t M_, index_t N_, index_t K_, index_t StrideA_, index_t StrideB_, index_t StrideC_)
Definition: gridwise_gemm_dpp.hpp:178
Definition: gridwise_gemm_dpp.hpp:135
index_t NPadded
Definition: gridwise_gemm_dpp.hpp:170
index_t BK0
Definition: gridwise_gemm_dpp.hpp:172
index_t StrideB
Definition: gridwise_gemm_dpp.hpp:167
index_t N
Definition: gridwise_gemm_dpp.hpp:164
index_t K
Definition: gridwise_gemm_dpp.hpp:165
index_t StrideC
Definition: gridwise_gemm_dpp.hpp:168
index_t M
Definition: gridwise_gemm_dpp.hpp:163
index_t AK0
Definition: gridwise_gemm_dpp.hpp:171
index_t MPadded
Definition: gridwise_gemm_dpp.hpp:169
__host__ Problem(index_t M_, index_t N_, index_t K_, index_t StrideA_, index_t StrideB_, index_t StrideC_)
Definition: gridwise_gemm_dpp.hpp:136
__host__ void Print() const
Definition: gridwise_gemm_dpp.hpp:155
index_t StrideA
Definition: gridwise_gemm_dpp.hpp:166
Definition: gridwise_gemm_dpp.hpp:96
static __host__ auto CalculateAK0(index_t K)
Definition: gridwise_gemm_dpp.hpp:130
static __device__ void Run(const ABDataType *__restrict__ p_a_grid, const ABDataType *__restrict__ p_b_grid, CDataType *__restrict__ p_c_grid, void *__restrict__ p_shared, const AGridDesc_AK0_M_AK1 &a_grid_desc_ak0_m_ak1, const BGridDesc_BK0_N_BK1 &b_grid_desc_bk0_n_bk1, const CGridDesc_M_N &c_grid_desc_m_n)
Definition: gridwise_gemm_dpp.hpp:451
static __device__ auto MakeCGridDescriptor_M_N(index_t M, index_t N, index_t StrideC)
Definition: gridwise_gemm_dpp.hpp:431
__host__ static constexpr __device__ auto MakeCGridDescriptor_M0_N0_M1_N1_M2_N2(const CGridDesc &c_grid_desc_m_n)
Definition: gridwise_gemm_dpp.hpp:358
static constexpr auto BK0PerBlock
Definition: gridwise_gemm_dpp.hpp:107
static __host__ auto CalculateBK0(index_t K)
Definition: gridwise_gemm_dpp.hpp:131
static constexpr __host__ bool CalculateHasMainKBlockLoop(index_t K)
Definition: gridwise_gemm_dpp.hpp:349
static __device__ auto MakeBGridDescriptor_BK0_N_BK1(index_t K, index_t N, index_t BK0, index_t StrideB)
Definition: gridwise_gemm_dpp.hpp:409
ThisThreadBlock< BlockSize > ThisThreadBlock
Definition: gridwise_gemm_dpp.hpp:111
static __host__ auto CalculateGridSize(index_t M, index_t N)
Definition: gridwise_gemm_dpp.hpp:115
static constexpr auto I4
Definition: gridwise_gemm_dpp.hpp:101
static constexpr __host__ bool CheckValidity(const Problem &problem)
Definition: gridwise_gemm_dpp.hpp:256
static constexpr auto matrix_padder
Definition: gridwise_gemm_dpp.hpp:381
static constexpr auto I5
Definition: gridwise_gemm_dpp.hpp:102
static __host__ auto CalculateMPadded(index_t M)
Definition: gridwise_gemm_dpp.hpp:120
remove_cvref_t< decltype(GridwiseGemmPipeline_Selector< PipelineVer, NumGemmKPrefetchStage >())> GridwiseGemmPipe
Definition: gridwise_gemm_dpp.hpp:200
static constexpr auto AK0PerBlock
Definition: gridwise_gemm_dpp.hpp:106
static constexpr auto I3
Definition: gridwise_gemm_dpp.hpp:100
__host__ static constexpr __device__ auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()
Definition: gridwise_gemm_dpp.hpp:222
static __host__ auto CalculateNPadded(index_t N)
Definition: gridwise_gemm_dpp.hpp:125
static constexpr auto BK1
Definition: gridwise_gemm_dpp.hpp:105
static constexpr auto I2
Definition: gridwise_gemm_dpp.hpp:99
static __device__ auto MakeAGridDescriptor_AK0_M_AK1(index_t M, index_t K, index_t AK0, index_t StrideA)
Definition: gridwise_gemm_dpp.hpp:386
__host__ static constexpr __device__ index_t GetSharedMemoryNumberOfByte()
Definition: gridwise_gemm_dpp.hpp:242
static constexpr auto I1
Definition: gridwise_gemm_dpp.hpp:98
static constexpr auto I0
Definition: gridwise_gemm_dpp.hpp:97
static constexpr auto AK1
Definition: gridwise_gemm_dpp.hpp:104
static constexpr auto max_lds_align
Definition: gridwise_gemm_dpp.hpp:109
__host__ static constexpr __device__ auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
Definition: gridwise_gemm_dpp.hpp:202
Definition: sequence.hpp:43
Blockwise data transfer.
Definition: thread_group_tensor_slice_transfer_v4r1.hpp:46
Definition: threadwise_tensor_slice_transfer.hpp:39
Definition: integral_constant.hpp:20
Definition: is_known_at_compile_time.hpp:14
Definition: device_base.hpp:51
Definition: matrix_padder.hpp:180
Definition: unary_element_wise_operation.hpp:334