20 template <
typename GridwiseGemm,
23 typename AGridDesc_K0_M_K1,
24 typename BGridDesc_K0_N_K1,
25 typename CGridDesc_M_N,
26 bool HasMainKBlockLoop>
28 #if CK_USE_LAUNCH_BOUNDS
31 #if CK_USE_WAVES_PER_EU
32 __attribute__((amdgpu_waves_per_eu(CK_MIN_WAVES_PER_EU, CK_MAX_WAVES_PER_EU)))
35 const FloatAB* __restrict__ p_b_grid,
36 FloatC* __restrict__ p_c_grid,
37 const AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1,
38 const BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1,
39 const CGridDesc_M_N c_grid_desc_m_n)
41 #if(defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx94__))
42 __shared__
char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
44 GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid,
55 ignore = a_grid_desc_k0_m_k1;
56 ignore = b_grid_desc_k0_n_k1;
61 template <
typename Gr
idwiseGemm,
bool HasMainKBlockLoop>
63 #if CK_USE_LAUNCH_BOUNDS
66 #if CK_USE_WAVES_PER_EU
67 __attribute__((amdgpu_waves_per_eu(CK_MIN_WAVES_PER_EU, CK_MAX_WAVES_PER_EU)))
71 #if(defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx94__))
72 __shared__
char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
74 const auto a_grid_desc_k0_m_k1 =
76 karg.M, karg.MPadded, karg.K, karg.K0, karg.StrideA));
77 const auto b_grid_desc_k0_n_k1 =
79 karg.K, karg.N, karg.NPadded, karg.K0, karg.StrideB));
81 karg.M, karg.MPadded, karg.N, karg.NPadded, karg.StrideC));
83 GridwiseGemm::template Run<HasMainKBlockLoop>(karg.p_a_grid,
100 typename AElementwiseOperation,
101 typename BElementwiseOperation,
102 typename CElementwiseOperation,
111 typename ABlockTransferThreadClusterLengths_K0_M_K1,
112 typename ABlockTransferThreadClusterArrangeOrder,
113 typename ABlockTransferSrcAccessOrder,
114 index_t ABlockTransferSrcVectorDim,
115 index_t ABlockTransferSrcScalarPerVector,
116 index_t ABlockTransferDstScalarPerVector_K1,
117 bool AThreadTransferSrcResetCoordinateAfterRun,
118 bool ABlockLdsExtraM,
119 typename BBlockTransferThreadClusterLengths_K0_N_K1,
120 typename BBlockTransferThreadClusterArrangeOrder,
121 typename BBlockTransferSrcAccessOrder,
122 index_t BBlockTransferSrcVectorDim,
123 index_t BBlockTransferSrcScalarPerVector,
124 index_t BBlockTransferDstScalarPerVector_K1,
125 bool BThreadTransferSrcResetCoordinateAfterRun,
126 bool BBlockLdsExtraN,
127 typename CThreadTransferSrcDstAccessOrder,
128 index_t CThreadTransferSrcDstVectorDim,
129 index_t CThreadTransferDstScalarPerVector,
130 index_t NumGemmKPrefetchStage = 1,
154 template <
typename CGr
idDesc_M_N>
157 return std::make_tuple(Block2CTileMap::CalculateGridSize(c_grid_desc_m_n), 1, 1);
201 std::cout <<
"problem {" <<
"M:" <<
M <<
", " <<
"N:" <<
N <<
", " <<
"K:" <<
K <<
", "
203 <<
", " <<
"MP:" <<
MPadded <<
", " <<
"NP:" <<
NPadded <<
", " <<
"K0:" <<
K0
222 const FloatAB* p_b_grid_,
230 :
Problem{M_, N_, K_, StrideA_, StrideB_, StrideC_},
243 decltype(GridwiseGemmPipeline_Selector<PipelineVer, NumGemmKPrefetchStage, LoopSched>())>;
249 #if CK_GFX90A_DENORM_WORKAROUND
257 constexpr
auto max_lds_align =
K1;
260 constexpr
auto a_block_desc_k0_m_k1 = [&]() {
261 if constexpr(ABlockLdsExtraM)
274 return a_block_desc_k0_m_k1;
279 constexpr
auto max_lds_align =
K1;
282 constexpr
auto b_block_desc_k0_n_k1 = [&]() {
283 if constexpr(BBlockLdsExtraN)
296 return b_block_desc_k0_n_k1;
306 constexpr
auto max_lds_align =
K1;
308 constexpr
auto a_block_space_size_aligned =
311 constexpr
auto b_block_space_size_aligned =
314 return (a_block_space_size_aligned + b_block_space_size_aligned) *
sizeof(FloatAB);
317 template <
typename AGr
idDesc_K0_M_K1,
typename BGr
idDesc_K0_N_K1,
typename CGr
idDesc_M_N>
318 __host__ __device__
static constexpr
bool
320 const BGridDesc_K0_N_K1& b_grid_desc_k0_n_k1,
321 const CGridDesc_M_N& c_grid_desc_m_n)
324 "wrong! K1 need to be known at compile-time");
326 static_assert((MPerBlock % (MPerXDL * MXdlPerWave) == 0) &&
327 (NPerBlock % (NXdlPerWave * NPerXDL)) == 0,
328 "Invalid tuning param!");
330 const auto M = a_grid_desc_k0_m_k1.GetLength(
I1);
331 const auto N = b_grid_desc_k0_n_k1.GetLength(
I1);
332 const auto K0 = a_grid_desc_k0_m_k1.GetLength(
I0);
334 if(!(M == c_grid_desc_m_n.GetLength(
I0) && N == c_grid_desc_m_n.GetLength(
I1) &&
335 K0 == b_grid_desc_k0_n_k1.GetLength(
I0) &&
K1 == a_grid_desc_k0_m_k1.GetLength(
I2) &&
336 K1 == b_grid_desc_k0_n_k1.GetLength(
I2)))
339 if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K0 % K0PerBlock == 0))
343 const auto num_k_loop = K0 / K0PerBlock;
345 if(!GridwiseGemmPipe::IsSupported(num_k_loop))
357 "wrong! K1 need to be known at compile-time");
359 static_assert((MPerBlock % (MPerXDL * MXdlPerWave) == 0) &&
360 (NPerBlock % (NXdlPerWave * NPerXDL)) == 0,
361 "Invalid tuning param!");
365 if(!GridwiseGemmPipe::IsSupported(num_k_loop))
378 return GridwiseGemmPipe::CalculateHasMainLoop(num_loop);
381 template <
typename CGr
idDesc>
382 __host__ __device__
static constexpr
auto
385 constexpr
auto max_lds_align =
K1;
388 constexpr
auto a_block_desc_k0_m_k1 = [&]() {
389 if constexpr(ABlockLdsExtraM)
403 constexpr
auto b_block_desc_k0_n_k1 = [&]() {
404 if constexpr(BBlockLdsExtraN)
417 using BlockwiseGemm =
422 decltype(a_block_desc_k0_m_k1),
423 decltype(b_block_desc_k0_n_k1),
430 return BlockwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_grid_desc_m_n);
436 template <
bool HasMainKBlockLoop,
437 typename AGridDesc_K0_M_K1,
438 typename BGridDesc_K0_N_K1,
439 typename CGridDesc_M_N>
440 __device__
static void Run(
const FloatAB* p_a_grid,
441 const FloatAB* p_b_grid,
443 void* __restrict__ p_shared,
444 const AGridDesc_K0_M_K1& a_grid_desc_k0_m_k1,
445 const BGridDesc_K0_N_K1& b_grid_desc_k0_n_k1,
446 const CGridDesc_M_N& c_grid_desc_m_n)
448 const auto c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
451 const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
452 p_a_grid, a_grid_desc_k0_m_k1.GetElementSpaceSize());
453 const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
454 p_b_grid, b_grid_desc_k0_n_k1.GetElementSpaceSize());
455 auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
456 p_c_grid, c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetElementSpaceSize());
458 const AElementwiseOperation a_element_op{};
459 const BElementwiseOperation b_element_op{};
460 const CElementwiseOperation c_element_op{};
462 const auto block_2_ctile_map =
466 const auto block_work_idx =
469 if(!block_2_ctile_map.ValidCTileIndex(
471 make_tuple(c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(
I0),
472 c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(
I1))))
478 const index_t m_block_data_idx_on_grid =
479 __builtin_amdgcn_readfirstlane(block_work_idx[
I0] * MPerBlock);
481 const index_t n_block_data_idx_on_grid =
482 __builtin_amdgcn_readfirstlane(block_work_idx[
I1] * NPerBlock);
485 constexpr
auto max_lds_align =
K1;
494 auto a_blockwise_copy =
496 AElementwiseOperation,
500 ABlockTransferThreadClusterLengths_K0_M_K1,
501 ABlockTransferThreadClusterArrangeOrder,
504 decltype(a_grid_desc_k0_m_k1),
505 decltype(a_block_desc_k0_m_k1),
506 ABlockTransferSrcAccessOrder,
508 ABlockTransferSrcVectorDim,
510 ABlockTransferSrcScalarPerVector,
511 ABlockTransferDstScalarPerVector_K1,
514 AThreadTransferSrcResetCoordinateAfterRun,
516 NumGemmKPrefetchStage>(
520 a_block_desc_k0_m_k1,
525 auto b_blockwise_copy =
527 BElementwiseOperation,
531 BBlockTransferThreadClusterLengths_K0_N_K1,
532 BBlockTransferThreadClusterArrangeOrder,
535 decltype(b_grid_desc_k0_n_k1),
536 decltype(b_block_desc_k0_n_k1),
537 BBlockTransferSrcAccessOrder,
539 BBlockTransferSrcVectorDim,
541 BBlockTransferSrcScalarPerVector,
542 BBlockTransferDstScalarPerVector_K1,
545 BThreadTransferSrcResetCoordinateAfterRun,
547 NumGemmKPrefetchStage>(
551 b_block_desc_k0_n_k1,
567 decltype(a_block_desc_k0_m_k1),
568 decltype(b_block_desc_k0_n_k1),
576 auto c_thread_buf = blockwise_gemm.GetCThreadBuffer();
579 constexpr
auto a_block_space_size_aligned =
582 auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
583 static_cast<FloatABAdjusted*
>(p_shared), a_block_desc_k0_m_k1.GetElementSpaceSize());
585 auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
587 b_block_desc_k0_n_k1.GetElementSpaceSize());
589 constexpr
auto a_block_slice_copy_step =
make_multi_index(K0PerBlock, 0, 0);
590 constexpr
auto b_block_slice_copy_step =
make_multi_index(K0PerBlock, 0, 0);
593 const auto K0 = a_grid_desc_k0_m_k1.GetLength(
I0);
594 const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(K0 / K0PerBlock);
596 GridwiseGemmPipe::template Run<HasMainKBlockLoop>(a_grid_desc_k0_m_k1,
597 a_block_desc_k0_m_k1,
601 a_block_slice_copy_step,
603 b_block_desc_k0_n_k1,
607 b_block_slice_copy_step,
610 num_k_block_main_loop);
614 constexpr
auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
615 blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
617 constexpr
auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
618 blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
620 constexpr
auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(
I0);
621 constexpr
auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(
I1);
622 constexpr
auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(
I2);
623 constexpr
auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(
I3);
624 constexpr
auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(
I4);
625 constexpr
auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(
I5);
626 constexpr
auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(
I6);
627 constexpr
auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(
I7);
631 const auto c_thread_mtx_on_block =
632 blockwise_gemm.CalculateCThreadOriginDataIndex(
I0,
I0,
I0,
I0);
634 const index_t m_thread_data_on_grid =
635 m_block_data_idx_on_grid + c_thread_mtx_on_block[
I0];
637 const index_t n_thread_data_on_grid =
638 n_block_data_idx_on_grid + c_thread_mtx_on_block[
I1];
640 const auto m_thread_data_on_grid_to_m0_m1_m2_m3_m4_adaptor =
646 const auto m_thread_data_on_grid_idx =
647 m_thread_data_on_grid_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex(
655 const auto n_thread_data_on_grid_idx =
656 n_thread_data_on_grid_to_n0_n1_n2_adaptor.CalculateBottomIndex(
662 decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2),
663 decltype(c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2),
664 CElementwiseOperation,
666 CThreadTransferSrcDstAccessOrder,
667 CThreadTransferSrcDstVectorDim,
668 CThreadTransferDstScalarPerVector,
669 CGlobalMemoryDataOperation,
672 c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2,
674 n_thread_data_on_grid_idx[
I0],
675 m_thread_data_on_grid_idx[
I1],
676 n_thread_data_on_grid_idx[
I1],
677 m_thread_data_on_grid_idx[
I2],
678 m_thread_data_on_grid_idx[
I3],
679 m_thread_data_on_grid_idx[
I4],
680 n_thread_data_on_grid_idx[
I2]),
683 c_thread_copy.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2,
686 c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2,
700 typename AElementwiseOperation,
701 typename BElementwiseOperation,
702 typename CElementwiseOperation,
712 typename ABlockTransferThreadClusterLengths_K0_M_K1,
713 typename ABlockTransferThreadClusterArrangeOrder,
714 typename ABlockTransferSrcAccessOrder,
715 index_t ABlockTransferSrcVectorDim,
716 index_t ABlockTransferSrcScalarPerVector,
717 index_t ABlockTransferDstScalarPerVector_K1,
718 bool AThreadTransferSrcResetCoordinateAfterRun,
719 bool ABlockLdsExtraM,
720 typename BBlockTransferThreadClusterLengths_K0_N_K1,
721 typename BBlockTransferThreadClusterArrangeOrder,
722 typename BBlockTransferSrcAccessOrder,
723 index_t BBlockTransferSrcVectorDim,
724 index_t BBlockTransferSrcScalarPerVector,
725 index_t BBlockTransferDstScalarPerVector_K1,
726 bool BThreadTransferSrcResetCoordinateAfterRun,
727 bool BBlockLdsExtraN,
728 typename CThreadTransferSrcDstAccessOrder,
729 index_t CThreadTransferSrcDstVectorDim,
730 index_t CThreadTransferDstScalarPerVector,
731 index_t NumGemmKPrefetchStage = 1,
739 CGlobalMemoryDataOperation,
740 AElementwiseOperation,
741 BElementwiseOperation,
742 CElementwiseOperation,
751 ABlockTransferThreadClusterLengths_K0_M_K1,
752 ABlockTransferThreadClusterArrangeOrder,
753 ABlockTransferSrcAccessOrder,
754 ABlockTransferSrcVectorDim,
755 ABlockTransferSrcScalarPerVector,
756 ABlockTransferDstScalarPerVector_K1,
757 AThreadTransferSrcResetCoordinateAfterRun,
759 BBlockTransferThreadClusterLengths_K0_N_K1,
760 BBlockTransferThreadClusterArrangeOrder,
761 BBlockTransferSrcAccessOrder,
762 BBlockTransferSrcVectorDim,
763 BBlockTransferSrcScalarPerVector,
764 BBlockTransferDstScalarPerVector_K1,
765 BThreadTransferSrcResetCoordinateAfterRun,
767 CThreadTransferSrcDstAccessOrder,
768 CThreadTransferSrcDstVectorDim,
769 CThreadTransferDstScalarPerVector,
770 NumGemmKPrefetchStage,
779 CGlobalMemoryDataOperation,
780 AElementwiseOperation,
781 BElementwiseOperation,
782 CElementwiseOperation,
791 ABlockTransferThreadClusterLengths_K0_M_K1,
792 ABlockTransferThreadClusterArrangeOrder,
793 ABlockTransferSrcAccessOrder,
794 ABlockTransferSrcVectorDim,
795 ABlockTransferSrcScalarPerVector,
796 ABlockTransferDstScalarPerVector_K1,
797 AThreadTransferSrcResetCoordinateAfterRun,
799 BBlockTransferThreadClusterLengths_K0_N_K1,
800 BBlockTransferThreadClusterArrangeOrder,
801 BBlockTransferSrcAccessOrder,
802 BBlockTransferSrcVectorDim,
803 BBlockTransferSrcScalarPerVector,
804 BBlockTransferDstScalarPerVector_K1,
805 BThreadTransferSrcResetCoordinateAfterRun,
807 CThreadTransferSrcDstAccessOrder,
808 CThreadTransferSrcDstVectorDim,
809 CThreadTransferDstScalarPerVector,
810 NumGemmKPrefetchStage,
821 __device__
static auto
824 const auto a_grid_desc_m_k = [&]() {
838 const auto KPad = K0Pad * K1Value;
873 __device__
static auto
876 const auto b_grid_desc_k_n = [&]() {
890 const auto KPad = K0Pad * K1Value;
926 __device__
static auto
929 const auto c_grid_desc_m_n = [&]() {
963 "wrong! K1 need to be known at compile-time");
965 static_assert((MPerBlock % (MPerXDL * MXdlPerWave) == 0) &&
966 (NPerBlock % (NXdlPerWave * NPerXDL)) == 0,
967 "Invalid tuning param!");
974 if(!(problem.M % MPerBlock == 0))
985 if(!(problem.N % NPerBlock == 0))
996 if(!(problem.K0 % K0PerBlock == 0))
1004 if(problem.K % ABlockTransferSrcScalarPerVector != 0)
1011 if(problem.M % ABlockTransferSrcScalarPerVector != 0)
1019 if(problem.N % BBlockTransferSrcScalarPerVector != 0)
1026 if(problem.K % BBlockTransferSrcScalarPerVector != 0)
1035 if(!GridwiseGemmPipe::IsSupported(num_k_loop))
#define CK_MIN_BLOCK_PER_CU
Definition: ck.hpp:30
#define CK_MAX_THREAD_PER_BLOCK
Definition: ck.hpp:29
__host__ constexpr __device__ auto integer_least_multiple(X x, Y y)
Definition: math.hpp:78
__host__ constexpr __device__ auto integer_divide_ceil(X x, Y y)
Definition: math.hpp:72
GemmSpecialization
Definition: gemm_specialization.hpp:11
__host__ constexpr __device__ auto make_multi_index(Xs &&... xs)
Definition: array_multi_index.hpp:15
constexpr auto BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector()
Definition: blockwise_gemm_xdlops.hpp:605
__host__ constexpr __device__ auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition: tensor_descriptor_helper.hpp:49
InMemoryDataOperationEnum
Definition: ck.hpp:276
__host__ constexpr __device__ auto make_merge_transform(const LowLengths &low_lengths)
Definition: multi_index_transform_helper.hpp:55
__host__ constexpr __device__ auto make_naive_tensor_descriptor_aligned(const Tuple< Lengths... > &lengths, Align align)
Definition: tensor_descriptor_helper.hpp:132
__host__ constexpr __device__ auto make_single_stage_tensor_adaptor(const Transforms &transforms, LowerDimensionOldTopIdss, UpperDimensionNewTopIdss)
Definition: tensor_adaptor.hpp:425
ushort bhalf_t
Definition: data_type.hpp:29
constexpr detail::ignore_t ignore
Definition: ignore.hpp:20
__device__ uint32_t amd_wave_read_first_lane(uint32_t value)
Definition: amd_wave_read_first_lane.hpp:100
__device__ index_t get_block_1d_id()
Definition: get_id.hpp:58
typename conditional< predicate, X, Y >::type conditional_t
Definition: functional.hpp:115
__host__ constexpr __device__ auto make_pass_through_transform(const LowLength &low_length)
Definition: multi_index_transform_helper.hpp:12
__global__ void kernel_gemm_xdlops_v2r3(const FloatAB *__restrict__ p_a_grid, const FloatAB *__restrict__ p_b_grid, FloatC *__restrict__ p_c_grid, const AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1, const BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1, const CGridDesc_M_N c_grid_desc_m_n)
Definition: gridwise_gemm_xdlops_v2r3.hpp:34
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition: type.hpp:297
__host__ constexpr __device__ auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition: multi_index_transform_helper.hpp:90
LoopScheduler
Definition: loop_scheduler.hpp:15
int32_t index_t
Definition: ck.hpp:298
__host__ constexpr __device__ auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition: tensor_descriptor.hpp:319
__host__ constexpr __device__ auto make_right_pad_transform(const LowLength &low_length, const RightPadLength &right_pad, integral_constant< bool, SkipIsValidCheck >=integral_constant< bool, false >{})
Definition: multi_index_transform_helper.hpp:37
PipelineVersion
Definition: gridwise_gemm_pipeline_selector.hpp:18
typename remove_cv< T >::type remove_cv_t
Definition: type.hpp:295
constexpr LoopScheduler make_default_loop_scheduler()
Definition: loop_scheduler.hpp:20
const GenericPointer< typename T::ValueType > T2 value
Definition: pointer.h:1350
Definition: blockwise_gemm_smfmac_xdlops.hpp:44
Definition: gridwise_gemm_xdlops_v2r3.hpp:220
const FloatAB * p_a_grid
Definition: gridwise_gemm_xdlops_v2r3.hpp:237
FloatC * p_c_grid
Definition: gridwise_gemm_xdlops_v2r3.hpp:239
const FloatAB * p_b_grid
Definition: gridwise_gemm_xdlops_v2r3.hpp:238
__host__ Argument(const FloatAB *p_a_grid_, const FloatAB *p_b_grid_, FloatC *p_c_grid_, index_t M_, index_t N_, index_t K_, index_t StrideA_, index_t StrideB_, index_t StrideC_)
Definition: gridwise_gemm_xdlops_v2r3.hpp:221
Definition: gridwise_gemm_xdlops_v2r3.hpp:180
index_t NPadded
Definition: gridwise_gemm_xdlops_v2r3.hpp:214
index_t StrideC
Definition: gridwise_gemm_xdlops_v2r3.hpp:212
index_t M
Definition: gridwise_gemm_xdlops_v2r3.hpp:207
index_t StrideA
Definition: gridwise_gemm_xdlops_v2r3.hpp:210
index_t N
Definition: gridwise_gemm_xdlops_v2r3.hpp:208
index_t K
Definition: gridwise_gemm_xdlops_v2r3.hpp:209
index_t StrideB
Definition: gridwise_gemm_xdlops_v2r3.hpp:211
index_t K0
Definition: gridwise_gemm_xdlops_v2r3.hpp:215
__host__ Problem(index_t M_, index_t N_, index_t K_, index_t StrideA_, index_t StrideB_, index_t StrideC_)
Definition: gridwise_gemm_xdlops_v2r3.hpp:181
__host__ void Print() const
Definition: gridwise_gemm_xdlops_v2r3.hpp:199
index_t MPadded
Definition: gridwise_gemm_xdlops_v2r3.hpp:213
Definition: gridwise_gemm_xdlops_v2r3.hpp:773
static __device__ auto MakeCGridDescriptor_M_N(index_t M, index_t MPad, index_t N, index_t NPad, index_t StrideC)
Definition: gridwise_gemm_xdlops_v2r3.hpp:927
static constexpr __host__ bool CheckValidity(const Problem &problem)
Definition: gridwise_gemm_xdlops_v2r3.hpp:960
static constexpr auto I1
Definition: gridwise_gemm_xdlops_v2r3.hpp:136
static constexpr auto K1
Definition: gridwise_gemm_xdlops_v2r3.hpp:145
static __device__ auto MakeAGridDescriptor_K0_M_K1(index_t M, index_t MPad, index_t K, index_t K0, index_t StrideA)
Definition: gridwise_gemm_xdlops_v2r3.hpp:822
static __device__ auto MakeBGridDescriptor_K0_N_K1(index_t K, index_t N, index_t NPad, index_t K0, index_t StrideB)
Definition: gridwise_gemm_xdlops_v2r3.hpp:874
Definition: gridwise_gemm_xdlops_v2r3.hpp:134
ThisThreadBlock< BlockSize > ThisThreadBlock
Definition: gridwise_gemm_xdlops_v2r3.hpp:147
static __host__ auto CalculateMPadded(index_t M)
Definition: gridwise_gemm_xdlops_v2r3.hpp:166
static __device__ void Run(const FloatAB *p_a_grid, const FloatAB *p_b_grid, FloatC *p_c_grid, void *__restrict__ p_shared, const AGridDesc_K0_M_K1 &a_grid_desc_k0_m_k1, const BGridDesc_K0_N_K1 &b_grid_desc_k0_n_k1, const CGridDesc_M_N &c_grid_desc_m_n)
Definition: gridwise_gemm_xdlops_v2r3.hpp:440
__host__ static constexpr __device__ bool CheckValidity(const AGridDesc_K0_M_K1 &a_grid_desc_k0_m_k1, const BGridDesc_K0_N_K1 &b_grid_desc_k0_n_k1, const CGridDesc_M_N &c_grid_desc_m_n)
Definition: gridwise_gemm_xdlops_v2r3.hpp:319
__host__ static constexpr __device__ auto GetBBlockDescriptor_K0PerBlock_NPerBlock_K1()
Definition: gridwise_gemm_xdlops_v2r3.hpp:277
static constexpr __host__ bool CalculateHasMainKBlockLoop(index_t K)
Definition: gridwise_gemm_xdlops_v2r3.hpp:374
static constexpr auto I7
Definition: gridwise_gemm_xdlops_v2r3.hpp:142
static constexpr auto I2
Definition: gridwise_gemm_xdlops_v2r3.hpp:137
static __host__ auto CalculateGridSize(const CGridDesc_M_N &c_grid_desc_m_n)
Definition: gridwise_gemm_xdlops_v2r3.hpp:155
static constexpr auto I5
Definition: gridwise_gemm_xdlops_v2r3.hpp:140
static __host__ auto CalculateNPadded(index_t N)
Definition: gridwise_gemm_xdlops_v2r3.hpp:171
FloatAB FloatABAdjusted
Definition: gridwise_gemm_xdlops_v2r3.hpp:252
static constexpr auto I0
Definition: gridwise_gemm_xdlops_v2r3.hpp:135
static constexpr auto I4
Definition: gridwise_gemm_xdlops_v2r3.hpp:139
static __host__ auto CalculateGridSize(index_t M, index_t N)
Definition: gridwise_gemm_xdlops_v2r3.hpp:161
static constexpr auto I1
Definition: gridwise_gemm_xdlops_v2r3.hpp:136
static constexpr auto K1
Definition: gridwise_gemm_xdlops_v2r3.hpp:145
static constexpr __host__ bool CheckValidity(const Problem &problem)
Definition: gridwise_gemm_xdlops_v2r3.hpp:354
static constexpr auto I6
Definition: gridwise_gemm_xdlops_v2r3.hpp:141
__host__ static constexpr __device__ auto MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(const CGridDesc &c_grid_desc_m_n)
Definition: gridwise_gemm_xdlops_v2r3.hpp:383
remove_cvref_t< decltype(GridwiseGemmPipeline_Selector< PipelineVer, NumGemmKPrefetchStage, LoopSched >())> GridwiseGemmPipe
Definition: gridwise_gemm_xdlops_v2r3.hpp:243
__host__ static constexpr __device__ auto GetABlockDescriptor_K0PerBlock_MPerBlock_K1()
Definition: gridwise_gemm_xdlops_v2r3.hpp:255
static __host__ auto CalculateK0(index_t K)
Definition: gridwise_gemm_xdlops_v2r3.hpp:176
static constexpr auto I3
Definition: gridwise_gemm_xdlops_v2r3.hpp:138
__host__ static constexpr __device__ index_t GetSharedMemoryNumberOfByte()
Definition: gridwise_gemm_xdlops_v2r3.hpp:299
static __host__ auto CalculateGridSize(index_t M, index_t N)
Definition: gridwise_gemm_xdlops_v2r3.hpp:149
Definition: sequence.hpp:43
Blockwise data transfer.
Definition: thread_group_tensor_slice_transfer_v4r1.hpp:46
Definition: threadwise_tensor_slice_transfer.hpp:39
Definition: integral_constant.hpp:20
Definition: is_known_at_compile_time.hpp:14
Definition: device_base.hpp:51
Definition: unary_element_wise_operation.hpp:334