26 template <
typename GridwiseGemm,
27 bool HasMainKBlockLoop,
32 #if CK_USE_LAUNCH_BOUNDS
38 #if defined(__gfx9__) || defined(__gfx12__) || defined(__gfx11__)
39 if constexpr(GridwiseGemm::template IsValidCompilationParameter<CGlobalMemoryDataOperation>())
41 __shared__
char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
43 auto splitk_batch_offset =
typename GridwiseGemm::SplitKBatchOffset(karg);
45 GridwiseGemm::template Run<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
46 karg.p_a_grid + splitk_batch_offset.a_k_split_offset,
47 karg.p_b_grid + splitk_batch_offset.b_k_split_offset,
48 karg.p_c_grid + splitk_batch_offset.c_reduce_offset,
57 template <
typename GridwiseGemm,
58 bool HasMainKBlockLoop,
63 #if CK_USE_LAUNCH_BOUNDS
69 #if defined(__gfx9__) || defined(__gfx12__) || defined(__gfx11__)
72 if constexpr(GridwiseGemm::template IsValidCompilationParameter<CGlobalMemoryDataOperation>())
74 __shared__
char p_shared_0[GridwiseGemm::GetSharedMemoryNumberOfByte()];
75 __shared__
char p_shared_1[GridwiseGemm::GetSharedMemoryNumberOfByte()];
77 auto splitk_batch_offset =
typename GridwiseGemm::SplitKBatchOffset(karg);
79 GridwiseGemm::template Run_2Lds<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
80 karg.p_a_grid + splitk_batch_offset.a_k_split_offset,
81 karg.p_b_grid + splitk_batch_offset.b_k_split_offset,
82 karg.p_c_grid + splitk_batch_offset.c_reduce_offset,
197 template <
typename ALayout,
202 typename AccDataType,
203 typename CShuffleDataType,
205 typename AElementwiseOperation,
206 typename BElementwiseOperation,
207 typename CElementwiseOperation,
219 typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
220 typename ABlockTransferThreadClusterArrangeOrder,
221 typename ABlockTransferSrcAccessOrder,
222 index_t ABlockTransferSrcVectorDim,
223 index_t ABlockTransferSrcScalarPerVector,
224 index_t ABlockTransferDstScalarPerVector_AK1,
225 bool AThreadTransferSrcResetCoordinateAfterRun,
227 typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
228 typename BBlockTransferThreadClusterArrangeOrder,
229 typename BBlockTransferSrcAccessOrder,
230 index_t BBlockTransferSrcVectorDim,
231 index_t BBlockTransferSrcScalarPerVector,
232 index_t BBlockTransferDstScalarPerVector_BK1,
233 bool BThreadTransferSrcResetCoordinateAfterRun,
235 index_t CShuffleMXdlPerWavePerShuffle,
236 index_t CShuffleNXdlPerWavePerShuffle,
237 typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
238 index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
241 typename ComputeTypeA = CDataType,
242 typename ComputeTypeB = ComputeTypeA,
243 bool PermuteA =
false,
244 bool PermuteB =
false,
245 bool DoElementwiseBeforeCShuffle =
false>
270 KPerBlock < 128 && MPerXdl == 16))
321 auto K_t = K_Batch * KPerBlock;
322 return (K + K_t - 1) / K_t * (KPerBlock / AK1Value);
327 auto K_t = K_Batch * KPerBlock;
328 return (K + K_t - 1) / K_t * (KPerBlock / BK1Value);
333 auto K_t = K_Batch * KPerBlock;
334 return (K + K_t - 1) / K_t * KPerBlock;
340 auto K_t = K_Batch * KReadVec;
341 return (K + K_t - 1) / K_t * KReadVec;
354 template <index_t MNXdlPerWave, index_t MNWaves, index_t MNPerXdl,
typename TileDesc_K0_MN_K1>
372 const auto a_grid_desc_mraw_kraw = [&]() {
373 if constexpr(is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
377 else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>)
385 if constexpr(GemmSpec == GemmSpecialization::MKPadding ||
386 GemmSpec == GemmSpecialization::MNKPadding)
389 const auto a_grid_desc_m_k =
403 return a_grid_desc_ak0_m_ak1;
405 else if constexpr(GemmSpec == GemmSpecialization::MPadding ||
406 GemmSpec == GemmSpecialization::MNPadding)
410 a_grid_desc_mraw_kraw,
416 return a_grid_desc_ak0_m_ak1;
418 else if constexpr(GemmSpec == GemmSpecialization::KPadding ||
419 GemmSpec == GemmSpecialization::NKPadding)
423 a_grid_desc_mraw_kraw,
435 return a_grid_desc_ak0_m_ak1;
441 a_grid_desc_mraw_kraw,
447 return a_grid_desc_ak0_m_ak1;
454 const auto b_grid_desc_nraw_kraw = [&]() {
468 GemmSpec != GemmSpecialization::Default),
469 "pk_i4_t does not support padding");
471 if constexpr(GemmSpec == GemmSpecialization::NKPadding ||
472 GemmSpec == GemmSpecialization::MNKPadding)
475 const auto b_grid_desc_n_k =
489 return b_grid_desc_bk0_n_bk1;
491 else if constexpr(GemmSpec == GemmSpecialization::NPadding ||
492 GemmSpec == GemmSpecialization::MNPadding)
496 b_grid_desc_nraw_kraw,
502 return b_grid_desc_bk0_n_bk1;
504 else if constexpr(GemmSpec == GemmSpecialization::KPadding ||
505 GemmSpec == GemmSpecialization::MKPadding)
509 b_grid_desc_nraw_kraw,
521 return b_grid_desc_bk0_n_bk1;
525 if constexpr(!PermuteB)
529 b_grid_desc_nraw_kraw,
535 return b_grid_desc_bk0_n_bk1;
541 constexpr
index_t BK01 = KPerBlock / BK1Value;
542 const index_t BK0_ = StrideB / BK1Value;
543 const index_t BK00 = BK0_ / BK01;
545 const auto b_grid_desc_bk00_n_bk01_bk1_permute =
549 b_grid_desc_bk00_n_bk01_bk1_permute,
556 return b_grid_desc_bk0_n_bk1_permute;
561 template <
typename ABlockDesc_AK0_M_AK1>
562 __host__ __device__
static constexpr
auto
565 constexpr
index_t MWaves = MPerBlock / (MXdlPerWave * MPerXdl);
567 return MakeGemmMmaTileDescriptor<MXdlPerWave, MWaves, MPerXdl>(ABlockDesc_AK0_M_AK1{});
570 template <
typename BBlockDesc_BK0_N_BK1>
571 __host__ __device__
static constexpr
auto
574 constexpr
index_t NWaves = NPerBlock / (NXdlPerWave * NPerXdl);
576 return MakeGemmMmaTileDescriptor<NXdlPerWave, NWaves, NPerXdl>(BBlockDesc_BK0_N_BK1{});
579 __host__ __device__
static auto
582 const auto c_grid_desc_mraw_nraw = [&]() {
602 if constexpr(GemmSpec == GemmSpecialization::MNPadding ||
603 GemmSpec == GemmSpecialization::MNKPadding)
612 else if constexpr(GemmSpec == GemmSpecialization::MPadding ||
613 GemmSpec == GemmSpecialization::MKPadding)
617 c_grid_desc_mraw_nraw,
622 else if constexpr(GemmSpec == GemmSpecialization::NPadding ||
623 GemmSpec == GemmSpecialization::NKPadding)
627 c_grid_desc_mraw_nraw,
635 return c_grid_desc_mraw_nraw;
649 AElementwiseOperation a_element_op,
650 BElementwiseOperation b_element_op,
651 CElementwiseOperation c_element_op)
676 std::cout <<
"problem {"
685 <<
"KRead:" <<
KRead <<
", "
687 <<
"AK0:" <<
AK0 <<
", "
688 <<
"BK0:" <<
BK0 <<
", "
689 <<
"MBlock: " <<
MBlock <<
", "
690 <<
"NBlock: " <<
NBlock <<
"}" << std::endl;
718 const BDataType* p_b_grid_,
719 CDataType* p_c_grid_,
727 bool is_reduce_ =
false,
728 AElementwiseOperation
a_element_op = AElementwiseOperation{},
729 BElementwiseOperation
b_element_op = BElementwiseOperation{},
730 CElementwiseOperation
c_element_op = CElementwiseOperation{})
769 if constexpr(is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
773 else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>)
778 if constexpr(is_same_v<tensor_layout::gemm::RowMajor, BLayout>)
782 else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, BLayout>)
784 if constexpr(!PermuteB)
790 const int k0_offset = karg.
KRead * karg.
N;
821 constexpr
index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
822 constexpr
index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
823 constexpr
index_t WaveSize = BlockSize / (MWave * NWave);
838 constexpr
auto MLdsLayer = LdsSize < 1 ? 1 : LdsSize;
853 a_lds_block_desc_permuted,
861 a_lds_block_desc_ak0_mldslayer_m_ak1,
869 return a_lds_block_desc_ak0_m_ak1;
876 constexpr
auto M0 = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(
I1);
877 constexpr
auto M1 = MPerBlock / M0;
879 constexpr
auto KThreadWrite = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(
I0);
880 constexpr
auto K0PerThreadWrite =
AK0Number / KThreadWrite;
881 constexpr
auto KThreadRead = WaveSize / MPerXdl;
882 constexpr
auto K0PerThreadRead =
AK0Number / KThreadRead;
884 constexpr
auto kfold = (
AK1Number * M0 *
sizeof(ADataType) > 128)
886 : 128 / (
AK1Number * M0 *
sizeof(ADataType));
887 constexpr
auto KThreadReadPerm =
888 (kfold * K0PerThreadWrite / K0PerThreadRead) > 1
889 ? KThreadRead / (kfold * K0PerThreadWrite / K0PerThreadRead)
893 constexpr
auto mpair = (
AK1Number * MPerXdl *
sizeof(ADataType) > 128)
895 : ((128 / (
AK1Number * MPerXdl *
sizeof(ADataType))) > M0
897 : 128 / (
AK1Number * MPerXdl *
sizeof(ADataType)));
903 Number<kfold * M0 / mpair>{},
922 a_lds_block_desc_permuted,
944 a_lds_block_desc_unmerged,
947 Number<KThreadWrite / kfold / KThreadReadPerm>{},
956 return a_lds_block_desc_ak0_m_ak1;
962 constexpr
index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
963 constexpr
index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
964 constexpr
index_t WaveSize = BlockSize / (MWave * NWave);
978 constexpr
index_t NLdsLayer = LdsSize < 1 ? 1 : LdsSize;
993 b_lds_block_desc_permuted,
1001 b_lds_block_desc_bk0_nldslayer_n_bk1,
1009 return b_lds_block_desc_bk0_n_bk1;
1013 constexpr
auto N0 = BBlockTransferThreadClusterLengths_BK0_N_BK1{}.At(
I1);
1014 constexpr
auto N1 = NPerBlock / N0;
1016 constexpr
auto KThreadWrite = BBlockTransferThreadClusterLengths_BK0_N_BK1{}.At(
I0);
1017 constexpr
auto K0PerThreadWrite =
BK0Number / KThreadWrite;
1018 constexpr
auto KThreadRead = WaveSize / NPerXdl;
1019 constexpr
auto K0PerThreadRead =
BK0Number / KThreadRead;
1021 constexpr
auto kfold = (
BK1Number * N0 *
sizeof(BDataType) > 128)
1023 : 128 / (
BK1Number * N0 *
sizeof(BDataType));
1024 constexpr
auto KThreadReadPerm =
1025 (kfold * K0PerThreadWrite / K0PerThreadRead) > 1
1026 ? KThreadRead / (kfold * K0PerThreadWrite / K0PerThreadRead)
1030 constexpr
auto npair = (
BK1Number * NPerXdl *
sizeof(BDataType) > 128)
1032 : ((128 / (
BK1Number * NPerXdl *
sizeof(BDataType))) > N0
1034 : 128 / (
BK1Number * NPerXdl *
sizeof(BDataType)));
1040 Number<kfold * N0 / npair>{},
1059 b_lds_block_desc_permuted,
1081 b_lds_block_desc_unmerged,
1084 Number<KThreadWrite / kfold / KThreadReadPerm>{},
1093 return b_lds_block_desc_bk0_n_bk1;
1099 constexpr
index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
1100 constexpr
index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
1102 constexpr
auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
1109 return c_shuffle_block_desc_mblock_mperblock_nblock_nperblock;
1127 ABlockTransferSrcScalarPerVector,
1128 BBlockTransferSrcScalarPerVector,
1148 a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
1151 b_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align);
1154 constexpr
auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
1157 constexpr
auto c_block_size =
1158 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize();
1161 b_block_space_size_aligned *
sizeof(BDataType) /
BPackedSize),
1162 c_block_size *
sizeof(CShuffleDataType));
1165 template <InMemoryDataOperationEnum CGlobalMemoryDataOperation>
1168 enum struct Arch : bool
1170 #if defined(__gfx950__)
1171 is_gfx950_build =
true,
1173 is_gfx950_build =
false,
1178 if constexpr(
static_cast<bool>(Arch::is_gfx950_build) ||
1190 return ck::tensor_operation::device::IsValidGemmCompilationParameter<
1199 CGlobalMemoryDataOperation>();
1204 static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) &&
1205 (NPerBlock % (NXdlPerWave * NPerXdl)) == 0,
1206 "Invalid tuning param!");
1214 if(!(karg.
M % MPerBlock == 0))
1218 std::cout <<
"Arg M value is not a multiple of MPerBlock! M: " << karg.
M <<
" "
1219 << __FILE__ <<
":" << __LINE__ <<
", in function: " << __func__
1232 if(!(karg.
N % NPerBlock == 0))
1236 std::cout <<
"Arg N value is not a multiple of NPerBlock! N: " << karg.
N <<
" "
1237 << __FILE__ <<
":" << __LINE__ <<
", in function: " << __func__
1250 auto K_t = karg.
KBatch * KPerBlock;
1251 if(!(karg.
K % K_t == 0))
1255 std::cout <<
"Arg K value is not a multiple of K_Batch * K0PerBlock * K1! K: "
1256 << karg.
K <<
" " << __FILE__ <<
":" << __LINE__
1257 <<
", in function: " << __func__ << std::endl;
1265 auto K_t = karg.
KBatch * KReadVec;
1267 if((KReadPadSplited * (karg.
KBatch - 1)) >= karg.
K)
1275 if(karg.
K % ABlockTransferSrcScalarPerVector != 0)
1279 std::cout <<
"Arg K (" << karg.
K
1280 <<
") value is not a multiple of ABlockTransferSrcScalarPerVector ("
1281 << ABlockTransferSrcScalarPerVector <<
" )! " << __FILE__ <<
":"
1282 << __LINE__ <<
", in function: " << __func__ << std::endl;
1289 if(karg.
M % ABlockTransferSrcScalarPerVector != 0)
1293 std::cout <<
"Arg M (" << karg.
M
1294 <<
") value is not a multiple of ABlockTransferSrcScalarPerVector ("
1295 << ABlockTransferSrcScalarPerVector <<
" )! " << __FILE__ <<
":"
1296 << __LINE__ <<
", in function: " << __func__ << std::endl;
1304 if(karg.
N % BBlockTransferSrcScalarPerVector != 0)
1308 std::cout <<
"Arg N (" << karg.
N
1309 <<
") value is not a multiple of BBlockTransferSrcScalarPerVector ("
1310 << BBlockTransferSrcScalarPerVector <<
" )! " << __FILE__ <<
":"
1311 << __LINE__ <<
", in function: " << __func__ << std::endl;
1318 if(karg.
K % BBlockTransferSrcScalarPerVector != 0)
1322 std::cout <<
"Arg K (" << karg.
K
1323 <<
") value is not a multiple of BBlockTransferSrcScalarPerVector ("
1324 << BBlockTransferSrcScalarPerVector <<
" )! " << __FILE__ <<
":"
1325 << __LINE__ <<
", in function: " << __func__ << std::endl;
1333 if(karg.
N % CShuffleBlockTransferScalarPerVector_NPerBlock != 0)
1337 std::cout <<
"Arg N (" << karg.
N
1338 <<
") value is not a multiple of "
1339 "CShuffleBlockTransferScalarPerVector_NPerBlock ("
1340 << CShuffleBlockTransferScalarPerVector_NPerBlock <<
" )! "
1341 << __FILE__ <<
":" << __LINE__ <<
", in function: " << __func__
1349 if(karg.
M % CShuffleBlockTransferScalarPerVector_NPerBlock != 0)
1353 std::cout <<
"Arg M (" << karg.
M
1354 <<
") value is not a multiple of "
1355 "CShuffleBlockTransferScalarPerVector_NPerBlock ("
1356 << CShuffleBlockTransferScalarPerVector_NPerBlock <<
" )! "
1357 << __FILE__ <<
":" << __LINE__ <<
", in function: " << __func__
1373 std::cout <<
" KBatch: " << karg.
KBatch <<
" > 1 is not support yet" << __FILE__
1374 <<
":" << __LINE__ <<
", in function: " << __func__ << std::endl;
1384 const auto num_k_loop = karg.
AK0 / (KPerBlock / AK1Value);
1388 if(num_k_loop <= BlockwiseGemmPipe::PrefetchStages)
1400 const index_t num_loop = K / KPerBlock;
1402 return BlockwiseGemmPipe::BlockHasHotloop(num_loop);
1407 const index_t num_loop = K / KPerBlock;
1409 return BlockwiseGemmPipe::BlockLoopTailNum(num_loop);
1412 template <
typename CGr
idDesc>
1414 const CGridDesc& c_grid_desc_m_n,
index_t MBlock,
index_t NBlock)
1423 return c_grid_desc_mblock_mperblock_nblock_nperblock;
1431 template <
typename AGridDesc_AK0_M_K1,
1432 typename BGridDesc_BK0_N_K1,
1433 typename CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
1434 bool HasMainKBlockLoop,
1437 __device__
static void Run(
const ADataType* p_a_grid,
1438 const BDataType* p_b_grid,
1439 CDataType* p_c_grid,
1442 const AGridDesc_AK0_M_K1& a_grid_desc_ak0_m_ak1,
1443 const BGridDesc_BK0_N_K1& b_grid_desc_bk0_n_bk1,
1444 const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock&
1445 c_grid_desc_mblock_mperblock_nblock_nperblock)
1447 const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
1448 p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
1449 const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
1450 p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize());
1451 auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
1452 p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
1457 const auto block_work_idx =
1460 if(!block_2_ctile_map.ValidCTileIndex(
1462 make_tuple(c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(
I0),
1463 c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(
I2))))
1468 const index_t block_m_id = __builtin_amdgcn_readfirstlane(block_work_idx[
I0]);
1469 const index_t block_n_id = __builtin_amdgcn_readfirstlane(block_work_idx[
I1]);
1472 const index_t m_block_data_idx_on_grid =
1473 __builtin_amdgcn_readfirstlane(block_m_id * MPerBlock);
1475 const index_t n_block_data_idx_on_grid =
1476 __builtin_amdgcn_readfirstlane(block_n_id * NPerBlock);
1488 auto a_blockwise_copy =
1490 AElementwiseOperation,
1494 ABlockTransferThreadClusterLengths_AK0_M_AK1,
1495 ABlockTransferThreadClusterArrangeOrder,
1498 decltype(a_grid_desc_ak0_m_ak1),
1499 decltype(a_block_desc_ak0_m_ak1),
1500 ABlockTransferSrcAccessOrder,
1502 ABlockTransferSrcVectorDim,
1504 ABlockTransferSrcScalarPerVector,
1505 ABlockTransferDstScalarPerVector_AK1,
1508 AThreadTransferSrcResetCoordinateAfterRun,
1510 BlockwiseGemmPipe::GlobalBufferNum>(
1511 a_grid_desc_ak0_m_ak1,
1514 a_block_desc_ak0_m_ak1,
1519 auto b_blockwise_copy =
1521 BElementwiseOperation,
1525 BBlockTransferThreadClusterLengths_BK0_N_BK1,
1526 BBlockTransferThreadClusterArrangeOrder,
1529 decltype(b_grid_desc_bk0_n_bk1),
1530 decltype(b_block_desc_bk0_n_bk1),
1531 BBlockTransferSrcAccessOrder,
1533 BBlockTransferSrcVectorDim,
1535 BBlockTransferSrcScalarPerVector,
1536 BBlockTransferDstScalarPerVector_BK1,
1539 BThreadTransferSrcResetCoordinateAfterRun,
1541 BlockwiseGemmPipe::GlobalBufferNum>(
1542 b_grid_desc_bk0_n_bk1,
1545 b_block_desc_bk0_n_bk1,
1551 a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
1554 auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
1555 static_cast<ADataType*
>(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
1557 auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
1558 reinterpret_cast<BDataType*
>(
static_cast<char*
>(p_shared) + a_block_space_size_aligned *
1561 b_block_desc_bk0_n_bk1.GetElementSpaceSize());
1567 static_assert(std::is_default_constructible_v<BlockwiseGemmPipe>);
1569 auto c_thread_buf = blockwise_gemm_pipeline.GetCThreadBuffer();
1571 const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(
1572 (a_grid_desc_ak0_m_ak1.GetLength(
I0) * a_grid_desc_ak0_m_ak1.GetLength(
I2)) /
1575 blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, TailNum>(a_grid_desc_ak0_m_ak1,
1576 a_block_desc_ak0_m_ak1,
1580 a_block_slice_copy_step,
1581 b_grid_desc_bk0_n_bk1,
1582 b_block_desc_bk0_n_bk1,
1586 b_block_slice_copy_step,
1588 num_k_block_main_loop);
1592 static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 &&
1593 NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0,
1596 constexpr
index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
1597 constexpr
index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
1600 constexpr
auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
1601 blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
1605 constexpr
auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp =
1606 blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
1608 constexpr
auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I0);
1609 constexpr
auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I1);
1610 constexpr
auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I2);
1611 constexpr
auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I3);
1612 constexpr
auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I4);
1613 constexpr
auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I5);
1614 constexpr
auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I6);
1615 constexpr
auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I7);
1617 constexpr
auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
1620 auto c_shuffle_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
1621 static_cast<CShuffleDataType*
>(p_shared),
1622 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
1625 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
1645 const auto c_thread_mtx_on_block =
1646 blockwise_gemm_pipeline.CalculateCThreadOriginDataIndex(
I0,
I0,
I0,
I0);
1648 const index_t m_thread_data_on_block = c_thread_mtx_on_block[
I0];
1649 const index_t n_thread_data_on_block = c_thread_mtx_on_block[
I1];
1651 const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor =
1657 const auto m_thread_data_on_block_idx =
1658 m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex(
1661 const auto n_thread_data_on_block_to_n0_n1_n2_adaptor =
1667 const auto n_thread_data_on_block_idx =
1668 n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex(
1672 const auto& vpgr_to_lds_element_op = [&] {
1673 if constexpr(DoElementwiseBeforeCShuffle)
1679 return pass_through;
1682 const auto& lds_to_global_element_op = [&] {
1683 if constexpr(!DoElementwiseBeforeCShuffle)
1689 return pass_through;
1697 decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2),
1698 decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2),
1700 CElementwiseOperation,
1702 Sequence<CShuffleMXdlPerWavePerShuffle,
1703 CShuffleNXdlPerWavePerShuffle,
1715 true>{c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
1718 m_thread_data_on_block_idx[
I1],
1719 n_thread_data_on_block_idx[
I1],
1720 m_thread_data_on_block_idx[
I2],
1721 m_thread_data_on_block_idx[
I3],
1722 m_thread_data_on_block_idx[
I4],
1723 n_thread_data_on_block_idx[
I2]),
1724 vpgr_to_lds_element_op()};
1730 CElementwiseOperation,
1732 CGlobalMemoryDataOperation,
1734 CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
1736 CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>,
1737 CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
1741 decltype(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock),
1742 decltype(c_grid_desc_mblock_mperblock_nblock_nperblock),
1745 CShuffleBlockTransferScalarPerVector_NPerBlock,
1748 {c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
1750 c_grid_desc_mblock_mperblock_nblock_nperblock,
1752 lds_to_global_element_op()};
1755 constexpr
auto sfc_c_vgpr =
1758 Sequence<CShuffleMXdlPerWavePerShuffle,
1759 CShuffleNXdlPerWavePerShuffle,
1768 constexpr
auto sfc_c_global =
1772 CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
1774 CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{};
1776 constexpr
index_t num_access = sfc_c_vgpr.GetNumOfAccess();
1778 static_assert(num_access == sfc_c_global.GetNumOfAccess(),
"wrong!");
1785 c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2,
1786 sfc_c_vgpr.GetIndexTupleOfNumber(access_id),
1788 c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
1789 c_shuffle_block_buf);
1795 c_shuffle_block_copy_lds_to_global.Run(
1796 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
1797 c_shuffle_block_buf,
1798 c_grid_desc_mblock_mperblock_nblock_nperblock,
1801 if constexpr(access_id < num_access - 1)
1803 constexpr
auto c_global_step = sfc_c_global.GetForwardStep(access_id);
1806 c_shuffle_block_copy_lds_to_global.MoveDstSliceWindow(
1807 c_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step);
1813 template <
bool HasMainKBlockLoop,
1816 __device__
static void Run(
const ADataType* p_a_grid,
1817 const BDataType* p_b_grid,
1818 CDataType* p_c_grid,
1828 const auto c_grid_desc_mblock_mperblock_nblock_nperblock =
1832 Run<decltype(a_grid_desc_ak0_m_ak1),
1833 decltype(b_grid_desc_bk0_n_bk1),
1834 decltype(c_grid_desc_mblock_mperblock_nblock_nperblock),
1836 CGlobalMemoryDataOperation,
1842 a_grid_desc_ak0_m_ak1,
1843 b_grid_desc_bk0_n_bk1,
1844 c_grid_desc_mblock_mperblock_nblock_nperblock);
1847 template <
typename AGridDesc_AK0_M_K1,
1848 typename BGridDesc_BK0_N_K1,
1849 typename CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
1850 bool HasMainKBlockLoop,
1853 __device__
static void Run_2Lds(
const ADataType* p_a_grid,
1854 const BDataType* p_b_grid,
1855 CDataType* p_c_grid,
1859 const AGridDesc_AK0_M_K1& a_grid_desc_ak0_m_ak1,
1860 const BGridDesc_BK0_N_K1& b_grid_desc_bk0_n_bk1,
1861 const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock&
1862 c_grid_desc_mblock_mperblock_nblock_nperblock)
1864 const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
1865 p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
1866 const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
1867 p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize());
1868 auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
1869 p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
1874 const auto block_work_idx =
1877 if(!block_2_ctile_map.ValidCTileIndex(
1879 make_tuple(c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(
I0),
1880 c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(
I2))))
1885 const index_t block_m_id = __builtin_amdgcn_readfirstlane(block_work_idx[
I0]);
1886 const index_t block_n_id = __builtin_amdgcn_readfirstlane(block_work_idx[
I1]);
1889 const index_t m_block_data_idx_on_grid =
1890 __builtin_amdgcn_readfirstlane(block_m_id * MPerBlock);
1892 const index_t n_block_data_idx_on_grid =
1893 __builtin_amdgcn_readfirstlane(block_n_id * NPerBlock);
1905 auto a_blockwise_copy =
1907 AElementwiseOperation,
1911 ABlockTransferThreadClusterLengths_AK0_M_AK1,
1912 ABlockTransferThreadClusterArrangeOrder,
1915 decltype(a_grid_desc_ak0_m_ak1),
1916 decltype(a_block_desc_ak0_m_ak1),
1917 ABlockTransferSrcAccessOrder,
1919 ABlockTransferSrcVectorDim,
1921 ABlockTransferSrcScalarPerVector,
1922 ABlockTransferDstScalarPerVector_AK1,
1925 AThreadTransferSrcResetCoordinateAfterRun,
1927 BlockwiseGemmPipe::GlobalBufferNum>(
1928 a_grid_desc_ak0_m_ak1,
1931 a_block_desc_ak0_m_ak1,
1936 auto b_blockwise_copy =
1938 BElementwiseOperation,
1942 BBlockTransferThreadClusterLengths_BK0_N_BK1,
1943 BBlockTransferThreadClusterArrangeOrder,
1946 decltype(b_grid_desc_bk0_n_bk1),
1947 decltype(b_block_desc_bk0_n_bk1),
1948 BBlockTransferSrcAccessOrder,
1950 BBlockTransferSrcVectorDim,
1952 BBlockTransferSrcScalarPerVector,
1953 BBlockTransferDstScalarPerVector_BK1,
1956 BThreadTransferSrcResetCoordinateAfterRun,
1958 BlockwiseGemmPipe::GlobalBufferNum>(
1959 b_grid_desc_bk0_n_bk1,
1962 b_block_desc_bk0_n_bk1,
1968 a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
1970 auto a_block_buf_ping = make_dynamic_buffer<AddressSpaceEnum::Lds>(
1971 static_cast<ADataType*
>(p_shared_0), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
1973 auto b_block_buf_ping = make_dynamic_buffer<AddressSpaceEnum::Lds>(
1974 bit_cast<BDataType*>(
static_cast<char*
>(p_shared_0) +
1975 a_block_space_size_aligned *
sizeof(ADataType)),
1976 b_block_desc_bk0_n_bk1.GetElementSpaceSize());
1978 auto a_block_buf_pong = make_dynamic_buffer<AddressSpaceEnum::Lds>(
1979 static_cast<ADataType*
>(p_shared_1), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
1981 auto b_block_buf_pong = make_dynamic_buffer<AddressSpaceEnum::Lds>(
1982 bit_cast<BDataType*>(bit_cast<char*>(p_shared_1) +
1983 a_block_space_size_aligned *
sizeof(ADataType)),
1984 b_block_desc_bk0_n_bk1.GetElementSpaceSize());
1986 auto a_block_bufs =
make_tuple(a_block_buf_ping, a_block_buf_pong);
1987 auto b_block_bufs =
make_tuple(b_block_buf_ping, b_block_buf_pong);
1993 static_assert(std::is_default_constructible_v<BlockwiseGemmPipe>);
1995 auto c_thread_buf = blockwise_gemm_pipeline.GetCThreadBuffer();
1997 const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(
1998 (a_grid_desc_ak0_m_ak1.GetLength(
I0) * a_grid_desc_ak0_m_ak1.GetLength(
I2)) /
2001 blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, TailNum>(a_grid_desc_ak0_m_ak1,
2002 a_block_desc_ak0_m_ak1,
2006 a_block_slice_copy_step,
2007 b_grid_desc_bk0_n_bk1,
2008 b_block_desc_bk0_n_bk1,
2012 b_block_slice_copy_step,
2014 num_k_block_main_loop);
2018 static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 &&
2019 NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0,
2022 constexpr
index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
2023 constexpr
index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
2026 constexpr
auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
2027 blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
2031 constexpr
auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp =
2032 blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
2034 constexpr
auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I0);
2035 constexpr
auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I1);
2036 constexpr
auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I2);
2037 constexpr
auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I3);
2038 constexpr
auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I4);
2039 constexpr
auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I5);
2040 constexpr
auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I6);
2041 constexpr
auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I7);
2043 constexpr
auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
2046 auto c_shuffle_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
2047 static_cast<CShuffleDataType*
>(p_shared_0),
2048 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
2051 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
2071 const auto c_thread_mtx_on_block =
2072 blockwise_gemm_pipeline.CalculateCThreadOriginDataIndex(
I0,
I0,
I0,
I0);
2074 const index_t m_thread_data_on_block = c_thread_mtx_on_block[
I0];
2075 const index_t n_thread_data_on_block = c_thread_mtx_on_block[
I1];
2077 const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor =
2083 const auto m_thread_data_on_block_idx =
2084 m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex(
2087 const auto n_thread_data_on_block_to_n0_n1_n2_adaptor =
2093 const auto n_thread_data_on_block_idx =
2094 n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex(
2098 auto c_thread_copy_vgpr_to_lds =
2101 decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2),
2102 decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2),
2104 Sequence<CShuffleMXdlPerWavePerShuffle,
2105 CShuffleNXdlPerWavePerShuffle,
2118 c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
2121 m_thread_data_on_block_idx[
I1],
2122 n_thread_data_on_block_idx[
I1],
2123 m_thread_data_on_block_idx[
I2],
2124 m_thread_data_on_block_idx[
I3],
2125 m_thread_data_on_block_idx[
I4],
2126 n_thread_data_on_block_idx[
I2]),
2132 CElementwiseOperation,
2133 CGlobalMemoryDataOperation,
2135 CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
2137 CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>,
2138 CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
2142 decltype(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock),
2143 decltype(c_grid_desc_mblock_mperblock_nblock_nperblock),
2146 CShuffleBlockTransferScalarPerVector_NPerBlock,
2149 {c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
2151 c_grid_desc_mblock_mperblock_nblock_nperblock,
2156 constexpr
auto sfc_c_vgpr =
2159 Sequence<CShuffleMXdlPerWavePerShuffle,
2160 CShuffleNXdlPerWavePerShuffle,
2169 constexpr
auto sfc_c_global =
2173 CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
2175 CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{};
2177 constexpr
index_t num_access = sfc_c_vgpr.GetNumOfAccess();
2179 static_assert(num_access == sfc_c_global.GetNumOfAccess(),
"wrong!");
2186 c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2,
2187 sfc_c_vgpr.GetIndexTupleOfNumber(access_id),
2189 c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
2190 c_shuffle_block_buf);
2196 c_shuffle_block_copy_lds_to_global.Run(
2197 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
2198 c_shuffle_block_buf,
2199 c_grid_desc_mblock_mperblock_nblock_nperblock,
2202 if constexpr(access_id < num_access - 1)
2204 constexpr
auto c_global_step = sfc_c_global.GetForwardStep(access_id);
2207 c_shuffle_block_copy_lds_to_global.MoveDstSliceWindow(
2208 c_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step);
2214 template <
bool HasMainKBlockLoop,
2217 __device__
static void Run_2Lds(
const ADataType* p_a_grid,
2218 const BDataType* p_b_grid,
2219 CDataType* p_c_grid,
2231 const auto c_grid_desc_mblock_mperblock_nblock_nperblock =
2235 Run_2Lds<decltype(a_grid_desc_ak0_m_ak1),
2236 decltype(b_grid_desc_bk0_n_bk1),
2237 decltype(c_grid_desc_mblock_mperblock_nblock_nperblock),
2239 CGlobalMemoryDataOperation,
2246 a_grid_desc_ak0_m_ak1,
2247 b_grid_desc_bk0_n_bk1,
2248 c_grid_desc_mblock_mperblock_nblock_nperblock);
#define CK_MAX_THREAD_PER_BLOCK
Definition: ck.hpp:30
Y __host__ constexpr __device__ auto lcm(X x, Y y)
Definition: math.hpp:198
__host__ constexpr __device__ auto integer_least_multiple(X x, Y y)
Definition: math.hpp:78
__host__ constexpr __device__ auto integer_divide_ceil(X x, Y y)
Definition: math.hpp:72
__host__ constexpr __device__ T max(T x)
Definition: math.hpp:84
GemmSpecialization
Definition: gemm_specialization.hpp:11
__host__ constexpr __device__ auto make_multi_index(Xs &&... xs)
Definition: array_multi_index.hpp:15
__host__ constexpr __device__ auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition: tensor_descriptor_helper.hpp:49
InMemoryDataOperationEnum
Definition: ck.hpp:277
__host__ constexpr __device__ auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition: tensor_descriptor_helper.hpp:101
__host__ constexpr __device__ auto make_merge_transform(const LowLengths &low_lengths)
Definition: multi_index_transform_helper.hpp:55
BlockGemmPipelineVersion
Definition: blkgemmpipe_scheduler.hpp:12
__host__ constexpr __device__ auto make_merge_transform_v3_division_mod(const LowLengths &low_lengths)
Definition: multi_index_transform_helper.hpp:84
TailNumber
Definition: blkgemmpipe_scheduler.hpp:31
_Float16 half_t
Definition: data_type.hpp:31
__host__ constexpr __device__ auto make_single_stage_tensor_adaptor(const Transforms &transforms, LowerDimensionOldTopIdss, UpperDimensionNewTopIdss)
Definition: tensor_adaptor.hpp:425
ushort bhalf_t
Definition: data_type.hpp:30
__global__ void kernel_gemm_xdl_cshuffle_v3_2lds(typename GridwiseGemm::Argument karg)
Definition: gridwise_gemm_xdl_cshuffle_streamk_v3.hpp:62
constexpr auto BlockGemmPipeline_Selector()
Definition: blockwise_gemm_pipeline_wmma_selector.hpp:32
__host__ constexpr __device__ auto make_freeze_transform(const LowerIndex &low_idx)
Definition: multi_index_transform_helper.hpp:151
constexpr detail::ignore_t ignore
Definition: ignore.hpp:20
__device__ index_t get_block_1d_id()
Definition: get_id.hpp:47
typename conditional< predicate, X, Y >::type conditional_t
Definition: functional.hpp:115
__host__ constexpr __device__ auto make_xor_with_modulo_transform(const LowLengths &low_lengths)
Definition: multi_index_transform_helper.hpp:185
bool EnvIsEnabled(EnvVar)
Definition: env.hpp:140
__host__ constexpr __device__ auto make_pass_through_transform(const LowLength &low_length)
Definition: multi_index_transform_helper.hpp:12
constexpr bool is_same_v
Definition: type.hpp:283
BlockGemmPipelineScheduler
Definition: blkgemmpipe_scheduler.hpp:25
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition: type.hpp:297
__global__ void kernel_gemm_xdl_cshuffle_v3(typename GridwiseGemm::Argument karg)
Definition: gridwise_gemm_xdl_cshuffle_streamk_v3.hpp:38
__host__ constexpr __device__ auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition: multi_index_transform_helper.hpp:90
int32_t index_t
Definition: ck.hpp:299
__host__ constexpr __device__ auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition: tensor_descriptor.hpp:319
__host__ constexpr __device__ auto make_right_pad_transform(const LowLength &low_length, const RightPadLength &right_pad, integral_constant< bool, SkipIsValidCheck >=integral_constant< bool, false >{})
Definition: multi_index_transform_helper.hpp:37
__device__ void block_sync_lds()
Definition: synchronization.hpp:10
const GenericPointer< typename T::ValueType > T2 value
Definition: pointer.h:1350
unsigned int uint32_t
Definition: stdint.h:126
signed int int32_t
Definition: stdint.h:123
Definition: block_to_ctile_map.hpp:271
__host__ static constexpr __device__ index_t CalculateGridSize(index_t M, index_t N)
Definition: block_to_ctile_map.hpp:283
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:716
const BElementwiseOperation b_element_op
Definition: gridwise_gemm_xdl_cshuffle_v3_b_scale.hpp:643
__host__ Argument(const ADataType *p_a_grid_, const BDataType *p_b_grid_, CDataType *p_c_grid_, index_t M_, index_t N_, index_t K_, index_t StrideA_, index_t StrideB_, index_t StrideC_, index_t k_batch_, bool is_reduce_=false, AElementwiseOperation a_element_op=AElementwiseOperation{}, BElementwiseOperation b_element_op=BElementwiseOperation{}, CElementwiseOperation c_element_op=CElementwiseOperation{})
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:717
const BDataType * p_b_grid
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:759
CDataType * p_c_grid
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:760
__host__ __device__ bool IsReduceAdd() const
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:748
const AElementwiseOperation a_element_op
Definition: gridwise_gemm_xdl_cshuffle_v3_b_scale.hpp:642
__host__ __device__ bool IsAtomicAdd() const
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:753
const ADataType * p_a_grid
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:758
const CElementwiseOperation c_element_op
Definition: gridwise_gemm_xdl_cshuffle_v3_b_scale.hpp:644
bool is_reduce
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:761
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:641
index_t N
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:695
index_t NPadded
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:702
index_t KBatch
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:700
index_t StrideA
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:697
__host__ Problem(index_t M_, index_t N_, index_t K_, index_t StrideA_, index_t StrideB_, index_t StrideC_, index_t KBatch_, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CElementwiseOperation c_element_op)
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:642
CElementwiseOperation c_element_op_
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:711
index_t BK0
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:706
index_t M
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:694
index_t NBlock
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:708
index_t MPadded
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:701
index_t K
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:696
index_t StrideB
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:698
index_t KPadded
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:704
index_t StrideC
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:699
index_t MBlock
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:707
BElementwiseOperation b_element_op_
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:710
index_t AK0
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:705
index_t KRead
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:703
AElementwiseOperation a_element_op_
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:709
__host__ void Print() const
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:673
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:765
index_t a_k_split_offset
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:814
index_t b_k_split_offset
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:815
__device__ SplitKBatchOffset(Argument &karg)
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:767
index_t c_reduce_offset
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:816
"Universal" GEMM kernel with SplitK support.
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:247
__host__ static constexpr __device__ auto MakeBMmaTileDescriptor_N0_N1_N2_K(const BBlockDesc_BK0_N_BK1 &)
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:572
__host__ static __device__ auto MakeBGridDescriptor_BK0_N_BK1(index_t K, index_t KPad, index_t N, index_t NPad, index_t StrideB, index_t BK0)
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:451
static __host__ auto CalculateKRead(index_t K, index_t K_Batch=1)
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:337
static constexpr auto is_scale_mfma
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:273
static __host__ auto CalculateKPadded(index_t K)
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:314
static __host__ auto CalculateMPadded(index_t M)
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:304
static constexpr auto BK1Number
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:261
static constexpr index_t APackedSize
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:285
static constexpr bool is_single_rate_mfma
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:264
static constexpr __host__ TailNumber CalculateKBlockLoopTailNum(index_t K)
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:1405
ThisThreadBlock< BlockSize > ThisThreadBlock
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:283
__host__ static constexpr __device__ auto MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const CGridDesc &c_grid_desc_m_n, index_t MBlock, index_t NBlock)
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:1413
static __host__ auto CalculateAK0Padded(index_t K, index_t K_Batch=1)
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:319
static constexpr auto I2
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:250
static constexpr index_t KPack
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:274
static constexpr auto lcm_AK1_BK1
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:263
__host__ static constexpr __device__ auto MakeGemmMmaTileDescriptor(const TileDesc_K0_MN_K1 &)
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:355
static constexpr auto I7
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:255
remove_cvref_t< decltype(BlockGemmPipeline_Selector< BlkGemmPipelineVer, BlkGemmPipeSched, BlockSize, ADataType, BDataType, ComputeTypeA, AccDataType, decltype(GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()), decltype(GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()), decltype(MakeAMmaTileDescriptor_M0_M1_M2_K(GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1())), decltype(MakeBMmaTileDescriptor_N0_N1_N2_K(GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1())), ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerXdl, NPerXdl, MXdlPerWave, NXdlPerWave, KPack >())> BlockwiseGemmPipe
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:1136
static constexpr auto I5
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:253
static __device__ constexpr bool IsValidCompilationParameter()
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:1166
static constexpr auto AK1Number
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:260
static constexpr __host__ bool CheckValidity(const Argument &karg)
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:1202
static __device__ void Run_2Lds(const ADataType *p_a_grid, const BDataType *p_b_grid, CDataType *p_c_grid, void *p_shared_0, void *p_shared_1, const Problem &problem, const AGridDesc_AK0_M_K1 &a_grid_desc_ak0_m_ak1, const BGridDesc_BK0_N_K1 &b_grid_desc_bk0_n_bk1, const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock &c_grid_desc_mblock_mperblock_nblock_nperblock)
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:1853
static __host__ auto CalculateGridSize(index_t M, index_t N, index_t KBatch)
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:299
static __host__ auto CalculateMBlock(index_t M)
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:344
static __device__ void Run_2Lds(const ADataType *p_a_grid, const BDataType *p_b_grid, CDataType *p_c_grid, void *p_shared_0, void *p_shared_1, const Problem &problem)
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:2217
static __host__ auto CalculateNPadded(index_t N)
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:309
__host__ static constexpr __device__ auto MakeAMmaTileDescriptor_M0_M1_M2_K(const ABlockDesc_AK0_M_AK1 &)
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:563
static __host__ auto CalculateBK0Padded(index_t K, index_t K_Batch=1)
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:325
static constexpr __device__ index_t GetSharedMemoryNumberOfByte()
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:1138
static constexpr __device__ auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:960
static constexpr index_t BPackedSize
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:292
static __device__ void Run(const ADataType *p_a_grid, const BDataType *p_b_grid, CDataType *p_c_grid, void *p_shared, const Problem &problem, const AGridDesc_AK0_M_K1 &a_grid_desc_ak0_m_ak1, const BGridDesc_BK0_N_K1 &b_grid_desc_bk0_n_bk1, const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock &c_grid_desc_mblock_mperblock_nblock_nperblock)
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:1437
static constexpr auto I6
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:254
static constexpr __device__ auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:819
static constexpr auto I1
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:249
static constexpr auto I0
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:248
static constexpr auto I3
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:251
static __device__ void Run(const ADataType *p_a_grid, const BDataType *p_b_grid, CDataType *p_c_grid, void *p_shared, const Problem &problem)
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:1816
static constexpr auto I4
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:252
static __host__ auto CalculateKPadded(index_t K, index_t K_Batch=1)
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:331
__host__ static __device__ auto MakeAGridDescriptor_AK0_M_AK1(index_t M, index_t MPad, index_t K, index_t KPad, index_t StrideA, index_t AK0)
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:369
static __host__ auto CalculateNBlock(index_t N)
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:349
static constexpr auto BK0Number
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:259
static constexpr __device__ auto GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock()
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:1097
static constexpr auto AK0Number
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:258
static constexpr __host__ bool CalculateHasMainKBlockLoop(index_t K)
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:1398
__host__ static __device__ auto MakeCGridDescriptor_M_N(index_t M, index_t MPad, index_t N, index_t NPad, index_t StrideC)
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:580
Selects the appropriate MFMA instruction type and configuration for given data types and tile sizes o...
Definition: xdlops_gemm.hpp:1208
Definition: sequence.hpp:43
Definition: tensor_space_filling_curve.hpp:20
Blockwise data transfer.
Definition: thread_group_tensor_slice_transfer_v4r1.hpp:46
Definition: thread_group_tensor_slice_transfer_v6r1.hpp:34
Definition: threadwise_tensor_slice_transfer.hpp:39
Definition: integral_constant.hpp:20
Definition: data_type.hpp:187
Definition: functional2.hpp:33
Definition: device_base.hpp:197
Definition: unary_element_wise_operation.hpp:334
#define CK_ENV(name)
Definition: env.hpp:129