26 template <
typename GridwiseGemm,
 
   27           bool HasMainKBlockLoop,
 
   32 #if CK_USE_LAUNCH_BOUNDS 
   38 #if defined(__gfx9__) || defined(__gfx12__) || defined(__gfx11__) 
   39     if constexpr(GridwiseGemm::template IsValidCompilationParameter<CGlobalMemoryDataOperation>())
 
   41         __shared__ 
char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
 
   43         auto splitk_batch_offset = 
typename GridwiseGemm::SplitKBatchOffset(karg);
 
   45         GridwiseGemm::template Run<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
 
   46             karg.p_a_grid + splitk_batch_offset.a_k_split_offset,
 
   47             karg.p_b_grid + splitk_batch_offset.b_k_split_offset,
 
   48             karg.p_c_grid + splitk_batch_offset.c_reduce_offset,
 
   57 template <
typename GridwiseGemm,
 
   58           bool HasMainKBlockLoop,
 
   63 #if CK_USE_LAUNCH_BOUNDS 
   69 #if defined(__gfx9__) || defined(__gfx12__) || defined(__gfx11__) 
   72     if constexpr(GridwiseGemm::template IsValidCompilationParameter<CGlobalMemoryDataOperation>())
 
   74         __shared__ 
char p_shared_0[GridwiseGemm::GetSharedMemoryNumberOfByte()];
 
   75         __shared__ 
char p_shared_1[GridwiseGemm::GetSharedMemoryNumberOfByte()];
 
   77         auto splitk_batch_offset = 
typename GridwiseGemm::SplitKBatchOffset(karg);
 
   79         GridwiseGemm::template Run_2Lds<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
 
   80             karg.p_a_grid + splitk_batch_offset.a_k_split_offset,
 
   81             karg.p_b_grid + splitk_batch_offset.b_k_split_offset,
 
   82             karg.p_c_grid + splitk_batch_offset.c_reduce_offset,
 
  197 template <
typename ALayout,
 
  202           typename AccDataType,
 
  203           typename CShuffleDataType,
 
  205           typename AElementwiseOperation,
 
  206           typename BElementwiseOperation,
 
  207           typename CElementwiseOperation,
 
  219           typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
 
  220           typename ABlockTransferThreadClusterArrangeOrder,
 
  221           typename ABlockTransferSrcAccessOrder,
 
  222           index_t ABlockTransferSrcVectorDim,
 
  223           index_t ABlockTransferSrcScalarPerVector,
 
  224           index_t ABlockTransferDstScalarPerVector_AK1,
 
  225           bool AThreadTransferSrcResetCoordinateAfterRun,
 
  227           typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
 
  228           typename BBlockTransferThreadClusterArrangeOrder,
 
  229           typename BBlockTransferSrcAccessOrder,
 
  230           index_t BBlockTransferSrcVectorDim,
 
  231           index_t BBlockTransferSrcScalarPerVector,
 
  232           index_t BBlockTransferDstScalarPerVector_BK1,
 
  233           bool BThreadTransferSrcResetCoordinateAfterRun,
 
  235           index_t CShuffleMXdlPerWavePerShuffle,
 
  236           index_t CShuffleNXdlPerWavePerShuffle,
 
  237           typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
 
  238           index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
 
  241           typename ComputeTypeA                       = CDataType,
 
  242           typename ComputeTypeB                       = ComputeTypeA,
 
  243           bool PermuteA                               = 
false,
 
  244           bool PermuteB                               = 
false,
 
  245           bool DoElementwiseBeforeCShuffle            = 
false>
 
  270           KPerBlock < 128 && MPerXdl == 16))
 
  321         auto K_t = K_Batch * KPerBlock;
 
  322         return (K + K_t - 1) / K_t * (KPerBlock / AK1Value);
 
  327         auto K_t = K_Batch * KPerBlock;
 
  328         return (K + K_t - 1) / K_t * (KPerBlock / BK1Value);
 
  333         auto K_t = K_Batch * KPerBlock;
 
  334         return (K + K_t - 1) / K_t * KPerBlock;
 
  340         auto K_t                = K_Batch * KReadVec;
 
  341         return (K + K_t - 1) / K_t * KReadVec;
 
  354     template <index_t MNXdlPerWave, index_t MNWaves, index_t MNPerXdl, 
typename TileDesc_K0_MN_K1>
 
  372         const auto a_grid_desc_mraw_kraw = [&]() {
 
  373             if constexpr(is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
 
  377             else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>)
 
  385         if constexpr(GemmSpec == GemmSpecialization::MKPadding ||
 
  386                      GemmSpec == GemmSpecialization::MNKPadding)
 
  389             const auto a_grid_desc_m_k =
 
  403             return a_grid_desc_ak0_m_ak1;
 
  405         else if constexpr(GemmSpec == GemmSpecialization::MPadding ||
 
  406                           GemmSpec == GemmSpecialization::MNPadding)
 
  410                 a_grid_desc_mraw_kraw,
 
  416             return a_grid_desc_ak0_m_ak1;
 
  418         else if constexpr(GemmSpec == GemmSpecialization::KPadding ||
 
  419                           GemmSpec == GemmSpecialization::NKPadding)
 
  423                 a_grid_desc_mraw_kraw,
 
  435             return a_grid_desc_ak0_m_ak1;
 
  441                 a_grid_desc_mraw_kraw,
 
  447             return a_grid_desc_ak0_m_ak1;
 
  454         const auto b_grid_desc_nraw_kraw = [&]() {
 
  468                         GemmSpec != GemmSpecialization::Default),
 
  469                       "pk_i4_t does not support padding");
 
  471         if constexpr(GemmSpec == GemmSpecialization::NKPadding ||
 
  472                      GemmSpec == GemmSpecialization::MNKPadding)
 
  475             const auto b_grid_desc_n_k =
 
  489             return b_grid_desc_bk0_n_bk1;
 
  491         else if constexpr(GemmSpec == GemmSpecialization::NPadding ||
 
  492                           GemmSpec == GemmSpecialization::MNPadding)
 
  496                 b_grid_desc_nraw_kraw,
 
  502             return b_grid_desc_bk0_n_bk1;
 
  504         else if constexpr(GemmSpec == GemmSpecialization::KPadding ||
 
  505                           GemmSpec == GemmSpecialization::MKPadding)
 
  509                 b_grid_desc_nraw_kraw,
 
  521             return b_grid_desc_bk0_n_bk1;
 
  525             if constexpr(!PermuteB)
 
  529                     b_grid_desc_nraw_kraw,
 
  535                 return b_grid_desc_bk0_n_bk1;
 
  541                 constexpr 
index_t BK01 = KPerBlock / BK1Value;
 
  542                 const index_t BK0_     = StrideB / BK1Value;
 
  543                 const index_t BK00     = BK0_ / BK01;
 
  545                 const auto b_grid_desc_bk00_n_bk01_bk1_permute =
 
  549                     b_grid_desc_bk00_n_bk01_bk1_permute,
 
  556                 return b_grid_desc_bk0_n_bk1_permute;
 
  561     template <
typename ABlockDesc_AK0_M_AK1>
 
  562     __host__ __device__ 
static constexpr 
auto 
  565         constexpr 
index_t MWaves = MPerBlock / (MXdlPerWave * MPerXdl);
 
  567         return MakeGemmMmaTileDescriptor<MXdlPerWave, MWaves, MPerXdl>(ABlockDesc_AK0_M_AK1{});
 
  570     template <
typename BBlockDesc_BK0_N_BK1>
 
  571     __host__ __device__ 
static constexpr 
auto 
  574         constexpr 
index_t NWaves = NPerBlock / (NXdlPerWave * NPerXdl);
 
  576         return MakeGemmMmaTileDescriptor<NXdlPerWave, NWaves, NPerXdl>(BBlockDesc_BK0_N_BK1{});
 
  579     __host__ __device__ 
static auto 
  582         const auto c_grid_desc_mraw_nraw = [&]() {
 
  602         if constexpr(GemmSpec == GemmSpecialization::MNPadding ||
 
  603                      GemmSpec == GemmSpecialization::MNKPadding)
 
  612         else if constexpr(GemmSpec == GemmSpecialization::MPadding ||
 
  613                           GemmSpec == GemmSpecialization::MKPadding)
 
  617                 c_grid_desc_mraw_nraw,
 
  622         else if constexpr(GemmSpec == GemmSpecialization::NPadding ||
 
  623                           GemmSpec == GemmSpecialization::NKPadding)
 
  627                 c_grid_desc_mraw_nraw,
 
  635             return c_grid_desc_mraw_nraw;
 
  649                          AElementwiseOperation a_element_op,
 
  650                          BElementwiseOperation b_element_op,
 
  651                          CElementwiseOperation c_element_op)
 
  676             std::cout << 
"problem {"  
  685                       << 
"KRead:" << 
KRead << 
", "  
  687                       << 
"AK0:" << 
AK0 << 
", "  
  688                       << 
"BK0:" << 
BK0 << 
", "  
  689                       << 
"MBlock: " << 
MBlock << 
", " 
  690                       << 
"NBlock: " << 
NBlock << 
"}" << std::endl;
 
  718                           const BDataType* p_b_grid_,
 
  719                           CDataType* p_c_grid_,
 
  727                           bool is_reduce_                    = 
false,
 
  728                           AElementwiseOperation 
a_element_op = AElementwiseOperation{},
 
  729                           BElementwiseOperation 
b_element_op = BElementwiseOperation{},
 
  730                           CElementwiseOperation 
c_element_op = CElementwiseOperation{})
 
  769             if constexpr(is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
 
  773             else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>)
 
  778             if constexpr(is_same_v<tensor_layout::gemm::RowMajor, BLayout>)
 
  782             else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, BLayout>)
 
  784                 if constexpr(!PermuteB)
 
  790                     const int k0_offset = karg.
KRead * karg.
N;
 
  821         constexpr 
index_t MWave    = MPerBlock / (MXdlPerWave * MPerXdl);
 
  822         constexpr 
index_t NWave    = NPerBlock / (NXdlPerWave * NPerXdl);
 
  823         constexpr 
index_t WaveSize = BlockSize / (MWave * NWave);   
 
  838             constexpr 
auto MLdsLayer        = LdsSize < 1 ? 1 : LdsSize;
 
  853                 a_lds_block_desc_permuted,
 
  861                 a_lds_block_desc_ak0_mldslayer_m_ak1,
 
  869             return a_lds_block_desc_ak0_m_ak1;
 
  876             constexpr 
auto M0 = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(
I1);
 
  877             constexpr 
auto M1 = MPerBlock / M0;
 
  879             constexpr 
auto KThreadWrite     = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(
I0);
 
  880             constexpr 
auto K0PerThreadWrite = 
AK0Number / KThreadWrite;
 
  881             constexpr 
auto KThreadRead      = WaveSize / MPerXdl;
 
  882             constexpr 
auto K0PerThreadRead  = 
AK0Number / KThreadRead;
 
  884             constexpr 
auto kfold = (
AK1Number * M0 * 
sizeof(ADataType) > 128)
 
  886                                        : 128 / (
AK1Number * M0 * 
sizeof(ADataType));
 
  887             constexpr 
auto KThreadReadPerm =
 
  888                 (kfold * K0PerThreadWrite / K0PerThreadRead) > 1
 
  889                     ? KThreadRead / (kfold * K0PerThreadWrite / K0PerThreadRead)
 
  893             constexpr 
auto mpair = (
AK1Number * MPerXdl * 
sizeof(ADataType) > 128)
 
  895                                        : ((128 / (
AK1Number * MPerXdl * 
sizeof(ADataType))) > M0
 
  897                                               : 128 / (
AK1Number * MPerXdl * 
sizeof(ADataType)));
 
  903                            Number<kfold * M0 / mpair>{},
 
  922                 a_lds_block_desc_permuted,
 
  944                 a_lds_block_desc_unmerged,
 
  947                                           Number<KThreadWrite / kfold / KThreadReadPerm>{},
 
  956             return a_lds_block_desc_ak0_m_ak1;
 
  962         constexpr 
index_t MWave    = MPerBlock / (MXdlPerWave * MPerXdl);
 
  963         constexpr 
index_t NWave    = NPerBlock / (NXdlPerWave * NPerXdl);
 
  964         constexpr 
index_t WaveSize = BlockSize / (MWave * NWave);
 
  978             constexpr 
index_t NLdsLayer     = LdsSize < 1 ? 1 : LdsSize;
 
  993                 b_lds_block_desc_permuted,
 
 1001                 b_lds_block_desc_bk0_nldslayer_n_bk1,
 
 1009             return b_lds_block_desc_bk0_n_bk1;
 
 1013             constexpr 
auto N0 = BBlockTransferThreadClusterLengths_BK0_N_BK1{}.At(
I1);
 
 1014             constexpr 
auto N1 = NPerBlock / N0;
 
 1016             constexpr 
auto KThreadWrite     = BBlockTransferThreadClusterLengths_BK0_N_BK1{}.At(
I0);
 
 1017             constexpr 
auto K0PerThreadWrite = 
BK0Number / KThreadWrite;
 
 1018             constexpr 
auto KThreadRead      = WaveSize / NPerXdl;
 
 1019             constexpr 
auto K0PerThreadRead  = 
BK0Number / KThreadRead;
 
 1021             constexpr 
auto kfold = (
BK1Number * N0 * 
sizeof(BDataType) > 128)
 
 1023                                        : 128 / (
BK1Number * N0 * 
sizeof(BDataType));
 
 1024             constexpr 
auto KThreadReadPerm =
 
 1025                 (kfold * K0PerThreadWrite / K0PerThreadRead) > 1
 
 1026                     ? KThreadRead / (kfold * K0PerThreadWrite / K0PerThreadRead)
 
 1030             constexpr 
auto npair = (
BK1Number * NPerXdl * 
sizeof(BDataType) > 128)
 
 1032                                        : ((128 / (
BK1Number * NPerXdl * 
sizeof(BDataType))) > N0
 
 1034                                               : 128 / (
BK1Number * NPerXdl * 
sizeof(BDataType)));
 
 1040                            Number<kfold * N0 / npair>{},
 
 1059                 b_lds_block_desc_permuted,
 
 1081                 b_lds_block_desc_unmerged,
 
 1084                                           Number<KThreadWrite / kfold / KThreadReadPerm>{},
 
 1093             return b_lds_block_desc_bk0_n_bk1;
 
 1099         constexpr 
index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
 
 1100         constexpr 
index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
 
 1102         constexpr 
auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
 
 1109         return c_shuffle_block_desc_mblock_mperblock_nblock_nperblock;
 
 1127                                 ABlockTransferSrcScalarPerVector,
 
 1128                                 BBlockTransferSrcScalarPerVector,
 
 1148             a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
 
 1151             b_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align);
 
 1154         constexpr 
auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
 
 1157         constexpr 
auto c_block_size =
 
 1158             c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize();
 
 1161                           b_block_space_size_aligned * 
sizeof(BDataType) / 
BPackedSize),
 
 1162                          c_block_size * 
sizeof(CShuffleDataType));
 
 1165     template <InMemoryDataOperationEnum CGlobalMemoryDataOperation>
 
 1168         enum struct Arch : bool
 
 1170 #if defined(__gfx950__) 
 1171             is_gfx950_build = 
true,
 
 1173             is_gfx950_build = 
false,
 
 1178         if constexpr(
static_cast<bool>(Arch::is_gfx950_build) ||
 
 1190         return ck::tensor_operation::device::IsValidGemmCompilationParameter<
 
 1199                    CGlobalMemoryDataOperation>();
 
 1204         static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) &&
 
 1205                           (NPerBlock % (NXdlPerWave * NPerXdl)) == 0,
 
 1206                       "Invalid tuning param!");
 
 1214             if(!(karg.
M % MPerBlock == 0))
 
 1218                     std::cout << 
"Arg M value is not a multiple of MPerBlock! M: " << karg.
M << 
" " 
 1219                               << __FILE__ << 
":" << __LINE__ << 
", in function: " << __func__
 
 1232             if(!(karg.
N % NPerBlock == 0))
 
 1236                     std::cout << 
"Arg N value is not a multiple of NPerBlock! N: " << karg.
N << 
" " 
 1237                               << __FILE__ << 
":" << __LINE__ << 
", in function: " << __func__
 
 1250             auto K_t = karg.
KBatch * KPerBlock;
 
 1251             if(!(karg.
K % K_t == 0))
 
 1255                     std::cout << 
"Arg K value is not a multiple of K_Batch * K0PerBlock * K1! K: " 
 1256                               << karg.
K << 
" " << __FILE__ << 
":" << __LINE__
 
 1257                               << 
", in function: " << __func__ << std::endl;
 
 1265             auto K_t                = karg.
KBatch * KReadVec;
 
 1267             if((KReadPadSplited * (karg.
KBatch - 1)) >= karg.
K)
 
 1275             if(karg.
K % ABlockTransferSrcScalarPerVector != 0)
 
 1279                     std::cout << 
"Arg K (" << karg.
K 
 1280                               << 
") value is not a multiple of ABlockTransferSrcScalarPerVector (" 
 1281                               << ABlockTransferSrcScalarPerVector << 
" )! " << __FILE__ << 
":" 
 1282                               << __LINE__ << 
", in function: " << __func__ << std::endl;
 
 1289             if(karg.
M % ABlockTransferSrcScalarPerVector != 0)
 
 1293                     std::cout << 
"Arg M (" << karg.
M 
 1294                               << 
") value is not a multiple of ABlockTransferSrcScalarPerVector (" 
 1295                               << ABlockTransferSrcScalarPerVector << 
" )! " << __FILE__ << 
":" 
 1296                               << __LINE__ << 
", in function: " << __func__ << std::endl;
 
 1304             if(karg.
N % BBlockTransferSrcScalarPerVector != 0)
 
 1308                     std::cout << 
"Arg N (" << karg.
N 
 1309                               << 
") value is not a multiple of BBlockTransferSrcScalarPerVector (" 
 1310                               << BBlockTransferSrcScalarPerVector << 
" )! " << __FILE__ << 
":" 
 1311                               << __LINE__ << 
", in function: " << __func__ << std::endl;
 
 1318             if(karg.
K % BBlockTransferSrcScalarPerVector != 0)
 
 1322                     std::cout << 
"Arg K (" << karg.
K 
 1323                               << 
") value is not a multiple of BBlockTransferSrcScalarPerVector (" 
 1324                               << BBlockTransferSrcScalarPerVector << 
" )! " << __FILE__ << 
":" 
 1325                               << __LINE__ << 
", in function: " << __func__ << std::endl;
 
 1333             if(karg.
N % CShuffleBlockTransferScalarPerVector_NPerBlock != 0)
 
 1337                     std::cout << 
"Arg N (" << karg.
N 
 1338                               << 
") value is not a multiple of " 
 1339                                  "CShuffleBlockTransferScalarPerVector_NPerBlock (" 
 1340                               << CShuffleBlockTransferScalarPerVector_NPerBlock << 
" )! " 
 1341                               << __FILE__ << 
":" << __LINE__ << 
", in function: " << __func__
 
 1349             if(karg.
M % CShuffleBlockTransferScalarPerVector_NPerBlock != 0)
 
 1353                     std::cout << 
"Arg M (" << karg.
M 
 1354                               << 
") value is not a multiple of " 
 1355                                  "CShuffleBlockTransferScalarPerVector_NPerBlock (" 
 1356                               << CShuffleBlockTransferScalarPerVector_NPerBlock << 
" )! " 
 1357                               << __FILE__ << 
":" << __LINE__ << 
", in function: " << __func__
 
 1373                     std::cout << 
" KBatch: " << karg.
KBatch << 
" > 1 is not support yet" << __FILE__
 
 1374                               << 
":" << __LINE__ << 
", in function: " << __func__ << std::endl;
 
 1384         const auto num_k_loop = karg.
AK0 / (KPerBlock / AK1Value);
 
 1388             if(num_k_loop <= BlockwiseGemmPipe::PrefetchStages)
 
 1400         const index_t num_loop = K / KPerBlock;
 
 1402         return BlockwiseGemmPipe::BlockHasHotloop(num_loop);
 
 1407         const index_t num_loop = K / KPerBlock;
 
 1409         return BlockwiseGemmPipe::BlockLoopTailNum(num_loop);
 
 1412     template <
typename CGr
idDesc>
 
 1414         const CGridDesc& c_grid_desc_m_n, 
index_t MBlock, 
index_t NBlock)
 
 1423         return c_grid_desc_mblock_mperblock_nblock_nperblock;
 
 1431     template <
typename AGridDesc_AK0_M_K1,
 
 1432               typename BGridDesc_BK0_N_K1,
 
 1433               typename CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
 
 1434               bool HasMainKBlockLoop,
 
 1437     __device__ 
static void Run(
const ADataType* p_a_grid,
 
 1438                                const BDataType* p_b_grid,
 
 1439                                CDataType* p_c_grid,
 
 1442                                const AGridDesc_AK0_M_K1& a_grid_desc_ak0_m_ak1,
 
 1443                                const BGridDesc_BK0_N_K1& b_grid_desc_bk0_n_bk1,
 
 1444                                const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock&
 
 1445                                    c_grid_desc_mblock_mperblock_nblock_nperblock)
 
 1447         const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
 
 1448             p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
 
 1449         const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
 
 1450             p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize());
 
 1451         auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
 
 1452             p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
 
 1457         const auto block_work_idx =
 
 1460         if(!block_2_ctile_map.ValidCTileIndex(
 
 1462                make_tuple(c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(
I0),
 
 1463                           c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(
I2))))
 
 1468         const index_t block_m_id = __builtin_amdgcn_readfirstlane(block_work_idx[
I0]);
 
 1469         const index_t block_n_id = __builtin_amdgcn_readfirstlane(block_work_idx[
I1]);
 
 1472         const index_t m_block_data_idx_on_grid =
 
 1473             __builtin_amdgcn_readfirstlane(block_m_id * MPerBlock);
 
 1475         const index_t n_block_data_idx_on_grid =
 
 1476             __builtin_amdgcn_readfirstlane(block_n_id * NPerBlock);
 
 1488         auto a_blockwise_copy =
 
 1490                                                 AElementwiseOperation,
 
 1494                                                 ABlockTransferThreadClusterLengths_AK0_M_AK1,
 
 1495                                                 ABlockTransferThreadClusterArrangeOrder,
 
 1498                                                 decltype(a_grid_desc_ak0_m_ak1),
 
 1499                                                 decltype(a_block_desc_ak0_m_ak1),
 
 1500                                                 ABlockTransferSrcAccessOrder,
 
 1502                                                 ABlockTransferSrcVectorDim,
 
 1504                                                 ABlockTransferSrcScalarPerVector,
 
 1505                                                 ABlockTransferDstScalarPerVector_AK1,
 
 1508                                                 AThreadTransferSrcResetCoordinateAfterRun,
 
 1510                                                 BlockwiseGemmPipe::GlobalBufferNum>(
 
 1511                 a_grid_desc_ak0_m_ak1,
 
 1514                 a_block_desc_ak0_m_ak1,
 
 1519         auto b_blockwise_copy =
 
 1521                                                 BElementwiseOperation,
 
 1525                                                 BBlockTransferThreadClusterLengths_BK0_N_BK1,
 
 1526                                                 BBlockTransferThreadClusterArrangeOrder,
 
 1529                                                 decltype(b_grid_desc_bk0_n_bk1),
 
 1530                                                 decltype(b_block_desc_bk0_n_bk1),
 
 1531                                                 BBlockTransferSrcAccessOrder,
 
 1533                                                 BBlockTransferSrcVectorDim,
 
 1535                                                 BBlockTransferSrcScalarPerVector,
 
 1536                                                 BBlockTransferDstScalarPerVector_BK1,
 
 1539                                                 BThreadTransferSrcResetCoordinateAfterRun,
 
 1541                                                 BlockwiseGemmPipe::GlobalBufferNum>(
 
 1542                 b_grid_desc_bk0_n_bk1,
 
 1545                 b_block_desc_bk0_n_bk1,
 
 1551             a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
 
 1554         auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
 
 1555             static_cast<ADataType*
>(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
 
 1557         auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
 
 1558             reinterpret_cast<BDataType*
>(
static_cast<char*
>(p_shared) + a_block_space_size_aligned *
 
 1561             b_block_desc_bk0_n_bk1.GetElementSpaceSize());
 
 1567         static_assert(std::is_default_constructible_v<BlockwiseGemmPipe>);
 
 1569         auto c_thread_buf            = blockwise_gemm_pipeline.GetCThreadBuffer();
 
 1571         const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(
 
 1572             (a_grid_desc_ak0_m_ak1.GetLength(
I0) * a_grid_desc_ak0_m_ak1.GetLength(
I2)) /
 
 1575         blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, TailNum>(a_grid_desc_ak0_m_ak1,
 
 1576                                                                          a_block_desc_ak0_m_ak1,
 
 1580                                                                          a_block_slice_copy_step,
 
 1581                                                                          b_grid_desc_bk0_n_bk1,
 
 1582                                                                          b_block_desc_bk0_n_bk1,
 
 1586                                                                          b_block_slice_copy_step,
 
 1588                                                                          num_k_block_main_loop);
 
 1592             static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 &&
 
 1593                               NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0,
 
 1596             constexpr 
index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
 
 1597             constexpr 
index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
 
 1600             constexpr 
auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
 
 1601                 blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
 
 1605             constexpr 
auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp =
 
 1606                 blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
 
 1608             constexpr 
auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I0);
 
 1609             constexpr 
auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I1);
 
 1610             constexpr 
auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I2);
 
 1611             constexpr 
auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I3);
 
 1612             constexpr 
auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I4);
 
 1613             constexpr 
auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I5);
 
 1614             constexpr 
auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I6);
 
 1615             constexpr 
auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I7);
 
 1617             constexpr 
auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
 
 1620             auto c_shuffle_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
 
 1621                 static_cast<CShuffleDataType*
>(p_shared),
 
 1622                 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
 
 1625                 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
 
 1645             const auto c_thread_mtx_on_block =
 
 1646                 blockwise_gemm_pipeline.CalculateCThreadOriginDataIndex(
I0, 
I0, 
I0, 
I0);
 
 1648             const index_t m_thread_data_on_block = c_thread_mtx_on_block[
I0];
 
 1649             const index_t n_thread_data_on_block = c_thread_mtx_on_block[
I1];
 
 1651             const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor =
 
 1657             const auto m_thread_data_on_block_idx =
 
 1658                 m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex(
 
 1661             const auto n_thread_data_on_block_to_n0_n1_n2_adaptor =
 
 1667             const auto n_thread_data_on_block_idx =
 
 1668                 n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex(
 
 1672             const auto& vpgr_to_lds_element_op = [&] {
 
 1673                 if constexpr(DoElementwiseBeforeCShuffle)
 
 1679                     return pass_through;
 
 1682             const auto& lds_to_global_element_op = [&] {
 
 1683                 if constexpr(!DoElementwiseBeforeCShuffle)
 
 1689                     return pass_through;
 
 1697                 decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2),
 
 1698                 decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2),
 
 1700                               CElementwiseOperation,
 
 1702                 Sequence<CShuffleMXdlPerWavePerShuffle,
 
 1703                          CShuffleNXdlPerWavePerShuffle,
 
 1715                 true>{c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
 
 1718                                        m_thread_data_on_block_idx[
I1],
 
 1719                                        n_thread_data_on_block_idx[
I1],
 
 1720                                        m_thread_data_on_block_idx[
I2],
 
 1721                                        m_thread_data_on_block_idx[
I3],
 
 1722                                        m_thread_data_on_block_idx[
I4],
 
 1723                                        n_thread_data_on_block_idx[
I2]),
 
 1724                       vpgr_to_lds_element_op()};
 
 1730                               CElementwiseOperation,
 
 1732                 CGlobalMemoryDataOperation, 
 
 1734                          CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
 
 1736                          CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, 
 
 1737                 CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
 
 1741                 decltype(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock),
 
 1742                 decltype(c_grid_desc_mblock_mperblock_nblock_nperblock),
 
 1745                 CShuffleBlockTransferScalarPerVector_NPerBlock, 
 
 1748                 {c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
 
 1750                  c_grid_desc_mblock_mperblock_nblock_nperblock,
 
 1752                  lds_to_global_element_op()};
 
 1755             constexpr 
auto sfc_c_vgpr =
 
 1758                                   Sequence<CShuffleMXdlPerWavePerShuffle,
 
 1759                                            CShuffleNXdlPerWavePerShuffle,
 
 1768             constexpr 
auto sfc_c_global =
 
 1772                                            CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
 
 1774                                            CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{};
 
 1776             constexpr 
index_t num_access = sfc_c_vgpr.GetNumOfAccess();
 
 1778             static_assert(num_access == sfc_c_global.GetNumOfAccess(), 
"wrong!");
 
 1785                 c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2,
 
 1786                                               sfc_c_vgpr.GetIndexTupleOfNumber(access_id),
 
 1788                                               c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
 
 1789                                               c_shuffle_block_buf);
 
 1795                 c_shuffle_block_copy_lds_to_global.Run(
 
 1796                     c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
 
 1797                     c_shuffle_block_buf,
 
 1798                     c_grid_desc_mblock_mperblock_nblock_nperblock,
 
 1801                 if constexpr(access_id < num_access - 1)
 
 1803                     constexpr 
auto c_global_step = sfc_c_global.GetForwardStep(access_id);
 
 1806                     c_shuffle_block_copy_lds_to_global.MoveDstSliceWindow(
 
 1807                         c_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step);
 
 1813     template <
bool HasMainKBlockLoop,
 
 1816     __device__ 
static void Run(
const ADataType* p_a_grid,
 
 1817                                const BDataType* p_b_grid,
 
 1818                                CDataType* p_c_grid,
 
 1828         const auto c_grid_desc_mblock_mperblock_nblock_nperblock =
 
 1832         Run<decltype(a_grid_desc_ak0_m_ak1),
 
 1833             decltype(b_grid_desc_bk0_n_bk1),
 
 1834             decltype(c_grid_desc_mblock_mperblock_nblock_nperblock),
 
 1836             CGlobalMemoryDataOperation,
 
 1842                      a_grid_desc_ak0_m_ak1,
 
 1843                      b_grid_desc_bk0_n_bk1,
 
 1844                      c_grid_desc_mblock_mperblock_nblock_nperblock);
 
 1847     template <
typename AGridDesc_AK0_M_K1,
 
 1848               typename BGridDesc_BK0_N_K1,
 
 1849               typename CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
 
 1850               bool HasMainKBlockLoop,
 
 1853     __device__ 
static void Run_2Lds(
const ADataType* p_a_grid,
 
 1854                                     const BDataType* p_b_grid,
 
 1855                                     CDataType* p_c_grid,
 
 1859                                     const AGridDesc_AK0_M_K1& a_grid_desc_ak0_m_ak1,
 
 1860                                     const BGridDesc_BK0_N_K1& b_grid_desc_bk0_n_bk1,
 
 1861                                     const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock&
 
 1862                                         c_grid_desc_mblock_mperblock_nblock_nperblock)
 
 1864         const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
 
 1865             p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
 
 1866         const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
 
 1867             p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize());
 
 1868         auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
 
 1869             p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
 
 1874         const auto block_work_idx =
 
 1877         if(!block_2_ctile_map.ValidCTileIndex(
 
 1879                make_tuple(c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(
I0),
 
 1880                           c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(
I2))))
 
 1885         const index_t block_m_id = __builtin_amdgcn_readfirstlane(block_work_idx[
I0]);
 
 1886         const index_t block_n_id = __builtin_amdgcn_readfirstlane(block_work_idx[
I1]);
 
 1889         const index_t m_block_data_idx_on_grid =
 
 1890             __builtin_amdgcn_readfirstlane(block_m_id * MPerBlock);
 
 1892         const index_t n_block_data_idx_on_grid =
 
 1893             __builtin_amdgcn_readfirstlane(block_n_id * NPerBlock);
 
 1905         auto a_blockwise_copy =
 
 1907                                                 AElementwiseOperation,
 
 1911                                                 ABlockTransferThreadClusterLengths_AK0_M_AK1,
 
 1912                                                 ABlockTransferThreadClusterArrangeOrder,
 
 1915                                                 decltype(a_grid_desc_ak0_m_ak1),
 
 1916                                                 decltype(a_block_desc_ak0_m_ak1),
 
 1917                                                 ABlockTransferSrcAccessOrder,
 
 1919                                                 ABlockTransferSrcVectorDim,
 
 1921                                                 ABlockTransferSrcScalarPerVector,
 
 1922                                                 ABlockTransferDstScalarPerVector_AK1,
 
 1925                                                 AThreadTransferSrcResetCoordinateAfterRun,
 
 1927                                                 BlockwiseGemmPipe::GlobalBufferNum>(
 
 1928                 a_grid_desc_ak0_m_ak1,
 
 1931                 a_block_desc_ak0_m_ak1,
 
 1936         auto b_blockwise_copy =
 
 1938                                                 BElementwiseOperation,
 
 1942                                                 BBlockTransferThreadClusterLengths_BK0_N_BK1,
 
 1943                                                 BBlockTransferThreadClusterArrangeOrder,
 
 1946                                                 decltype(b_grid_desc_bk0_n_bk1),
 
 1947                                                 decltype(b_block_desc_bk0_n_bk1),
 
 1948                                                 BBlockTransferSrcAccessOrder,
 
 1950                                                 BBlockTransferSrcVectorDim,
 
 1952                                                 BBlockTransferSrcScalarPerVector,
 
 1953                                                 BBlockTransferDstScalarPerVector_BK1,
 
 1956                                                 BThreadTransferSrcResetCoordinateAfterRun,
 
 1958                                                 BlockwiseGemmPipe::GlobalBufferNum>(
 
 1959                 b_grid_desc_bk0_n_bk1,
 
 1962                 b_block_desc_bk0_n_bk1,
 
 1968             a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
 
 1970         auto a_block_buf_ping = make_dynamic_buffer<AddressSpaceEnum::Lds>(
 
 1971             static_cast<ADataType*
>(p_shared_0), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
 
 1973         auto b_block_buf_ping = make_dynamic_buffer<AddressSpaceEnum::Lds>(
 
 1974             bit_cast<BDataType*>(
static_cast<char*
>(p_shared_0) +
 
 1975                                  a_block_space_size_aligned * 
sizeof(ADataType)),
 
 1976             b_block_desc_bk0_n_bk1.GetElementSpaceSize());
 
 1978         auto a_block_buf_pong = make_dynamic_buffer<AddressSpaceEnum::Lds>(
 
 1979             static_cast<ADataType*
>(p_shared_1), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
 
 1981         auto b_block_buf_pong = make_dynamic_buffer<AddressSpaceEnum::Lds>(
 
 1982             bit_cast<BDataType*>(bit_cast<char*>(p_shared_1) +
 
 1983                                  a_block_space_size_aligned * 
sizeof(ADataType)),
 
 1984             b_block_desc_bk0_n_bk1.GetElementSpaceSize());
 
 1986         auto a_block_bufs = 
make_tuple(a_block_buf_ping, a_block_buf_pong);
 
 1987         auto b_block_bufs = 
make_tuple(b_block_buf_ping, b_block_buf_pong);
 
 1993         static_assert(std::is_default_constructible_v<BlockwiseGemmPipe>);
 
 1995         auto c_thread_buf            = blockwise_gemm_pipeline.GetCThreadBuffer();
 
 1997         const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(
 
 1998             (a_grid_desc_ak0_m_ak1.GetLength(
I0) * a_grid_desc_ak0_m_ak1.GetLength(
I2)) /
 
 2001         blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, TailNum>(a_grid_desc_ak0_m_ak1,
 
 2002                                                                          a_block_desc_ak0_m_ak1,
 
 2006                                                                          a_block_slice_copy_step,
 
 2007                                                                          b_grid_desc_bk0_n_bk1,
 
 2008                                                                          b_block_desc_bk0_n_bk1,
 
 2012                                                                          b_block_slice_copy_step,
 
 2014                                                                          num_k_block_main_loop);
 
 2018             static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 &&
 
 2019                               NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0,
 
 2022             constexpr 
index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
 
 2023             constexpr 
index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
 
 2026             constexpr 
auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
 
 2027                 blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
 
 2031             constexpr 
auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp =
 
 2032                 blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
 
 2034             constexpr 
auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I0);
 
 2035             constexpr 
auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I1);
 
 2036             constexpr 
auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I2);
 
 2037             constexpr 
auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I3);
 
 2038             constexpr 
auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I4);
 
 2039             constexpr 
auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I5);
 
 2040             constexpr 
auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I6);
 
 2041             constexpr 
auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I7);
 
 2043             constexpr 
auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
 
 2046             auto c_shuffle_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
 
 2047                 static_cast<CShuffleDataType*
>(p_shared_0),
 
 2048                 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
 
 2051                 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
 
 2071             const auto c_thread_mtx_on_block =
 
 2072                 blockwise_gemm_pipeline.CalculateCThreadOriginDataIndex(
I0, 
I0, 
I0, 
I0);
 
 2074             const index_t m_thread_data_on_block = c_thread_mtx_on_block[
I0];
 
 2075             const index_t n_thread_data_on_block = c_thread_mtx_on_block[
I1];
 
 2077             const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor =
 
 2083             const auto m_thread_data_on_block_idx =
 
 2084                 m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex(
 
 2087             const auto n_thread_data_on_block_to_n0_n1_n2_adaptor =
 
 2093             const auto n_thread_data_on_block_idx =
 
 2094                 n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex(
 
 2098             auto c_thread_copy_vgpr_to_lds =
 
 2101                                                    decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2),
 
 2102                                                    decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2),
 
 2104                                                    Sequence<CShuffleMXdlPerWavePerShuffle,
 
 2105                                                             CShuffleNXdlPerWavePerShuffle,
 
 2118                     c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
 
 2121                                      m_thread_data_on_block_idx[
I1],
 
 2122                                      n_thread_data_on_block_idx[
I1],
 
 2123                                      m_thread_data_on_block_idx[
I2],
 
 2124                                      m_thread_data_on_block_idx[
I3],
 
 2125                                      m_thread_data_on_block_idx[
I4],
 
 2126                                      n_thread_data_on_block_idx[
I2]),
 
 2132                 CElementwiseOperation,      
 
 2133                 CGlobalMemoryDataOperation, 
 
 2135                          CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
 
 2137                          CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, 
 
 2138                 CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
 
 2142                 decltype(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock),
 
 2143                 decltype(c_grid_desc_mblock_mperblock_nblock_nperblock),
 
 2146                 CShuffleBlockTransferScalarPerVector_NPerBlock, 
 
 2149                 {c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
 
 2151                  c_grid_desc_mblock_mperblock_nblock_nperblock,
 
 2156             constexpr 
auto sfc_c_vgpr =
 
 2159                                   Sequence<CShuffleMXdlPerWavePerShuffle,
 
 2160                                            CShuffleNXdlPerWavePerShuffle,
 
 2169             constexpr 
auto sfc_c_global =
 
 2173                                            CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
 
 2175                                            CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{};
 
 2177             constexpr 
index_t num_access = sfc_c_vgpr.GetNumOfAccess();
 
 2179             static_assert(num_access == sfc_c_global.GetNumOfAccess(), 
"wrong!");
 
 2186                 c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2,
 
 2187                                               sfc_c_vgpr.GetIndexTupleOfNumber(access_id),
 
 2189                                               c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
 
 2190                                               c_shuffle_block_buf);
 
 2196                 c_shuffle_block_copy_lds_to_global.Run(
 
 2197                     c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
 
 2198                     c_shuffle_block_buf,
 
 2199                     c_grid_desc_mblock_mperblock_nblock_nperblock,
 
 2202                 if constexpr(access_id < num_access - 1)
 
 2204                     constexpr 
auto c_global_step = sfc_c_global.GetForwardStep(access_id);
 
 2207                     c_shuffle_block_copy_lds_to_global.MoveDstSliceWindow(
 
 2208                         c_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step);
 
 2214     template <
bool HasMainKBlockLoop,
 
 2217     __device__ 
static void Run_2Lds(
const ADataType* p_a_grid,
 
 2218                                     const BDataType* p_b_grid,
 
 2219                                     CDataType* p_c_grid,
 
 2231         const auto c_grid_desc_mblock_mperblock_nblock_nperblock =
 
 2235         Run_2Lds<decltype(a_grid_desc_ak0_m_ak1),
 
 2236                  decltype(b_grid_desc_bk0_n_bk1),
 
 2237                  decltype(c_grid_desc_mblock_mperblock_nblock_nperblock),
 
 2239                  CGlobalMemoryDataOperation,
 
 2246                           a_grid_desc_ak0_m_ak1,
 
 2247                           b_grid_desc_bk0_n_bk1,
 
 2248                           c_grid_desc_mblock_mperblock_nblock_nperblock);
 
#define CK_MAX_THREAD_PER_BLOCK
Definition: ck.hpp:30
 
Y __host__ constexpr __device__ auto lcm(X x, Y y)
Definition: math.hpp:198
 
__host__ constexpr __device__ auto integer_least_multiple(X x, Y y)
Definition: math.hpp:78
 
__host__ constexpr __device__ auto integer_divide_ceil(X x, Y y)
Definition: math.hpp:72
 
__host__ constexpr __device__ T max(T x)
Definition: math.hpp:84
 
GemmSpecialization
Definition: gemm_specialization.hpp:11
 
ck_tile::element_wise::PassThrough PassThrough
Definition: grouped_convolution_utils.hpp:47
 
__host__ constexpr __device__ auto make_multi_index(Xs &&... xs)
Definition: array_multi_index.hpp:15
 
__host__ constexpr __device__ auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition: tensor_descriptor_helper.hpp:49
 
InMemoryDataOperationEnum
Definition: ck.hpp:277
 
__host__ constexpr __device__ auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition: tensor_descriptor_helper.hpp:101
 
__host__ constexpr __device__ auto make_merge_transform(const LowLengths &low_lengths)
Definition: multi_index_transform_helper.hpp:55
 
BlockGemmPipelineVersion
Definition: blkgemmpipe_scheduler.hpp:12
 
__host__ constexpr __device__ auto make_merge_transform_v3_division_mod(const LowLengths &low_lengths)
Definition: multi_index_transform_helper.hpp:84
 
TailNumber
Definition: blkgemmpipe_scheduler.hpp:31
 
_Float16 half_t
Definition: data_type.hpp:31
 
__host__ constexpr __device__ auto make_single_stage_tensor_adaptor(const Transforms &transforms, LowerDimensionOldTopIdss, UpperDimensionNewTopIdss)
Definition: tensor_adaptor.hpp:425
 
ushort bhalf_t
Definition: data_type.hpp:30
 
__global__ void kernel_gemm_xdl_cshuffle_v3_2lds(typename GridwiseGemm::Argument karg)
Definition: gridwise_gemm_xdl_cshuffle_streamk_v3.hpp:62
 
constexpr auto BlockGemmPipeline_Selector()
Definition: blockwise_gemm_pipeline_wmma_selector.hpp:32
 
__host__ constexpr __device__ auto make_freeze_transform(const LowerIndex &low_idx)
Definition: multi_index_transform_helper.hpp:151
 
constexpr detail::ignore_t ignore
Definition: ignore.hpp:20
 
__device__ index_t get_block_1d_id()
Definition: get_id.hpp:47
 
typename conditional< predicate, X, Y >::type conditional_t
Definition: functional.hpp:115
 
__host__ constexpr __device__ auto make_xor_with_modulo_transform(const LowLengths &low_lengths)
Definition: multi_index_transform_helper.hpp:185
 
bool EnvIsEnabled(EnvVar)
Definition: env.hpp:140
 
__host__ constexpr __device__ auto make_pass_through_transform(const LowLength &low_length)
Definition: multi_index_transform_helper.hpp:12
 
constexpr bool is_same_v
Definition: type.hpp:283
 
BlockGemmPipelineScheduler
Definition: blkgemmpipe_scheduler.hpp:25
 
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
 
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition: type.hpp:297
 
__global__ void kernel_gemm_xdl_cshuffle_v3(typename GridwiseGemm::Argument karg)
Definition: gridwise_gemm_xdl_cshuffle_streamk_v3.hpp:38
 
__host__ constexpr __device__ auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition: multi_index_transform_helper.hpp:90
 
int32_t index_t
Definition: ck.hpp:299
 
__host__ constexpr __device__ auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition: tensor_descriptor.hpp:319
 
__host__ constexpr __device__ auto make_right_pad_transform(const LowLength &low_length, const RightPadLength &right_pad, integral_constant< bool, SkipIsValidCheck >=integral_constant< bool, false >{})
Definition: multi_index_transform_helper.hpp:37
 
__device__ void block_sync_lds()
Definition: synchronization.hpp:16
 
const GenericPointer< typename T::ValueType > T2 value
Definition: pointer.h:1697
 
unsigned int uint32_t
Definition: stdint.h:126
 
signed int int32_t
Definition: stdint.h:123
 
Definition: block_to_ctile_map.hpp:271
 
__host__ static constexpr __device__ index_t CalculateGridSize(index_t M, index_t N)
Definition: block_to_ctile_map.hpp:283
 
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:716
 
const BElementwiseOperation b_element_op
Definition: gridwise_gemm_xdl_cshuffle_v3_b_scale.hpp:643
 
__host__ Argument(const ADataType *p_a_grid_, const BDataType *p_b_grid_, CDataType *p_c_grid_, index_t M_, index_t N_, index_t K_, index_t StrideA_, index_t StrideB_, index_t StrideC_, index_t k_batch_, bool is_reduce_=false, AElementwiseOperation a_element_op=AElementwiseOperation{}, BElementwiseOperation b_element_op=BElementwiseOperation{}, CElementwiseOperation c_element_op=CElementwiseOperation{})
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:717
 
const BDataType * p_b_grid
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:759
 
CDataType * p_c_grid
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:760
 
__host__ __device__ bool IsReduceAdd() const
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:748
 
const AElementwiseOperation a_element_op
Definition: gridwise_gemm_xdl_cshuffle_v3_b_scale.hpp:642
 
__host__ __device__ bool IsAtomicAdd() const
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:753
 
const ADataType * p_a_grid
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:758
 
const CElementwiseOperation c_element_op
Definition: gridwise_gemm_xdl_cshuffle_v3_b_scale.hpp:644
 
bool is_reduce
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:761
 
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:641
 
index_t N
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:695
 
index_t NPadded
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:702
 
index_t KBatch
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:700
 
index_t StrideA
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:697
 
__host__ Problem(index_t M_, index_t N_, index_t K_, index_t StrideA_, index_t StrideB_, index_t StrideC_, index_t KBatch_, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CElementwiseOperation c_element_op)
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:642
 
CElementwiseOperation c_element_op_
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:711
 
index_t BK0
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:706
 
index_t M
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:694
 
index_t NBlock
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:708
 
index_t MPadded
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:701
 
index_t K
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:696
 
index_t StrideB
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:698
 
index_t KPadded
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:704
 
index_t StrideC
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:699
 
index_t MBlock
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:707
 
BElementwiseOperation b_element_op_
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:710
 
index_t AK0
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:705
 
index_t KRead
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:703
 
AElementwiseOperation a_element_op_
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:709
 
__host__ void Print() const
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:673
 
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:765
 
index_t a_k_split_offset
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:814
 
index_t b_k_split_offset
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:815
 
__device__ SplitKBatchOffset(Argument &karg)
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:767
 
index_t c_reduce_offset
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:816
 
"Universal" GEMM kernel with SplitK support.
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:247
 
__host__ static constexpr __device__ auto MakeBMmaTileDescriptor_N0_N1_N2_K(const BBlockDesc_BK0_N_BK1 &)
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:572
 
__host__ static __device__ auto MakeBGridDescriptor_BK0_N_BK1(index_t K, index_t KPad, index_t N, index_t NPad, index_t StrideB, index_t BK0)
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:451
 
static __host__ auto CalculateKRead(index_t K, index_t K_Batch=1)
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:337
 
static constexpr auto is_scale_mfma
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:273
 
static __host__ auto CalculateKPadded(index_t K)
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:314
 
static __host__ auto CalculateMPadded(index_t M)
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:304
 
static constexpr auto BK1Number
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:261
 
static constexpr index_t APackedSize
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:285
 
static constexpr bool is_single_rate_mfma
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:264
 
static constexpr __host__ TailNumber CalculateKBlockLoopTailNum(index_t K)
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:1405
 
ThisThreadBlock< BlockSize > ThisThreadBlock
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:283
 
__host__ static constexpr __device__ auto MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const CGridDesc &c_grid_desc_m_n, index_t MBlock, index_t NBlock)
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:1413
 
static __host__ auto CalculateAK0Padded(index_t K, index_t K_Batch=1)
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:319
 
static constexpr auto I2
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:250
 
static constexpr index_t KPack
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:274
 
static constexpr auto lcm_AK1_BK1
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:263
 
__host__ static constexpr __device__ auto MakeGemmMmaTileDescriptor(const TileDesc_K0_MN_K1 &)
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:355
 
static constexpr auto I7
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:255
 
remove_cvref_t< decltype(BlockGemmPipeline_Selector< BlkGemmPipelineVer, BlkGemmPipeSched, BlockSize, ADataType, BDataType, ComputeTypeA, AccDataType, decltype(GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()), decltype(GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()), decltype(MakeAMmaTileDescriptor_M0_M1_M2_K(GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1())), decltype(MakeBMmaTileDescriptor_N0_N1_N2_K(GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1())), ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerXdl, NPerXdl, MXdlPerWave, NXdlPerWave, KPack >())> BlockwiseGemmPipe
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:1136
 
static constexpr auto I5
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:253
 
static __device__ constexpr bool IsValidCompilationParameter()
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:1166
 
static constexpr auto AK1Number
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:260
 
static constexpr __host__ bool CheckValidity(const Argument &karg)
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:1202
 
static __device__ void Run_2Lds(const ADataType *p_a_grid, const BDataType *p_b_grid, CDataType *p_c_grid, void *p_shared_0, void *p_shared_1, const Problem &problem, const AGridDesc_AK0_M_K1 &a_grid_desc_ak0_m_ak1, const BGridDesc_BK0_N_K1 &b_grid_desc_bk0_n_bk1, const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock &c_grid_desc_mblock_mperblock_nblock_nperblock)
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:1853
 
static __host__ auto CalculateGridSize(index_t M, index_t N, index_t KBatch)
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:299
 
static __host__ auto CalculateMBlock(index_t M)
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:344
 
static __device__ void Run_2Lds(const ADataType *p_a_grid, const BDataType *p_b_grid, CDataType *p_c_grid, void *p_shared_0, void *p_shared_1, const Problem &problem)
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:2217
 
static __host__ auto CalculateNPadded(index_t N)
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:309
 
__host__ static constexpr __device__ auto MakeAMmaTileDescriptor_M0_M1_M2_K(const ABlockDesc_AK0_M_AK1 &)
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:563
 
static __host__ auto CalculateBK0Padded(index_t K, index_t K_Batch=1)
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:325
 
static constexpr __device__ index_t GetSharedMemoryNumberOfByte()
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:1138
 
static constexpr __device__ auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:960
 
static constexpr index_t BPackedSize
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:292
 
static __device__ void Run(const ADataType *p_a_grid, const BDataType *p_b_grid, CDataType *p_c_grid, void *p_shared, const Problem &problem, const AGridDesc_AK0_M_K1 &a_grid_desc_ak0_m_ak1, const BGridDesc_BK0_N_K1 &b_grid_desc_bk0_n_bk1, const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock &c_grid_desc_mblock_mperblock_nblock_nperblock)
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:1437
 
static constexpr auto I6
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:254
 
static constexpr __device__ auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:819
 
static constexpr auto I1
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:249
 
static constexpr auto I0
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:248
 
static constexpr auto I3
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:251
 
static __device__ void Run(const ADataType *p_a_grid, const BDataType *p_b_grid, CDataType *p_c_grid, void *p_shared, const Problem &problem)
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:1816
 
static constexpr auto I4
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:252
 
static __host__ auto CalculateKPadded(index_t K, index_t K_Batch=1)
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:331
 
__host__ static __device__ auto MakeAGridDescriptor_AK0_M_AK1(index_t M, index_t MPad, index_t K, index_t KPad, index_t StrideA, index_t AK0)
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:369
 
static __host__ auto CalculateNBlock(index_t N)
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:349
 
static constexpr auto BK0Number
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:259
 
static constexpr __device__ auto GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock()
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:1097
 
static constexpr auto AK0Number
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:258
 
static constexpr __host__ bool CalculateHasMainKBlockLoop(index_t K)
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:1398
 
__host__ static __device__ auto MakeCGridDescriptor_M_N(index_t M, index_t MPad, index_t N, index_t NPad, index_t StrideC)
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:580
 
Selects the appropriate MFMA instruction type and configuration for given data types and tile sizes o...
Definition: xdlops_gemm.hpp:1208
 
Definition: sequence.hpp:43
 
Definition: tensor_space_filling_curve.hpp:20
 
Blockwise data transfer.
Definition: thread_group_tensor_slice_transfer_v4r1.hpp:46
 
Definition: thread_group_tensor_slice_transfer_v6r1.hpp:34
 
Definition: threadwise_tensor_slice_transfer.hpp:39
 
Definition: integral_constant.hpp:20
 
Definition: data_type.hpp:187
 
Definition: functional2.hpp:33
 
Definition: device_base.hpp:197
 
Definition: unary_element_wise_operation.hpp:340
 
#define CK_ENV(name)
Definition: env.hpp:129