26 template <
typename GridwiseGemm,
27 bool HasMainKBlockLoop,
32 #if CK_USE_LAUNCH_BOUNDS
38 #if defined(__gfx9__) || defined(__gfx12__) || defined(__gfx11__)
39 if constexpr(GridwiseGemm::template IsValidCompilationParameter<CGlobalMemoryDataOperation>())
41 __shared__
char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
43 auto splitk_batch_offset =
typename GridwiseGemm::SplitKBatchOffset(karg);
45 GridwiseGemm::template Run<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
46 karg.p_a_grid + splitk_batch_offset.a_k_split_offset,
47 karg.p_b_grid + splitk_batch_offset.b_k_split_offset,
48 karg.p_c_grid + splitk_batch_offset.c_reduce_offset,
57 template <
typename GridwiseGemm,
58 bool HasMainKBlockLoop,
63 #if CK_USE_LAUNCH_BOUNDS
69 #if defined(__gfx9__) || defined(__gfx12__) || defined(__gfx11__)
72 if constexpr(GridwiseGemm::template IsValidCompilationParameter<CGlobalMemoryDataOperation>())
74 __shared__
char p_shared_0[GridwiseGemm::GetSharedMemoryNumberOfByte()];
75 __shared__
char p_shared_1[GridwiseGemm::GetSharedMemoryNumberOfByte()];
77 auto splitk_batch_offset =
typename GridwiseGemm::SplitKBatchOffset(karg);
79 GridwiseGemm::template Run_2Lds<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
80 karg.p_a_grid + splitk_batch_offset.a_k_split_offset,
81 karg.p_b_grid + splitk_batch_offset.b_k_split_offset,
82 karg.p_c_grid + splitk_batch_offset.c_reduce_offset,
197 template <
typename ALayout,
202 typename AccDataType,
203 typename CShuffleDataType,
205 typename AElementwiseOperation,
206 typename BElementwiseOperation,
207 typename CElementwiseOperation,
219 typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
220 typename ABlockTransferThreadClusterArrangeOrder,
221 typename ABlockTransferSrcAccessOrder,
222 index_t ABlockTransferSrcVectorDim,
223 index_t ABlockTransferSrcScalarPerVector,
224 index_t ABlockTransferDstScalarPerVector_AK1,
225 bool AThreadTransferSrcResetCoordinateAfterRun,
227 typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
228 typename BBlockTransferThreadClusterArrangeOrder,
229 typename BBlockTransferSrcAccessOrder,
230 index_t BBlockTransferSrcVectorDim,
231 index_t BBlockTransferSrcScalarPerVector,
232 index_t BBlockTransferDstScalarPerVector_BK1,
233 bool BThreadTransferSrcResetCoordinateAfterRun,
235 index_t CShuffleMXdlPerWavePerShuffle,
236 index_t CShuffleNXdlPerWavePerShuffle,
237 typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
238 index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
241 typename ComputeTypeA = CDataType,
242 typename ComputeTypeB = ComputeTypeA,
243 bool PermuteA =
false,
244 bool PermuteB =
false,
245 bool DoElementwiseBeforeCShuffle =
false>
270 KPerBlock < 128 && MPerXdl == 16))
321 auto K_t = K_Batch * KPerBlock;
322 return (K + K_t - 1) / K_t * (KPerBlock / AK1Value);
327 auto K_t = K_Batch * KPerBlock;
328 return (K + K_t - 1) / K_t * (KPerBlock / BK1Value);
333 auto K_t = K_Batch * KPerBlock;
334 return (K + K_t - 1) / K_t * KPerBlock;
340 auto K_t = K_Batch * KReadVec;
341 return (K + K_t - 1) / K_t * KReadVec;
354 template <index_t MNXdlPerWave, index_t MNWaves, index_t MNPerXdl,
typename TileDesc_K0_MN_K1>
372 const auto a_grid_desc_mraw_kraw = [&]() {
373 if constexpr(is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
377 else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>)
385 if constexpr(GemmSpec == GemmSpecialization::MKPadding ||
386 GemmSpec == GemmSpecialization::MNKPadding)
389 const auto a_grid_desc_m_k =
403 return a_grid_desc_ak0_m_ak1;
405 else if constexpr(GemmSpec == GemmSpecialization::MPadding ||
406 GemmSpec == GemmSpecialization::MNPadding)
410 a_grid_desc_mraw_kraw,
416 return a_grid_desc_ak0_m_ak1;
418 else if constexpr(GemmSpec == GemmSpecialization::KPadding ||
419 GemmSpec == GemmSpecialization::NKPadding)
423 a_grid_desc_mraw_kraw,
435 return a_grid_desc_ak0_m_ak1;
441 a_grid_desc_mraw_kraw,
447 return a_grid_desc_ak0_m_ak1;
454 const auto b_grid_desc_nraw_kraw = [&]() {
468 GemmSpec != GemmSpecialization::Default),
469 "pk_i4_t does not support padding");
471 if constexpr(GemmSpec == GemmSpecialization::NKPadding ||
472 GemmSpec == GemmSpecialization::MNKPadding)
475 const auto b_grid_desc_n_k =
489 return b_grid_desc_bk0_n_bk1;
491 else if constexpr(GemmSpec == GemmSpecialization::NPadding ||
492 GemmSpec == GemmSpecialization::MNPadding)
496 b_grid_desc_nraw_kraw,
502 return b_grid_desc_bk0_n_bk1;
504 else if constexpr(GemmSpec == GemmSpecialization::KPadding ||
505 GemmSpec == GemmSpecialization::MKPadding)
509 b_grid_desc_nraw_kraw,
521 return b_grid_desc_bk0_n_bk1;
525 if constexpr(!PermuteB)
529 b_grid_desc_nraw_kraw,
535 return b_grid_desc_bk0_n_bk1;
541 constexpr
index_t BK01 = KPerBlock / BK1Value;
542 const index_t BK0_ = StrideB / BK1Value;
543 const index_t BK00 = BK0_ / BK01;
545 const auto b_grid_desc_bk00_n_bk01_bk1_permute =
549 b_grid_desc_bk00_n_bk01_bk1_permute,
556 return b_grid_desc_bk0_n_bk1_permute;
561 template <
typename ABlockDesc_AK0_M_AK1>
562 __host__ __device__
static constexpr
auto
565 constexpr
index_t MWaves = MPerBlock / (MXdlPerWave * MPerXdl);
567 return MakeGemmMmaTileDescriptor<MXdlPerWave, MWaves, MPerXdl>(ABlockDesc_AK0_M_AK1{});
570 template <
typename BBlockDesc_BK0_N_BK1>
571 __host__ __device__
static constexpr
auto
574 constexpr
index_t NWaves = NPerBlock / (NXdlPerWave * NPerXdl);
576 return MakeGemmMmaTileDescriptor<NXdlPerWave, NWaves, NPerXdl>(BBlockDesc_BK0_N_BK1{});
579 __host__ __device__
static auto
582 const auto c_grid_desc_mraw_nraw = [&]() {
602 if constexpr(GemmSpec == GemmSpecialization::MNPadding ||
603 GemmSpec == GemmSpecialization::MNKPadding)
612 else if constexpr(GemmSpec == GemmSpecialization::MPadding ||
613 GemmSpec == GemmSpecialization::MKPadding)
617 c_grid_desc_mraw_nraw,
622 else if constexpr(GemmSpec == GemmSpecialization::NPadding ||
623 GemmSpec == GemmSpecialization::NKPadding)
627 c_grid_desc_mraw_nraw,
635 return c_grid_desc_mraw_nraw;
649 AElementwiseOperation a_element_op,
650 BElementwiseOperation b_element_op,
651 CElementwiseOperation c_element_op)
676 std::cout <<
"problem {"
685 <<
"KRead:" <<
KRead <<
", "
687 <<
"AK0:" <<
AK0 <<
", "
688 <<
"BK0:" <<
BK0 <<
", "
689 <<
"MBlock: " <<
MBlock <<
", "
690 <<
"NBlock: " <<
NBlock <<
"}" << std::endl;
718 const BDataType* p_b_grid_,
719 CDataType* p_c_grid_,
727 bool is_reduce_ =
false,
728 AElementwiseOperation
a_element_op = AElementwiseOperation{},
729 BElementwiseOperation
b_element_op = BElementwiseOperation{},
730 CElementwiseOperation
c_element_op = CElementwiseOperation{})
769 if constexpr(is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
773 else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>)
778 if constexpr(is_same_v<tensor_layout::gemm::RowMajor, BLayout>)
782 else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, BLayout>)
784 if constexpr(!PermuteB)
790 const int k0_offset = karg.
KRead * karg.
N;
821 constexpr
index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
822 constexpr
index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
823 constexpr
index_t WaveSize = BlockSize / (MWave * NWave);
838 constexpr
auto MLdsLayer = LdsSize < 1 ? 1 : LdsSize;
853 a_lds_block_desc_permuted,
861 a_lds_block_desc_ak0_mldslayer_m_ak1,
869 return a_lds_block_desc_ak0_m_ak1;
876 constexpr
auto M0 = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(
I1);
877 constexpr
auto M1 = MPerBlock / M0;
879 constexpr
auto KThreadWrite = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(
I0);
880 constexpr
auto K0PerThreadWrite =
AK0Number / KThreadWrite;
881 constexpr
auto KThreadRead = WaveSize / MPerXdl;
882 constexpr
auto K0PerThreadRead =
AK0Number / KThreadRead;
884 constexpr
auto kfold = (
AK1Number * M0 *
sizeof(ADataType) > 128)
886 : 128 / (
AK1Number * M0 *
sizeof(ADataType));
887 constexpr
auto KThreadReadPerm =
888 (kfold * K0PerThreadWrite / K0PerThreadRead) > 1
889 ? KThreadRead / (kfold * K0PerThreadWrite / K0PerThreadRead)
893 constexpr
auto mpair = (
AK1Number * MPerXdl *
sizeof(ADataType) > 128)
895 : ((128 / (
AK1Number * MPerXdl *
sizeof(ADataType))) > M0
897 : 128 / (
AK1Number * MPerXdl *
sizeof(ADataType)));
903 Number<kfold * M0 / mpair>{},
922 a_lds_block_desc_permuted,
944 a_lds_block_desc_unmerged,
947 Number<KThreadWrite / kfold / KThreadReadPerm>{},
956 return a_lds_block_desc_ak0_m_ak1;
962 constexpr
index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
963 constexpr
index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
964 constexpr
index_t WaveSize = BlockSize / (MWave * NWave);
978 constexpr
index_t NLdsLayer = LdsSize < 1 ? 1 : LdsSize;
993 b_lds_block_desc_permuted,
1001 b_lds_block_desc_bk0_nldslayer_n_bk1,
1009 return b_lds_block_desc_bk0_n_bk1;
1013 constexpr
auto N0 = BBlockTransferThreadClusterLengths_BK0_N_BK1{}.At(
I1);
1014 constexpr
auto N1 = NPerBlock / N0;
1016 constexpr
auto KThreadWrite = BBlockTransferThreadClusterLengths_BK0_N_BK1{}.At(
I0);
1017 constexpr
auto K0PerThreadWrite =
BK0Number / KThreadWrite;
1018 constexpr
auto KThreadRead = WaveSize / NPerXdl;
1019 constexpr
auto K0PerThreadRead =
BK0Number / KThreadRead;
1021 constexpr
auto kfold = (
BK1Number * N0 *
sizeof(BDataType) > 128)
1023 : 128 / (
BK1Number * N0 *
sizeof(BDataType));
1024 constexpr
auto KThreadReadPerm =
1025 (kfold * K0PerThreadWrite / K0PerThreadRead) > 1
1026 ? KThreadRead / (kfold * K0PerThreadWrite / K0PerThreadRead)
1030 constexpr
auto npair = (
BK1Number * NPerXdl *
sizeof(BDataType) > 128)
1032 : ((128 / (
BK1Number * NPerXdl *
sizeof(BDataType))) > N0
1034 : 128 / (
BK1Number * NPerXdl *
sizeof(BDataType)));
1040 Number<kfold * N0 / npair>{},
1059 b_lds_block_desc_permuted,
1081 b_lds_block_desc_unmerged,
1084 Number<KThreadWrite / kfold / KThreadReadPerm>{},
1093 return b_lds_block_desc_bk0_n_bk1;
1099 constexpr
index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
1100 constexpr
index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
1102 constexpr
auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
1109 return c_shuffle_block_desc_mblock_mperblock_nblock_nperblock;
1127 ABlockTransferSrcScalarPerVector,
1128 BBlockTransferSrcScalarPerVector,
1148 a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
1151 b_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align);
1154 constexpr
auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
1157 constexpr
auto c_block_size =
1158 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize();
1161 b_block_space_size_aligned *
sizeof(BDataType) /
BPackedSize),
1162 c_block_size *
sizeof(CShuffleDataType));
1165 template <InMemoryDataOperationEnum CGlobalMemoryDataOperation>
1168 enum struct Arch : bool
1170 #if defined(__gfx950__)
1171 is_gfx950_build =
true,
1173 is_gfx950_build =
false,
1178 if constexpr(
static_cast<bool>(Arch::is_gfx950_build) ||
1191 #if defined(__gfx11__) || defined(__gfx12__)
1192 if constexpr(MPerXdl != 16 || NPerXdl != 16)
1198 #if defined(__gfx11__)
1201 constexpr
bool SupportMemOp =
sizeof(CDataType) >= 2 || (CGlobalMemoryDataOperation ==
1204 if constexpr(SupportMemOp ==
false)
1210 if constexpr(MXdlPerWave > 0 && NXdlPerWave > 0)
1212 constexpr
index_t MWaves = MPerBlock / (MXdlPerWave * MPerXdl);
1213 constexpr
index_t NWaves = NPerBlock / (NXdlPerWave * NPerXdl);
1214 if constexpr(MWaves > 0 && NWaves > 0)
1216 constexpr
index_t WaveSize = BlockSize / (MWaves * NWaves);
1239 if constexpr((MPerXdl * MXdlPerWave) == 0 || (NXdlPerWave * NPerXdl) == 0)
1245 if constexpr((MPerBlock % (MPerXdl * MXdlPerWave) != 0) ||
1246 (NPerBlock % (NXdlPerWave * NPerXdl) != 0))
1265 if(!(karg.
M % MPerBlock == 0))
1269 std::cout <<
"Arg M value is not a multiple of MPerBlock! M: " << karg.
M <<
" "
1270 << __FILE__ <<
":" << __LINE__ <<
", in function: " << __func__
1283 if(!(karg.
N % NPerBlock == 0))
1287 std::cout <<
"Arg N value is not a multiple of NPerBlock! N: " << karg.
N <<
" "
1288 << __FILE__ <<
":" << __LINE__ <<
", in function: " << __func__
1301 auto K_t = karg.
KBatch * KPerBlock;
1302 if(!(karg.
K % K_t == 0))
1306 std::cout <<
"Arg K value is not a multiple of K_Batch * K0PerBlock * K1! K: "
1307 << karg.
K <<
" " << __FILE__ <<
":" << __LINE__
1308 <<
", in function: " << __func__ << std::endl;
1316 auto K_t = karg.
KBatch * KReadVec;
1318 if((KReadPadSplited * (karg.
KBatch - 1)) >= karg.
K)
1326 if(karg.
K % ABlockTransferSrcScalarPerVector != 0)
1330 std::cout <<
"Arg K (" << karg.
K
1331 <<
") value is not a multiple of ABlockTransferSrcScalarPerVector ("
1332 << ABlockTransferSrcScalarPerVector <<
" )! " << __FILE__ <<
":"
1333 << __LINE__ <<
", in function: " << __func__ << std::endl;
1340 if(karg.
M % ABlockTransferSrcScalarPerVector != 0)
1344 std::cout <<
"Arg M (" << karg.
M
1345 <<
") value is not a multiple of ABlockTransferSrcScalarPerVector ("
1346 << ABlockTransferSrcScalarPerVector <<
" )! " << __FILE__ <<
":"
1347 << __LINE__ <<
", in function: " << __func__ << std::endl;
1355 if(karg.
N % BBlockTransferSrcScalarPerVector != 0)
1359 std::cout <<
"Arg N (" << karg.
N
1360 <<
") value is not a multiple of BBlockTransferSrcScalarPerVector ("
1361 << BBlockTransferSrcScalarPerVector <<
" )! " << __FILE__ <<
":"
1362 << __LINE__ <<
", in function: " << __func__ << std::endl;
1369 if(karg.
K % BBlockTransferSrcScalarPerVector != 0)
1373 std::cout <<
"Arg K (" << karg.
K
1374 <<
") value is not a multiple of BBlockTransferSrcScalarPerVector ("
1375 << BBlockTransferSrcScalarPerVector <<
" )! " << __FILE__ <<
":"
1376 << __LINE__ <<
", in function: " << __func__ << std::endl;
1384 if(karg.
N % CShuffleBlockTransferScalarPerVector_NPerBlock != 0)
1388 std::cout <<
"Arg N (" << karg.
N
1389 <<
") value is not a multiple of "
1390 "CShuffleBlockTransferScalarPerVector_NPerBlock ("
1391 << CShuffleBlockTransferScalarPerVector_NPerBlock <<
" )! "
1392 << __FILE__ <<
":" << __LINE__ <<
", in function: " << __func__
1400 if(karg.
M % CShuffleBlockTransferScalarPerVector_NPerBlock != 0)
1404 std::cout <<
"Arg M (" << karg.
M
1405 <<
") value is not a multiple of "
1406 "CShuffleBlockTransferScalarPerVector_NPerBlock ("
1407 << CShuffleBlockTransferScalarPerVector_NPerBlock <<
" )! "
1408 << __FILE__ <<
":" << __LINE__ <<
", in function: " << __func__
1424 std::cout <<
" KBatch: " << karg.
KBatch <<
" > 1 is not support yet" << __FILE__
1425 <<
":" << __LINE__ <<
", in function: " << __func__ << std::endl;
1435 const auto num_k_loop = karg.
AK0 / (KPerBlock / AK1Value);
1439 if(num_k_loop <= BlockwiseGemmPipe::PrefetchStages)
1451 const index_t num_loop = K / KPerBlock;
1453 return BlockwiseGemmPipe::BlockHasHotloop(num_loop);
1458 const index_t num_loop = K / KPerBlock;
1460 return BlockwiseGemmPipe::BlockLoopTailNum(num_loop);
1463 template <
typename CGr
idDesc>
1465 const CGridDesc& c_grid_desc_m_n,
index_t MBlock,
index_t NBlock)
1474 return c_grid_desc_mblock_mperblock_nblock_nperblock;
1482 template <
typename AGridDesc_AK0_M_K1,
1483 typename BGridDesc_BK0_N_K1,
1484 typename CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
1485 bool HasMainKBlockLoop,
1488 __device__
static void Run(
const ADataType* p_a_grid,
1489 const BDataType* p_b_grid,
1490 CDataType* p_c_grid,
1493 const AGridDesc_AK0_M_K1& a_grid_desc_ak0_m_ak1,
1494 const BGridDesc_BK0_N_K1& b_grid_desc_bk0_n_bk1,
1495 const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock&
1496 c_grid_desc_mblock_mperblock_nblock_nperblock)
1498 const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
1499 p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
1500 const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
1501 p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize());
1502 auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
1503 p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
1508 const auto block_work_idx =
1511 if(!block_2_ctile_map.ValidCTileIndex(
1513 make_tuple(c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(
I0),
1514 c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(
I2))))
1519 const index_t block_m_id = __builtin_amdgcn_readfirstlane(block_work_idx[
I0]);
1520 const index_t block_n_id = __builtin_amdgcn_readfirstlane(block_work_idx[
I1]);
1523 const index_t m_block_data_idx_on_grid =
1524 __builtin_amdgcn_readfirstlane(block_m_id * MPerBlock);
1526 const index_t n_block_data_idx_on_grid =
1527 __builtin_amdgcn_readfirstlane(block_n_id * NPerBlock);
1539 auto a_blockwise_copy =
1541 AElementwiseOperation,
1545 ABlockTransferThreadClusterLengths_AK0_M_AK1,
1546 ABlockTransferThreadClusterArrangeOrder,
1549 decltype(a_grid_desc_ak0_m_ak1),
1550 decltype(a_block_desc_ak0_m_ak1),
1551 ABlockTransferSrcAccessOrder,
1553 ABlockTransferSrcVectorDim,
1555 ABlockTransferSrcScalarPerVector,
1556 ABlockTransferDstScalarPerVector_AK1,
1559 AThreadTransferSrcResetCoordinateAfterRun,
1561 BlockwiseGemmPipe::GlobalBufferNum>(
1562 a_grid_desc_ak0_m_ak1,
1565 a_block_desc_ak0_m_ak1,
1570 auto b_blockwise_copy =
1572 BElementwiseOperation,
1576 BBlockTransferThreadClusterLengths_BK0_N_BK1,
1577 BBlockTransferThreadClusterArrangeOrder,
1580 decltype(b_grid_desc_bk0_n_bk1),
1581 decltype(b_block_desc_bk0_n_bk1),
1582 BBlockTransferSrcAccessOrder,
1584 BBlockTransferSrcVectorDim,
1586 BBlockTransferSrcScalarPerVector,
1587 BBlockTransferDstScalarPerVector_BK1,
1590 BThreadTransferSrcResetCoordinateAfterRun,
1592 BlockwiseGemmPipe::GlobalBufferNum>(
1593 b_grid_desc_bk0_n_bk1,
1596 b_block_desc_bk0_n_bk1,
1602 a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
1605 auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
1606 static_cast<ADataType*
>(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
1608 auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
1609 reinterpret_cast<BDataType*
>(
static_cast<char*
>(p_shared) + a_block_space_size_aligned *
1612 b_block_desc_bk0_n_bk1.GetElementSpaceSize());
1618 static_assert(std::is_default_constructible_v<BlockwiseGemmPipe>);
1620 auto c_thread_buf = blockwise_gemm_pipeline.GetCThreadBuffer();
1622 const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(
1623 (a_grid_desc_ak0_m_ak1.GetLength(
I0) * a_grid_desc_ak0_m_ak1.GetLength(
I2)) /
1626 blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, TailNum>(a_grid_desc_ak0_m_ak1,
1627 a_block_desc_ak0_m_ak1,
1631 a_block_slice_copy_step,
1632 b_grid_desc_bk0_n_bk1,
1633 b_block_desc_bk0_n_bk1,
1637 b_block_slice_copy_step,
1639 num_k_block_main_loop);
1643 static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 &&
1644 NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0,
1647 constexpr
index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
1648 constexpr
index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
1651 constexpr
auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
1652 blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
1656 constexpr
auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp =
1657 blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
1659 constexpr
auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I0);
1660 constexpr
auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I1);
1661 constexpr
auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I2);
1662 constexpr
auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I3);
1663 constexpr
auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I4);
1664 constexpr
auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I5);
1665 constexpr
auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I6);
1666 constexpr
auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I7);
1668 constexpr
auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
1671 auto c_shuffle_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
1672 static_cast<CShuffleDataType*
>(p_shared),
1673 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
1676 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
1696 const auto c_thread_mtx_on_block =
1697 blockwise_gemm_pipeline.CalculateCThreadOriginDataIndex(
I0,
I0,
I0,
I0);
1699 const index_t m_thread_data_on_block = c_thread_mtx_on_block[
I0];
1700 const index_t n_thread_data_on_block = c_thread_mtx_on_block[
I1];
1702 const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor =
1708 const auto m_thread_data_on_block_idx =
1709 m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex(
1712 const auto n_thread_data_on_block_to_n0_n1_n2_adaptor =
1718 const auto n_thread_data_on_block_idx =
1719 n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex(
1723 const auto& vpgr_to_lds_element_op = [&] {
1724 if constexpr(DoElementwiseBeforeCShuffle)
1730 return pass_through;
1733 const auto& lds_to_global_element_op = [&] {
1734 if constexpr(!DoElementwiseBeforeCShuffle)
1740 return pass_through;
1748 decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2),
1749 decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2),
1751 CElementwiseOperation,
1753 Sequence<CShuffleMXdlPerWavePerShuffle,
1754 CShuffleNXdlPerWavePerShuffle,
1766 true>{c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
1769 m_thread_data_on_block_idx[
I1],
1770 n_thread_data_on_block_idx[
I1],
1771 m_thread_data_on_block_idx[
I2],
1772 m_thread_data_on_block_idx[
I3],
1773 m_thread_data_on_block_idx[
I4],
1774 n_thread_data_on_block_idx[
I2]),
1775 vpgr_to_lds_element_op()};
1781 CElementwiseOperation,
1783 CGlobalMemoryDataOperation,
1785 CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
1787 CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>,
1788 CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
1792 decltype(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock),
1793 decltype(c_grid_desc_mblock_mperblock_nblock_nperblock),
1796 CShuffleBlockTransferScalarPerVector_NPerBlock,
1799 {c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
1801 c_grid_desc_mblock_mperblock_nblock_nperblock,
1803 lds_to_global_element_op()};
1806 constexpr
auto sfc_c_vgpr =
1809 Sequence<CShuffleMXdlPerWavePerShuffle,
1810 CShuffleNXdlPerWavePerShuffle,
1819 constexpr
auto sfc_c_global =
1823 CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
1825 CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{};
1827 constexpr
index_t num_access = sfc_c_vgpr.GetNumOfAccess();
1829 static_assert(num_access == sfc_c_global.GetNumOfAccess(),
"wrong!");
1836 c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2,
1837 sfc_c_vgpr.GetIndexTupleOfNumber(access_id),
1839 c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
1840 c_shuffle_block_buf);
1846 c_shuffle_block_copy_lds_to_global.Run(
1847 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
1848 c_shuffle_block_buf,
1849 c_grid_desc_mblock_mperblock_nblock_nperblock,
1852 if constexpr(access_id < num_access - 1)
1854 constexpr
auto c_global_step = sfc_c_global.GetForwardStep(access_id);
1857 c_shuffle_block_copy_lds_to_global.MoveDstSliceWindow(
1858 c_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step);
1864 template <
bool HasMainKBlockLoop,
1867 __device__
static void Run(
const ADataType* p_a_grid,
1868 const BDataType* p_b_grid,
1869 CDataType* p_c_grid,
1879 const auto c_grid_desc_mblock_mperblock_nblock_nperblock =
1883 Run<decltype(a_grid_desc_ak0_m_ak1),
1884 decltype(b_grid_desc_bk0_n_bk1),
1885 decltype(c_grid_desc_mblock_mperblock_nblock_nperblock),
1887 CGlobalMemoryDataOperation,
1893 a_grid_desc_ak0_m_ak1,
1894 b_grid_desc_bk0_n_bk1,
1895 c_grid_desc_mblock_mperblock_nblock_nperblock);
1898 template <
typename AGridDesc_AK0_M_K1,
1899 typename BGridDesc_BK0_N_K1,
1900 typename CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
1901 bool HasMainKBlockLoop,
1904 __device__
static void Run_2Lds(
const ADataType* p_a_grid,
1905 const BDataType* p_b_grid,
1906 CDataType* p_c_grid,
1910 const AGridDesc_AK0_M_K1& a_grid_desc_ak0_m_ak1,
1911 const BGridDesc_BK0_N_K1& b_grid_desc_bk0_n_bk1,
1912 const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock&
1913 c_grid_desc_mblock_mperblock_nblock_nperblock)
1915 const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
1916 p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
1917 const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
1918 p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize());
1919 auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
1920 p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
1925 const auto block_work_idx =
1928 if(!block_2_ctile_map.ValidCTileIndex(
1930 make_tuple(c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(
I0),
1931 c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(
I2))))
1936 const index_t block_m_id = __builtin_amdgcn_readfirstlane(block_work_idx[
I0]);
1937 const index_t block_n_id = __builtin_amdgcn_readfirstlane(block_work_idx[
I1]);
1940 const index_t m_block_data_idx_on_grid =
1941 __builtin_amdgcn_readfirstlane(block_m_id * MPerBlock);
1943 const index_t n_block_data_idx_on_grid =
1944 __builtin_amdgcn_readfirstlane(block_n_id * NPerBlock);
1956 auto a_blockwise_copy =
1958 AElementwiseOperation,
1962 ABlockTransferThreadClusterLengths_AK0_M_AK1,
1963 ABlockTransferThreadClusterArrangeOrder,
1966 decltype(a_grid_desc_ak0_m_ak1),
1967 decltype(a_block_desc_ak0_m_ak1),
1968 ABlockTransferSrcAccessOrder,
1970 ABlockTransferSrcVectorDim,
1972 ABlockTransferSrcScalarPerVector,
1973 ABlockTransferDstScalarPerVector_AK1,
1976 AThreadTransferSrcResetCoordinateAfterRun,
1978 BlockwiseGemmPipe::GlobalBufferNum>(
1979 a_grid_desc_ak0_m_ak1,
1982 a_block_desc_ak0_m_ak1,
1987 auto b_blockwise_copy =
1989 BElementwiseOperation,
1993 BBlockTransferThreadClusterLengths_BK0_N_BK1,
1994 BBlockTransferThreadClusterArrangeOrder,
1997 decltype(b_grid_desc_bk0_n_bk1),
1998 decltype(b_block_desc_bk0_n_bk1),
1999 BBlockTransferSrcAccessOrder,
2001 BBlockTransferSrcVectorDim,
2003 BBlockTransferSrcScalarPerVector,
2004 BBlockTransferDstScalarPerVector_BK1,
2007 BThreadTransferSrcResetCoordinateAfterRun,
2009 BlockwiseGemmPipe::GlobalBufferNum>(
2010 b_grid_desc_bk0_n_bk1,
2013 b_block_desc_bk0_n_bk1,
2019 a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
2021 auto a_block_buf_ping = make_dynamic_buffer<AddressSpaceEnum::Lds>(
2022 static_cast<ADataType*
>(p_shared_0), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
2024 auto b_block_buf_ping = make_dynamic_buffer<AddressSpaceEnum::Lds>(
2025 bit_cast<BDataType*>(
static_cast<char*
>(p_shared_0) +
2026 a_block_space_size_aligned *
sizeof(ADataType)),
2027 b_block_desc_bk0_n_bk1.GetElementSpaceSize());
2029 auto a_block_buf_pong = make_dynamic_buffer<AddressSpaceEnum::Lds>(
2030 static_cast<ADataType*
>(p_shared_1), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
2032 auto b_block_buf_pong = make_dynamic_buffer<AddressSpaceEnum::Lds>(
2033 bit_cast<BDataType*>(bit_cast<char*>(p_shared_1) +
2034 a_block_space_size_aligned *
sizeof(ADataType)),
2035 b_block_desc_bk0_n_bk1.GetElementSpaceSize());
2037 auto a_block_bufs =
make_tuple(a_block_buf_ping, a_block_buf_pong);
2038 auto b_block_bufs =
make_tuple(b_block_buf_ping, b_block_buf_pong);
2044 static_assert(std::is_default_constructible_v<BlockwiseGemmPipe>);
2046 auto c_thread_buf = blockwise_gemm_pipeline.GetCThreadBuffer();
2048 const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(
2049 (a_grid_desc_ak0_m_ak1.GetLength(
I0) * a_grid_desc_ak0_m_ak1.GetLength(
I2)) /
2052 blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, TailNum>(a_grid_desc_ak0_m_ak1,
2053 a_block_desc_ak0_m_ak1,
2057 a_block_slice_copy_step,
2058 b_grid_desc_bk0_n_bk1,
2059 b_block_desc_bk0_n_bk1,
2063 b_block_slice_copy_step,
2065 num_k_block_main_loop);
2069 static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 &&
2070 NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0,
2073 constexpr
index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
2074 constexpr
index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
2077 constexpr
auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
2078 blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
2082 constexpr
auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp =
2083 blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
2085 constexpr
auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I0);
2086 constexpr
auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I1);
2087 constexpr
auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I2);
2088 constexpr
auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I3);
2089 constexpr
auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I4);
2090 constexpr
auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I5);
2091 constexpr
auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I6);
2092 constexpr
auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I7);
2094 constexpr
auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
2097 auto c_shuffle_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
2098 static_cast<CShuffleDataType*
>(p_shared_0),
2099 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
2102 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
2122 const auto c_thread_mtx_on_block =
2123 blockwise_gemm_pipeline.CalculateCThreadOriginDataIndex(
I0,
I0,
I0,
I0);
2125 const index_t m_thread_data_on_block = c_thread_mtx_on_block[
I0];
2126 const index_t n_thread_data_on_block = c_thread_mtx_on_block[
I1];
2128 const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor =
2134 const auto m_thread_data_on_block_idx =
2135 m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex(
2138 const auto n_thread_data_on_block_to_n0_n1_n2_adaptor =
2144 const auto n_thread_data_on_block_idx =
2145 n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex(
2149 auto c_thread_copy_vgpr_to_lds =
2152 decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2),
2153 decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2),
2155 Sequence<CShuffleMXdlPerWavePerShuffle,
2156 CShuffleNXdlPerWavePerShuffle,
2169 c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
2172 m_thread_data_on_block_idx[
I1],
2173 n_thread_data_on_block_idx[
I1],
2174 m_thread_data_on_block_idx[
I2],
2175 m_thread_data_on_block_idx[
I3],
2176 m_thread_data_on_block_idx[
I4],
2177 n_thread_data_on_block_idx[
I2]),
2183 CElementwiseOperation,
2184 CGlobalMemoryDataOperation,
2186 CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
2188 CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>,
2189 CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
2193 decltype(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock),
2194 decltype(c_grid_desc_mblock_mperblock_nblock_nperblock),
2197 CShuffleBlockTransferScalarPerVector_NPerBlock,
2200 {c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
2202 c_grid_desc_mblock_mperblock_nblock_nperblock,
2207 constexpr
auto sfc_c_vgpr =
2210 Sequence<CShuffleMXdlPerWavePerShuffle,
2211 CShuffleNXdlPerWavePerShuffle,
2220 constexpr
auto sfc_c_global =
2224 CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
2226 CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{};
2228 constexpr
index_t num_access = sfc_c_vgpr.GetNumOfAccess();
2230 static_assert(num_access == sfc_c_global.GetNumOfAccess(),
"wrong!");
2237 c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2,
2238 sfc_c_vgpr.GetIndexTupleOfNumber(access_id),
2240 c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
2241 c_shuffle_block_buf);
2247 c_shuffle_block_copy_lds_to_global.Run(
2248 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
2249 c_shuffle_block_buf,
2250 c_grid_desc_mblock_mperblock_nblock_nperblock,
2253 if constexpr(access_id < num_access - 1)
2255 constexpr
auto c_global_step = sfc_c_global.GetForwardStep(access_id);
2258 c_shuffle_block_copy_lds_to_global.MoveDstSliceWindow(
2259 c_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step);
2265 template <
bool HasMainKBlockLoop,
2268 __device__
static void Run_2Lds(
const ADataType* p_a_grid,
2269 const BDataType* p_b_grid,
2270 CDataType* p_c_grid,
2282 const auto c_grid_desc_mblock_mperblock_nblock_nperblock =
2286 Run_2Lds<decltype(a_grid_desc_ak0_m_ak1),
2287 decltype(b_grid_desc_bk0_n_bk1),
2288 decltype(c_grid_desc_mblock_mperblock_nblock_nperblock),
2290 CGlobalMemoryDataOperation,
2297 a_grid_desc_ak0_m_ak1,
2298 b_grid_desc_bk0_n_bk1,
2299 c_grid_desc_mblock_mperblock_nblock_nperblock);
#define CK_MAX_THREAD_PER_BLOCK
Definition: ck.hpp:29
Y __host__ constexpr __device__ auto lcm(X x, Y y)
Definition: math.hpp:198
__host__ constexpr __device__ auto integer_least_multiple(X x, Y y)
Definition: math.hpp:78
__host__ constexpr __device__ auto integer_divide_ceil(X x, Y y)
Definition: math.hpp:72
__host__ constexpr __device__ T max(T x)
Definition: math.hpp:84
GemmSpecialization
Definition: gemm_specialization.hpp:11
__host__ constexpr __device__ auto make_multi_index(Xs &&... xs)
Definition: array_multi_index.hpp:15
__host__ constexpr __device__ auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition: tensor_descriptor_helper.hpp:49
InMemoryDataOperationEnum
Definition: ck.hpp:276
__host__ constexpr __device__ index_t get_warp_size()
Definition: get_id.hpp:42
__host__ constexpr __device__ auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition: tensor_descriptor_helper.hpp:101
__host__ constexpr __device__ auto make_merge_transform(const LowLengths &low_lengths)
Definition: multi_index_transform_helper.hpp:55
BlockGemmPipelineVersion
Definition: blkgemmpipe_scheduler.hpp:12
__host__ constexpr __device__ auto make_merge_transform_v3_division_mod(const LowLengths &low_lengths)
Definition: multi_index_transform_helper.hpp:84
TailNumber
Definition: blkgemmpipe_scheduler.hpp:31
constexpr auto BlockGemmPipeline_Selector()
Definition: blockwise_gemm_pipeline_wmma_selector.hpp:31
_Float16 half_t
Definition: data_type.hpp:30
__host__ constexpr __device__ auto make_single_stage_tensor_adaptor(const Transforms &transforms, LowerDimensionOldTopIdss, UpperDimensionNewTopIdss)
Definition: tensor_adaptor.hpp:425
ushort bhalf_t
Definition: data_type.hpp:29
__global__ void kernel_gemm_xdl_cshuffle_v3_2lds(typename GridwiseGemm::Argument karg)
Definition: gridwise_gemm_xdl_cshuffle_streamk_v3.hpp:59
__host__ constexpr __device__ auto make_freeze_transform(const LowerIndex &low_idx)
Definition: multi_index_transform_helper.hpp:98
constexpr detail::ignore_t ignore
Definition: ignore.hpp:20
__device__ index_t get_block_1d_id()
Definition: get_id.hpp:58
typename conditional< predicate, X, Y >::type conditional_t
Definition: functional.hpp:115
__host__ constexpr __device__ auto make_xor_with_modulo_transform(const LowLengths &low_lengths)
Definition: multi_index_transform_helper.hpp:132
bool EnvIsEnabled(EnvVar)
Definition: env.hpp:140
__host__ constexpr __device__ auto make_pass_through_transform(const LowLength &low_length)
Definition: multi_index_transform_helper.hpp:12
constexpr bool is_same_v
Definition: type.hpp:283
BlockGemmPipelineScheduler
Definition: blkgemmpipe_scheduler.hpp:25
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition: type.hpp:297
__global__ void kernel_gemm_xdl_cshuffle_v3(typename GridwiseGemm::Argument karg)
Definition: gridwise_gemm_xdl_cshuffle_streamk_v3.hpp:38
__host__ constexpr __device__ auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition: multi_index_transform_helper.hpp:90
int32_t index_t
Definition: ck.hpp:298
__host__ constexpr __device__ auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition: tensor_descriptor.hpp:319
__host__ constexpr __device__ auto make_right_pad_transform(const LowLength &low_length, const RightPadLength &right_pad, integral_constant< bool, SkipIsValidCheck >=integral_constant< bool, false >{})
Definition: multi_index_transform_helper.hpp:37
__device__ void block_sync_lds()
Definition: synchronization.hpp:10
const GenericPointer< typename T::ValueType > T2 value
Definition: pointer.h:1350
unsigned int uint32_t
Definition: stdint.h:126
signed int int32_t
Definition: stdint.h:123
Definition: block_to_ctile_map.hpp:270
__host__ static constexpr __device__ index_t CalculateGridSize(index_t M, index_t N)
Definition: block_to_ctile_map.hpp:282
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:716
const BElementwiseOperation b_element_op
Definition: gridwise_gemm_xdl_cshuffle_v3_b_scale.hpp:639
__host__ Argument(const ADataType *p_a_grid_, const BDataType *p_b_grid_, CDataType *p_c_grid_, index_t M_, index_t N_, index_t K_, index_t StrideA_, index_t StrideB_, index_t StrideC_, index_t k_batch_, bool is_reduce_=false, AElementwiseOperation a_element_op=AElementwiseOperation{}, BElementwiseOperation b_element_op=BElementwiseOperation{}, CElementwiseOperation c_element_op=CElementwiseOperation{})
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:717
const BDataType * p_b_grid
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:759
CDataType * p_c_grid
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:760
__host__ __device__ bool IsReduceAdd() const
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:748
const AElementwiseOperation a_element_op
Definition: gridwise_gemm_xdl_cshuffle_v3_b_scale.hpp:638
__host__ __device__ bool IsAtomicAdd() const
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:753
const ADataType * p_a_grid
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:758
const CElementwiseOperation c_element_op
Definition: gridwise_gemm_xdl_cshuffle_v3_b_scale.hpp:640
bool is_reduce
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:761
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:641
index_t N
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:695
index_t NPadded
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:702
index_t KBatch
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:700
index_t StrideA
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:697
__host__ Problem(index_t M_, index_t N_, index_t K_, index_t StrideA_, index_t StrideB_, index_t StrideC_, index_t KBatch_, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CElementwiseOperation c_element_op)
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:642
CElementwiseOperation c_element_op_
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:711
index_t BK0
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:706
index_t M
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:694
index_t NBlock
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:708
index_t MPadded
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:701
index_t K
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:696
index_t StrideB
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:698
index_t KPadded
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:704
index_t StrideC
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:699
index_t MBlock
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:707
BElementwiseOperation b_element_op_
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:710
index_t AK0
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:705
index_t KRead
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:703
AElementwiseOperation a_element_op_
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:709
__host__ void Print() const
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:673
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:765
index_t a_k_split_offset
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:814
index_t b_k_split_offset
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:815
__device__ SplitKBatchOffset(Argument &karg)
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:767
index_t c_reduce_offset
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:816
"Universal" GEMM kernel with SplitK support.
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:247
__host__ static constexpr __device__ auto MakeBMmaTileDescriptor_N0_N1_N2_K(const BBlockDesc_BK0_N_BK1 &)
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:572
__host__ static __device__ auto MakeBGridDescriptor_BK0_N_BK1(index_t K, index_t KPad, index_t N, index_t NPad, index_t StrideB, index_t BK0)
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:451
static __host__ auto CalculateKRead(index_t K, index_t K_Batch=1)
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:337
static constexpr auto is_scale_mfma
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:273
static __host__ auto CalculateKPadded(index_t K)
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:314
static __host__ auto CalculateMPadded(index_t M)
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:304
static constexpr auto BK1Number
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:261
static constexpr index_t APackedSize
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:285
static constexpr bool is_single_rate_mfma
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:264
static constexpr __host__ TailNumber CalculateKBlockLoopTailNum(index_t K)
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:1456
ThisThreadBlock< BlockSize > ThisThreadBlock
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:283
__host__ static constexpr __device__ auto MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const CGridDesc &c_grid_desc_m_n, index_t MBlock, index_t NBlock)
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:1464
static __host__ auto CalculateAK0Padded(index_t K, index_t K_Batch=1)
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:319
static constexpr auto I2
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:250
static constexpr index_t KPack
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:274
static constexpr auto lcm_AK1_BK1
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:263
__host__ static constexpr __device__ auto MakeGemmMmaTileDescriptor(const TileDesc_K0_MN_K1 &)
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:355
static constexpr auto I7
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:255
remove_cvref_t< decltype(BlockGemmPipeline_Selector< BlkGemmPipelineVer, BlkGemmPipeSched, BlockSize, ADataType, BDataType, ComputeTypeA, AccDataType, decltype(GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()), decltype(GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()), decltype(MakeAMmaTileDescriptor_M0_M1_M2_K(GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1())), decltype(MakeBMmaTileDescriptor_N0_N1_N2_K(GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1())), ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerXdl, NPerXdl, MXdlPerWave, NXdlPerWave, KPack >())> BlockwiseGemmPipe
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:1136
static constexpr auto I5
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:253
static __device__ constexpr bool IsValidCompilationParameter()
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:1166
static constexpr auto AK1Number
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:260
static constexpr __host__ bool CheckValidity(const Argument &karg)
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:1237
static __device__ void Run_2Lds(const ADataType *p_a_grid, const BDataType *p_b_grid, CDataType *p_c_grid, void *p_shared_0, void *p_shared_1, const Problem &problem, const AGridDesc_AK0_M_K1 &a_grid_desc_ak0_m_ak1, const BGridDesc_BK0_N_K1 &b_grid_desc_bk0_n_bk1, const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock &c_grid_desc_mblock_mperblock_nblock_nperblock)
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:1904
static __host__ auto CalculateGridSize(index_t M, index_t N, index_t KBatch)
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:299
static __host__ auto CalculateMBlock(index_t M)
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:344
static __device__ void Run_2Lds(const ADataType *p_a_grid, const BDataType *p_b_grid, CDataType *p_c_grid, void *p_shared_0, void *p_shared_1, const Problem &problem)
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:2268
static __host__ auto CalculateNPadded(index_t N)
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:309
__host__ static constexpr __device__ auto MakeAMmaTileDescriptor_M0_M1_M2_K(const ABlockDesc_AK0_M_AK1 &)
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:563
static __host__ auto CalculateBK0Padded(index_t K, index_t K_Batch=1)
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:325
static constexpr __device__ index_t GetSharedMemoryNumberOfByte()
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:1138
static constexpr __device__ auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:960
static constexpr index_t BPackedSize
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:292
static __device__ void Run(const ADataType *p_a_grid, const BDataType *p_b_grid, CDataType *p_c_grid, void *p_shared, const Problem &problem, const AGridDesc_AK0_M_K1 &a_grid_desc_ak0_m_ak1, const BGridDesc_BK0_N_K1 &b_grid_desc_bk0_n_bk1, const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock &c_grid_desc_mblock_mperblock_nblock_nperblock)
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:1488
static constexpr auto I6
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:254
static constexpr __device__ auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:819
static constexpr auto I1
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:249
static constexpr auto I0
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:248
static constexpr auto I3
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:251
static __device__ void Run(const ADataType *p_a_grid, const BDataType *p_b_grid, CDataType *p_c_grid, void *p_shared, const Problem &problem)
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:1867
static constexpr auto I4
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:252
static __host__ auto CalculateKPadded(index_t K, index_t K_Batch=1)
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:331
__host__ static __device__ auto MakeAGridDescriptor_AK0_M_AK1(index_t M, index_t MPad, index_t K, index_t KPad, index_t StrideA, index_t AK0)
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:369
static __host__ auto CalculateNBlock(index_t N)
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:349
static constexpr auto BK0Number
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:259
static constexpr __device__ auto GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock()
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:1097
static constexpr auto AK0Number
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:258
static constexpr __host__ bool CalculateHasMainKBlockLoop(index_t K)
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:1449
__host__ static __device__ auto MakeCGridDescriptor_M_N(index_t M, index_t MPad, index_t N, index_t NPad, index_t StrideC)
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:580
Definition: xdlops_gemm.hpp:1126
Definition: sequence.hpp:43
Definition: tensor_space_filling_curve.hpp:20
Blockwise data transfer.
Definition: thread_group_tensor_slice_transfer_v4r1.hpp:46
Definition: thread_group_tensor_slice_transfer_v6r1.hpp:34
Definition: threadwise_tensor_slice_transfer.hpp:39
Definition: integral_constant.hpp:20
Definition: data_type.hpp:186
Definition: functional2.hpp:33
Definition: device_base.hpp:51
Definition: unary_element_wise_operation.hpp:334
#define CK_ENV(name)
Definition: env.hpp:129