20 template <
typename GridwiseGemm,
23 typename AGridDesc_K0_M_K1,
24 typename BGridDesc_K0_N_K1,
25 typename CGridDesc_M_N,
26 bool HasMainKBlockLoop>
28 #if CK_USE_LAUNCH_BOUNDS
31 #if CK_USE_WAVES_PER_EU
32 __attribute__((amdgpu_waves_per_eu(CK_MIN_WAVES_PER_EU, CK_MAX_WAVES_PER_EU)))
35 const FloatAB* __restrict__ p_b_grid,
36 FloatC* __restrict__ p_c_grid,
37 const AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1,
38 const BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1,
39 const CGridDesc_M_N c_grid_desc_m_n)
41 #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
43 __shared__
char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
45 GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid,
56 ignore = a_grid_desc_k0_m_k1;
57 ignore = b_grid_desc_k0_n_k1;
62 template <
typename Gr
idwiseGemm,
bool HasMainKBlockLoop>
64 #if CK_USE_LAUNCH_BOUNDS
67 #if CK_USE_WAVES_PER_EU
68 __attribute__((amdgpu_waves_per_eu(CK_MIN_WAVES_PER_EU, CK_MAX_WAVES_PER_EU)))
72 #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
74 __shared__
char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
76 const auto a_grid_desc_k0_m_k1 =
78 karg.M, karg.MPadded, karg.K, karg.K0, karg.StrideA));
79 const auto b_grid_desc_k0_n_k1 =
81 karg.K, karg.N, karg.NPadded, karg.K0, karg.StrideB));
83 karg.M, karg.MPadded, karg.N, karg.NPadded, karg.StrideC));
85 GridwiseGemm::template Run<HasMainKBlockLoop>(karg.p_a_grid,
102 typename AElementwiseOperation,
103 typename BElementwiseOperation,
104 typename CElementwiseOperation,
113 typename ABlockTransferThreadClusterLengths_K0_M_K1,
114 typename ABlockTransferThreadClusterArrangeOrder,
115 typename ABlockTransferSrcAccessOrder,
116 index_t ABlockTransferSrcVectorDim,
117 index_t ABlockTransferSrcScalarPerVector,
118 index_t ABlockTransferDstScalarPerVector_K1,
119 bool AThreadTransferSrcResetCoordinateAfterRun,
120 bool ABlockLdsExtraM,
121 typename BBlockTransferThreadClusterLengths_K0_N_K1,
122 typename BBlockTransferThreadClusterArrangeOrder,
123 typename BBlockTransferSrcAccessOrder,
124 index_t BBlockTransferSrcVectorDim,
125 index_t BBlockTransferSrcScalarPerVector,
126 index_t BBlockTransferDstScalarPerVector_K1,
127 bool BThreadTransferSrcResetCoordinateAfterRun,
128 bool BBlockLdsExtraN,
129 typename CThreadTransferSrcDstAccessOrder,
130 index_t CThreadTransferSrcDstVectorDim,
131 index_t CThreadTransferDstScalarPerVector,
132 index_t NumGemmKPrefetchStage = 1,
156 template <
typename CGr
idDesc_M_N>
159 return std::make_tuple(Block2CTileMap::CalculateGridSize(c_grid_desc_m_n), 1, 1);
203 std::cout <<
"problem {"
212 <<
"K0:" <<
K0 <<
"}" << std::endl;
230 const FloatAB* p_b_grid_,
238 :
Problem{M_, N_, K_, StrideA_, StrideB_, StrideC_},
251 decltype(GridwiseGemmPipeline_Selector<PipelineVer, NumGemmKPrefetchStage, LoopSched>())>;
257 #if CK_GFX90A_DENORM_WORKAROUND
265 constexpr
auto max_lds_align =
K1;
268 constexpr
auto a_block_desc_k0_m_k1 = [&]() {
269 if constexpr(ABlockLdsExtraM)
282 return a_block_desc_k0_m_k1;
287 constexpr
auto max_lds_align =
K1;
290 constexpr
auto b_block_desc_k0_n_k1 = [&]() {
291 if constexpr(BBlockLdsExtraN)
304 return b_block_desc_k0_n_k1;
314 constexpr
auto max_lds_align =
K1;
316 constexpr
auto a_block_space_size_aligned =
319 constexpr
auto b_block_space_size_aligned =
322 return (a_block_space_size_aligned + b_block_space_size_aligned) *
sizeof(FloatAB);
325 template <
typename AGr
idDesc_K0_M_K1,
typename BGr
idDesc_K0_N_K1,
typename CGr
idDesc_M_N>
326 __host__ __device__
static constexpr
bool
328 const BGridDesc_K0_N_K1& b_grid_desc_k0_n_k1,
329 const CGridDesc_M_N& c_grid_desc_m_n)
332 "wrong! K1 need to be known at compile-time");
334 static_assert((MPerBlock % (MPerXDL * MXdlPerWave) == 0) &&
335 (NPerBlock % (NXdlPerWave * NPerXDL)) == 0,
336 "Invalid tuning param!");
338 const auto M = a_grid_desc_k0_m_k1.GetLength(
I1);
339 const auto N = b_grid_desc_k0_n_k1.GetLength(
I1);
340 const auto K0 = a_grid_desc_k0_m_k1.GetLength(
I0);
342 if(!(M == c_grid_desc_m_n.GetLength(
I0) && N == c_grid_desc_m_n.GetLength(
I1) &&
343 K0 == b_grid_desc_k0_n_k1.GetLength(
I0) &&
K1 == a_grid_desc_k0_m_k1.GetLength(
I2) &&
344 K1 == b_grid_desc_k0_n_k1.GetLength(
I2)))
347 if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K0 % K0PerBlock == 0))
351 const auto num_k_loop = K0 / K0PerBlock;
353 if(!GridwiseGemmPipe::IsSupported(num_k_loop))
365 "wrong! K1 need to be known at compile-time");
367 static_assert((MPerBlock % (MPerXDL * MXdlPerWave) == 0) &&
368 (NPerBlock % (NXdlPerWave * NPerXDL)) == 0,
369 "Invalid tuning param!");
373 if(!GridwiseGemmPipe::IsSupported(num_k_loop))
386 return GridwiseGemmPipe::CalculateHasMainLoop(num_loop);
389 template <
typename CGr
idDesc>
390 __host__ __device__
static constexpr
auto
393 constexpr
auto max_lds_align =
K1;
396 constexpr
auto a_block_desc_k0_m_k1 = [&]() {
397 if constexpr(ABlockLdsExtraM)
411 constexpr
auto b_block_desc_k0_n_k1 = [&]() {
412 if constexpr(BBlockLdsExtraN)
425 using BlockwiseGemm =
430 decltype(a_block_desc_k0_m_k1),
431 decltype(b_block_desc_k0_n_k1),
438 return BlockwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_grid_desc_m_n);
444 template <
bool HasMainKBlockLoop,
445 typename AGridDesc_K0_M_K1,
446 typename BGridDesc_K0_N_K1,
447 typename CGridDesc_M_N>
448 __device__
static void Run(
const FloatAB* p_a_grid,
449 const FloatAB* p_b_grid,
451 void* __restrict__ p_shared,
452 const AGridDesc_K0_M_K1& a_grid_desc_k0_m_k1,
453 const BGridDesc_K0_N_K1& b_grid_desc_k0_n_k1,
454 const CGridDesc_M_N& c_grid_desc_m_n)
456 const auto c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
459 const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
460 p_a_grid, a_grid_desc_k0_m_k1.GetElementSpaceSize());
461 const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
462 p_b_grid, b_grid_desc_k0_n_k1.GetElementSpaceSize());
463 auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
464 p_c_grid, c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetElementSpaceSize());
466 const AElementwiseOperation a_element_op{};
467 const BElementwiseOperation b_element_op{};
468 const CElementwiseOperation c_element_op{};
470 const auto block_2_ctile_map =
474 const auto block_work_idx =
477 if(!block_2_ctile_map.ValidCTileIndex(
479 make_tuple(c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(
I0),
480 c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(
I1))))
486 const index_t m_block_data_idx_on_grid =
487 __builtin_amdgcn_readfirstlane(block_work_idx[
I0] * MPerBlock);
489 const index_t n_block_data_idx_on_grid =
490 __builtin_amdgcn_readfirstlane(block_work_idx[
I1] * NPerBlock);
493 constexpr
auto max_lds_align =
K1;
502 auto a_blockwise_copy =
504 AElementwiseOperation,
508 ABlockTransferThreadClusterLengths_K0_M_K1,
509 ABlockTransferThreadClusterArrangeOrder,
512 decltype(a_grid_desc_k0_m_k1),
513 decltype(a_block_desc_k0_m_k1),
514 ABlockTransferSrcAccessOrder,
516 ABlockTransferSrcVectorDim,
518 ABlockTransferSrcScalarPerVector,
519 ABlockTransferDstScalarPerVector_K1,
522 AThreadTransferSrcResetCoordinateAfterRun,
524 NumGemmKPrefetchStage>(
528 a_block_desc_k0_m_k1,
533 auto b_blockwise_copy =
535 BElementwiseOperation,
539 BBlockTransferThreadClusterLengths_K0_N_K1,
540 BBlockTransferThreadClusterArrangeOrder,
543 decltype(b_grid_desc_k0_n_k1),
544 decltype(b_block_desc_k0_n_k1),
545 BBlockTransferSrcAccessOrder,
547 BBlockTransferSrcVectorDim,
549 BBlockTransferSrcScalarPerVector,
550 BBlockTransferDstScalarPerVector_K1,
553 BThreadTransferSrcResetCoordinateAfterRun,
555 NumGemmKPrefetchStage>(
559 b_block_desc_k0_n_k1,
575 decltype(a_block_desc_k0_m_k1),
576 decltype(b_block_desc_k0_n_k1),
584 auto c_thread_buf = blockwise_gemm.GetCThreadBuffer();
587 constexpr
auto a_block_space_size_aligned =
590 auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
591 static_cast<FloatABAdjusted*
>(p_shared), a_block_desc_k0_m_k1.GetElementSpaceSize());
593 auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
595 b_block_desc_k0_n_k1.GetElementSpaceSize());
597 constexpr
auto a_block_slice_copy_step =
make_multi_index(K0PerBlock, 0, 0);
598 constexpr
auto b_block_slice_copy_step =
make_multi_index(K0PerBlock, 0, 0);
601 const auto K0 = a_grid_desc_k0_m_k1.GetLength(
I0);
602 const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(K0 / K0PerBlock);
604 GridwiseGemmPipe::template Run<HasMainKBlockLoop>(a_grid_desc_k0_m_k1,
605 a_block_desc_k0_m_k1,
609 a_block_slice_copy_step,
611 b_block_desc_k0_n_k1,
615 b_block_slice_copy_step,
618 num_k_block_main_loop);
622 constexpr
auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
623 blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
625 constexpr
auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
626 blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
628 constexpr
auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(
I0);
629 constexpr
auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(
I1);
630 constexpr
auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(
I2);
631 constexpr
auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(
I3);
632 constexpr
auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(
I4);
633 constexpr
auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(
I5);
634 constexpr
auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(
I6);
635 constexpr
auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(
I7);
639 const auto c_thread_mtx_on_block =
640 blockwise_gemm.CalculateCThreadOriginDataIndex(
I0,
I0,
I0,
I0);
642 const index_t m_thread_data_on_grid =
643 m_block_data_idx_on_grid + c_thread_mtx_on_block[
I0];
645 const index_t n_thread_data_on_grid =
646 n_block_data_idx_on_grid + c_thread_mtx_on_block[
I1];
648 const auto m_thread_data_on_grid_to_m0_m1_m2_m3_m4_adaptor =
654 const auto m_thread_data_on_grid_idx =
655 m_thread_data_on_grid_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex(
663 const auto n_thread_data_on_grid_idx =
664 n_thread_data_on_grid_to_n0_n1_n2_adaptor.CalculateBottomIndex(
670 decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2),
671 decltype(c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2),
672 CElementwiseOperation,
674 CThreadTransferSrcDstAccessOrder,
675 CThreadTransferSrcDstVectorDim,
676 CThreadTransferDstScalarPerVector,
677 CGlobalMemoryDataOperation,
680 c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2,
682 n_thread_data_on_grid_idx[
I0],
683 m_thread_data_on_grid_idx[
I1],
684 n_thread_data_on_grid_idx[
I1],
685 m_thread_data_on_grid_idx[
I2],
686 m_thread_data_on_grid_idx[
I3],
687 m_thread_data_on_grid_idx[
I4],
688 n_thread_data_on_grid_idx[
I2]),
691 c_thread_copy.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2,
694 c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2,
708 typename AElementwiseOperation,
709 typename BElementwiseOperation,
710 typename CElementwiseOperation,
720 typename ABlockTransferThreadClusterLengths_K0_M_K1,
721 typename ABlockTransferThreadClusterArrangeOrder,
722 typename ABlockTransferSrcAccessOrder,
723 index_t ABlockTransferSrcVectorDim,
724 index_t ABlockTransferSrcScalarPerVector,
725 index_t ABlockTransferDstScalarPerVector_K1,
726 bool AThreadTransferSrcResetCoordinateAfterRun,
727 bool ABlockLdsExtraM,
728 typename BBlockTransferThreadClusterLengths_K0_N_K1,
729 typename BBlockTransferThreadClusterArrangeOrder,
730 typename BBlockTransferSrcAccessOrder,
731 index_t BBlockTransferSrcVectorDim,
732 index_t BBlockTransferSrcScalarPerVector,
733 index_t BBlockTransferDstScalarPerVector_K1,
734 bool BThreadTransferSrcResetCoordinateAfterRun,
735 bool BBlockLdsExtraN,
736 typename CThreadTransferSrcDstAccessOrder,
737 index_t CThreadTransferSrcDstVectorDim,
738 index_t CThreadTransferDstScalarPerVector,
739 index_t NumGemmKPrefetchStage = 1,
747 CGlobalMemoryDataOperation,
748 AElementwiseOperation,
749 BElementwiseOperation,
750 CElementwiseOperation,
759 ABlockTransferThreadClusterLengths_K0_M_K1,
760 ABlockTransferThreadClusterArrangeOrder,
761 ABlockTransferSrcAccessOrder,
762 ABlockTransferSrcVectorDim,
763 ABlockTransferSrcScalarPerVector,
764 ABlockTransferDstScalarPerVector_K1,
765 AThreadTransferSrcResetCoordinateAfterRun,
767 BBlockTransferThreadClusterLengths_K0_N_K1,
768 BBlockTransferThreadClusterArrangeOrder,
769 BBlockTransferSrcAccessOrder,
770 BBlockTransferSrcVectorDim,
771 BBlockTransferSrcScalarPerVector,
772 BBlockTransferDstScalarPerVector_K1,
773 BThreadTransferSrcResetCoordinateAfterRun,
775 CThreadTransferSrcDstAccessOrder,
776 CThreadTransferSrcDstVectorDim,
777 CThreadTransferDstScalarPerVector,
778 NumGemmKPrefetchStage,
787 CGlobalMemoryDataOperation,
788 AElementwiseOperation,
789 BElementwiseOperation,
790 CElementwiseOperation,
799 ABlockTransferThreadClusterLengths_K0_M_K1,
800 ABlockTransferThreadClusterArrangeOrder,
801 ABlockTransferSrcAccessOrder,
802 ABlockTransferSrcVectorDim,
803 ABlockTransferSrcScalarPerVector,
804 ABlockTransferDstScalarPerVector_K1,
805 AThreadTransferSrcResetCoordinateAfterRun,
807 BBlockTransferThreadClusterLengths_K0_N_K1,
808 BBlockTransferThreadClusterArrangeOrder,
809 BBlockTransferSrcAccessOrder,
810 BBlockTransferSrcVectorDim,
811 BBlockTransferSrcScalarPerVector,
812 BBlockTransferDstScalarPerVector_K1,
813 BThreadTransferSrcResetCoordinateAfterRun,
815 CThreadTransferSrcDstAccessOrder,
816 CThreadTransferSrcDstVectorDim,
817 CThreadTransferDstScalarPerVector,
818 NumGemmKPrefetchStage,
829 __device__
static auto
832 const auto a_grid_desc_m_k = [&]() {
846 const auto KPad = K0Pad * K1Value;
881 __device__
static auto
884 const auto b_grid_desc_k_n = [&]() {
898 const auto KPad = K0Pad * K1Value;
934 __device__
static auto
937 const auto c_grid_desc_m_n = [&]() {
971 "wrong! K1 need to be known at compile-time");
973 static_assert((MPerBlock % (MPerXDL * MXdlPerWave) == 0) &&
974 (NPerBlock % (NXdlPerWave * NPerXDL)) == 0,
975 "Invalid tuning param!");
982 if(!(problem.M % MPerBlock == 0))
993 if(!(problem.N % NPerBlock == 0))
1004 if(!(problem.K0 % K0PerBlock == 0))
1012 if(problem.K % ABlockTransferSrcScalarPerVector != 0)
1019 if(problem.M % ABlockTransferSrcScalarPerVector != 0)
1027 if(problem.N % BBlockTransferSrcScalarPerVector != 0)
1034 if(problem.K % BBlockTransferSrcScalarPerVector != 0)
1043 if(!GridwiseGemmPipe::IsSupported(num_k_loop))
#define CK_MIN_BLOCK_PER_CU
Definition: ck.hpp:34
#define CK_MAX_THREAD_PER_BLOCK
Definition: ck.hpp:33
__host__ constexpr __device__ auto integer_least_multiple(X x, Y y)
Definition: math.hpp:78
__host__ constexpr __device__ auto integer_divide_ceil(X x, Y y)
Definition: math.hpp:72
GemmSpecialization
Definition: gemm_specialization.hpp:11
__host__ constexpr __device__ auto make_multi_index(Xs &&... xs)
Definition: array_multi_index.hpp:15
constexpr auto BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector()
Definition: blockwise_gemm_xdlops.hpp:606
__host__ constexpr __device__ auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition: tensor_descriptor_helper.hpp:49
InMemoryDataOperationEnum
Definition: ck.hpp:267
__host__ constexpr __device__ auto make_merge_transform(const LowLengths &low_lengths)
Definition: multi_index_transform_helper.hpp:55
__host__ constexpr __device__ auto make_naive_tensor_descriptor_aligned(const Tuple< Lengths... > &lengths, Align align)
Definition: tensor_descriptor_helper.hpp:132
__host__ constexpr __device__ auto make_single_stage_tensor_adaptor(const Transforms &transforms, LowerDimensionOldTopIdss, UpperDimensionNewTopIdss)
Definition: tensor_adaptor.hpp:429
ushort bhalf_t
Definition: data_type.hpp:24
constexpr detail::ignore_t ignore
Definition: ignore.hpp:20
__device__ uint32_t amd_wave_read_first_lane(uint32_t value)
Definition: amd_wave_read_first_lane.hpp:100
__device__ index_t get_block_1d_id()
Definition: get_id.hpp:22
typename conditional< predicate, X, Y >::type conditional_t
Definition: functional.hpp:115
__host__ constexpr __device__ auto make_pass_through_transform(const LowLength &low_length)
Definition: multi_index_transform_helper.hpp:12
__global__ void kernel_gemm_xdlops_v2r3(const FloatAB *__restrict__ p_a_grid, const FloatAB *__restrict__ p_b_grid, FloatC *__restrict__ p_c_grid, const AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1, const BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1, const CGridDesc_M_N c_grid_desc_m_n)
Definition: gridwise_gemm_xdlops_v2r3.hpp:34
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition: type.hpp:300
__host__ constexpr __device__ auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition: multi_index_transform_helper.hpp:90
LoopScheduler
Definition: loop_scheduler.hpp:15
int32_t index_t
Definition: ck.hpp:289
__host__ constexpr __device__ auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition: tensor_descriptor.hpp:319
__host__ constexpr __device__ auto make_right_pad_transform(const LowLength &low_length, const RightPadLength &right_pad, integral_constant< bool, SkipIsValidCheck >=integral_constant< bool, false >{})
Definition: multi_index_transform_helper.hpp:37
PipelineVersion
Definition: gridwise_gemm_pipeline_selector.hpp:17
typename remove_cv< T >::type remove_cv_t
Definition: type.hpp:298
constexpr LoopScheduler make_default_loop_scheduler()
Definition: loop_scheduler.hpp:20
Definition: blockwise_gemm_smfmac_xdlops.hpp:44
Definition: gridwise_gemm_xdlops_v2r3.hpp:228
const FloatAB * p_a_grid
Definition: gridwise_gemm_xdlops_v2r3.hpp:245
FloatC * p_c_grid
Definition: gridwise_gemm_xdlops_v2r3.hpp:247
const FloatAB * p_b_grid
Definition: gridwise_gemm_xdlops_v2r3.hpp:246
__host__ Argument(const FloatAB *p_a_grid_, const FloatAB *p_b_grid_, FloatC *p_c_grid_, index_t M_, index_t N_, index_t K_, index_t StrideA_, index_t StrideB_, index_t StrideC_)
Definition: gridwise_gemm_xdlops_v2r3.hpp:229
Definition: gridwise_gemm_xdlops_v2r3.hpp:182
index_t NPadded
Definition: gridwise_gemm_xdlops_v2r3.hpp:222
index_t StrideC
Definition: gridwise_gemm_xdlops_v2r3.hpp:220
index_t M
Definition: gridwise_gemm_xdlops_v2r3.hpp:215
index_t StrideA
Definition: gridwise_gemm_xdlops_v2r3.hpp:218
index_t N
Definition: gridwise_gemm_xdlops_v2r3.hpp:216
index_t K
Definition: gridwise_gemm_xdlops_v2r3.hpp:217
index_t StrideB
Definition: gridwise_gemm_xdlops_v2r3.hpp:219
index_t K0
Definition: gridwise_gemm_xdlops_v2r3.hpp:223
__host__ Problem(index_t M_, index_t N_, index_t K_, index_t StrideA_, index_t StrideB_, index_t StrideC_)
Definition: gridwise_gemm_xdlops_v2r3.hpp:183
__host__ void Print() const
Definition: gridwise_gemm_xdlops_v2r3.hpp:201
index_t MPadded
Definition: gridwise_gemm_xdlops_v2r3.hpp:221
Definition: gridwise_gemm_xdlops_v2r3.hpp:781
static __device__ auto MakeCGridDescriptor_M_N(index_t M, index_t MPad, index_t N, index_t NPad, index_t StrideC)
Definition: gridwise_gemm_xdlops_v2r3.hpp:935
static constexpr __host__ bool CheckValidity(const Problem &problem)
Definition: gridwise_gemm_xdlops_v2r3.hpp:968
static constexpr auto I1
Definition: gridwise_gemm_xdlops_v2r3.hpp:138
static constexpr auto K1
Definition: gridwise_gemm_xdlops_v2r3.hpp:147
static __device__ auto MakeAGridDescriptor_K0_M_K1(index_t M, index_t MPad, index_t K, index_t K0, index_t StrideA)
Definition: gridwise_gemm_xdlops_v2r3.hpp:830
static __device__ auto MakeBGridDescriptor_K0_N_K1(index_t K, index_t N, index_t NPad, index_t K0, index_t StrideB)
Definition: gridwise_gemm_xdlops_v2r3.hpp:882
Definition: gridwise_gemm_xdlops_v2r3.hpp:136
ThisThreadBlock< BlockSize > ThisThreadBlock
Definition: gridwise_gemm_xdlops_v2r3.hpp:149
static __host__ auto CalculateMPadded(index_t M)
Definition: gridwise_gemm_xdlops_v2r3.hpp:168
static __device__ void Run(const FloatAB *p_a_grid, const FloatAB *p_b_grid, FloatC *p_c_grid, void *__restrict__ p_shared, const AGridDesc_K0_M_K1 &a_grid_desc_k0_m_k1, const BGridDesc_K0_N_K1 &b_grid_desc_k0_n_k1, const CGridDesc_M_N &c_grid_desc_m_n)
Definition: gridwise_gemm_xdlops_v2r3.hpp:448
__host__ static constexpr __device__ bool CheckValidity(const AGridDesc_K0_M_K1 &a_grid_desc_k0_m_k1, const BGridDesc_K0_N_K1 &b_grid_desc_k0_n_k1, const CGridDesc_M_N &c_grid_desc_m_n)
Definition: gridwise_gemm_xdlops_v2r3.hpp:327
__host__ static constexpr __device__ auto GetBBlockDescriptor_K0PerBlock_NPerBlock_K1()
Definition: gridwise_gemm_xdlops_v2r3.hpp:285
static constexpr __host__ bool CalculateHasMainKBlockLoop(index_t K)
Definition: gridwise_gemm_xdlops_v2r3.hpp:382
static constexpr auto I7
Definition: gridwise_gemm_xdlops_v2r3.hpp:144
static constexpr auto I2
Definition: gridwise_gemm_xdlops_v2r3.hpp:139
static __host__ auto CalculateGridSize(const CGridDesc_M_N &c_grid_desc_m_n)
Definition: gridwise_gemm_xdlops_v2r3.hpp:157
static constexpr auto I5
Definition: gridwise_gemm_xdlops_v2r3.hpp:142
static __host__ auto CalculateNPadded(index_t N)
Definition: gridwise_gemm_xdlops_v2r3.hpp:173
FloatAB FloatABAdjusted
Definition: gridwise_gemm_xdlops_v2r3.hpp:260
static constexpr auto I0
Definition: gridwise_gemm_xdlops_v2r3.hpp:137
static constexpr auto I4
Definition: gridwise_gemm_xdlops_v2r3.hpp:141
static __host__ auto CalculateGridSize(index_t M, index_t N)
Definition: gridwise_gemm_xdlops_v2r3.hpp:163
static constexpr auto I1
Definition: gridwise_gemm_xdlops_v2r3.hpp:138
static constexpr auto K1
Definition: gridwise_gemm_xdlops_v2r3.hpp:147
static constexpr __host__ bool CheckValidity(const Problem &problem)
Definition: gridwise_gemm_xdlops_v2r3.hpp:362
static constexpr auto I6
Definition: gridwise_gemm_xdlops_v2r3.hpp:143
__host__ static constexpr __device__ auto MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(const CGridDesc &c_grid_desc_m_n)
Definition: gridwise_gemm_xdlops_v2r3.hpp:391
remove_cvref_t< decltype(GridwiseGemmPipeline_Selector< PipelineVer, NumGemmKPrefetchStage, LoopSched >())> GridwiseGemmPipe
Definition: gridwise_gemm_xdlops_v2r3.hpp:251
__host__ static constexpr __device__ auto GetABlockDescriptor_K0PerBlock_MPerBlock_K1()
Definition: gridwise_gemm_xdlops_v2r3.hpp:263
static __host__ auto CalculateK0(index_t K)
Definition: gridwise_gemm_xdlops_v2r3.hpp:178
static constexpr auto I3
Definition: gridwise_gemm_xdlops_v2r3.hpp:140
__host__ static constexpr __device__ index_t GetSharedMemoryNumberOfByte()
Definition: gridwise_gemm_xdlops_v2r3.hpp:307
static __host__ auto CalculateGridSize(index_t M, index_t N)
Definition: gridwise_gemm_xdlops_v2r3.hpp:151
Definition: sequence.hpp:43
Blockwise data transfer.
Definition: thread_group_tensor_slice_transfer_v4r1.hpp:46
Definition: threadwise_tensor_slice_transfer.hpp:39
Definition: integral_constant.hpp:10
Definition: is_known_at_compile_time.hpp:14
Definition: device_base.hpp:50
Definition: unary_element_wise_operation.hpp:241