20           typename ComputeDataType,
 
   24           typename AMmaTileDesc,
 
   25           typename BMmaTileDesc,
 
   26           index_t ABlockTransferSrcScalarPerVector,
 
   27           index_t BBlockTransferSrcScalarPerVector,
 
   43           typename ComputeDataType,
 
   47           typename AMmaTileDesc,
 
   48           typename BMmaTileDesc,
 
   49           index_t ABlockTransferSrcScalarPerVector,
 
   50           index_t BBlockTransferSrcScalarPerVector,
 
   71                                        ABlockTransferSrcScalarPerVector,
 
   72                                        BBlockTransferSrcScalarPerVector,
 
   90                                         ABlockTransferSrcScalarPerVector,
 
   91                                         BBlockTransferSrcScalarPerVector,
 
  111                                                    ABlockTransferSrcScalarPerVector,
 
  112                                                    BBlockTransferSrcScalarPerVector,
 
  124     using Base::xdlops_gemm;
 
  127     using Base::CalculateCThreadOriginDataIndex;
 
  128     using Base::CalculateCThreadOriginDataIndex8D;
 
  129     using Base::GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2;
 
  130     using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2;
 
  131     using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4;
 
  132     using Base::GetCThreadBuffer;
 
  133     using Base::GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2;
 
  134     using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2;
 
  135     using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4;
 
  136     using Base::MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2;
 
  137     using Base::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2;
 
  139     using Base::a_block_desc_m0_m1_m2_k;
 
  140     using Base::b_block_desc_n0_n1_n2_k;
 
  142     using Base::AMmaKStride;
 
  143     using Base::BMmaKStride;
 
  153         return num_loop > PrefetchStages;
 
  164 #if !defined(__gfx11__) && !defined(__gfx12__) 
  167         constexpr 
auto num_ds_read_inst_a =
 
  168             HotLoopInstList::A_LDS_Read_Width * 
sizeof(ADataType) == 16
 
  169                 ? HotLoopInstList::A_LDS_Read_Inst_Num
 
  170                 : HotLoopInstList::A_LDS_Read_Inst_Num / 2;
 
  171         constexpr 
auto num_ds_read_inst_b =
 
  172             HotLoopInstList::B_LDS_Read_Width * 
sizeof(BDataType) == 16
 
  173                 ? HotLoopInstList::B_LDS_Read_Inst_Num
 
  174                 : HotLoopInstList::B_LDS_Read_Inst_Num / 2;
 
  176         constexpr 
auto num_ds_write_inst_a = HotLoopInstList::A_LDS_Write_Inst_Num;
 
  177         constexpr 
auto num_ds_write_inst_b = HotLoopInstList::B_LDS_Write_Inst_Num;
 
  179         constexpr 
auto num_buffer_load_inst_a = HotLoopInstList::A_Buffer_Load_Inst_Num;
 
  180         constexpr 
auto num_buffer_load_inst_b = HotLoopInstList::B_Buffer_Load_Inst_Num;
 
  182         constexpr 
auto num_mfma_inst = HotLoopInstList::C_MFMA_Inst_Num;
 
  183         constexpr 
auto mfma_cycle    = HotLoopInstList::C_MFMA_Inst_Cycle;
 
  185         constexpr 
auto ds_read_a_issue_cycle =
 
  186             HotLoopInstList::A_LDS_Read_Width * 
sizeof(ADataType) == 16 ? 8 : 4;
 
  187         constexpr 
auto ds_read_b_issue_cycle =
 
  188             HotLoopInstList::B_LDS_Read_Width * 
sizeof(BDataType) == 16 ? 8 : 4;
 
  189         constexpr 
auto ds_read_a_mfma_rate =
 
  190             (mfma_cycle - 4 + 2 * ds_read_a_issue_cycle - 1) / (2 * ds_read_a_issue_cycle);
 
  191         constexpr 
auto ds_read_b_mfma_rate =
 
  192             (mfma_cycle - 4 + 2 * ds_read_b_issue_cycle - 1) / (2 * ds_read_b_issue_cycle);
 
  194         constexpr 
auto num_dsread_a_mfma =
 
  195             (num_ds_read_inst_a + ds_read_a_mfma_rate - 1) / ds_read_a_mfma_rate;
 
  196         constexpr 
auto num_dsread_b_mfma =
 
  197             (num_ds_read_inst_b + ds_read_b_mfma_rate - 1) / ds_read_b_mfma_rate;
 
  205         constexpr 
auto num_mfma_stage1 = num_mfma_inst - (num_dsread_a_mfma + num_dsread_b_mfma);
 
  206         constexpr 
auto num_mfma_per_issue =
 
  207             num_mfma_stage1 / (num_buffer_load_inst_a + num_buffer_load_inst_b);
 
  208         constexpr 
auto num_dswrite_per_issue_a = num_ds_write_inst_a / num_buffer_load_inst_a;
 
  209         constexpr 
auto num_dswrite_per_issue_b = num_ds_write_inst_b / num_buffer_load_inst_b;
 
  215                 __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); 
 
  216                 __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); 
 
  218             __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); 
 
  219             __builtin_amdgcn_sched_group_barrier(
 
  220                 0x008, num_mfma_per_issue - num_dswrite_per_issue_a, 0); 
 
  226                 __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); 
 
  227                 __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); 
 
  229             __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); 
 
  230             __builtin_amdgcn_sched_group_barrier(
 
  231                 0x008, num_mfma_per_issue - num_dswrite_per_issue_b, 0); 
 
  236             if constexpr((num_ds_read_inst_a - (i + 1) * ds_read_a_mfma_rate) >=
 
  239                 __builtin_amdgcn_sched_group_barrier(0x100, ds_read_a_mfma_rate, 0); 
 
  243                 __builtin_amdgcn_sched_group_barrier(0x100,
 
  244                                                      num_ds_read_inst_a - (num_dsread_a_mfma - 1) *
 
  248             __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); 
 
  252             if constexpr((num_ds_read_inst_b - (i + 1) * ds_read_b_mfma_rate) >=
 
  255                 __builtin_amdgcn_sched_group_barrier(0x100, ds_read_b_mfma_rate, 0); 
 
  259                 __builtin_amdgcn_sched_group_barrier(0x100,
 
  260                                                      num_ds_read_inst_b - (num_dsread_b_mfma - 1) *
 
  264             __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); 
 
  269     template <
bool HasMainLoop,
 
  273               typename ABlockTransfer,
 
  274               typename AGridBuffer,
 
  275               typename ABlockBuffer,
 
  276               typename ABlockTransferStep,
 
  279               typename BBlockTransfer,
 
  280               typename BGridBuffer,
 
  281               typename BBlockBuffer,
 
  282               typename BBlockTransferStep,
 
  283               typename CThreadBuffer>
 
  284     __device__ 
void Run(
const AGridDesc& a_grid_desc,
 
  285                         const ABlockDesc& a_block_desc,
 
  286                         ABlockTransfer& a_blockwise_copy,
 
  287                         const AGridBuffer& a_grid_buf,
 
  288                         ABlockBuffer& a_block_buf,
 
  289                         const ABlockTransferStep& a_block_copy_step,
 
  290                         const BGridDesc& b_grid_desc,
 
  291                         const BBlockDesc& b_block_desc,
 
  292                         BBlockTransfer& b_blockwise_copy,
 
  293                         const BGridBuffer& b_grid_buf,
 
  294                         BBlockBuffer& b_block_buf,
 
  295                         const BBlockTransferStep& b_block_copy_step,
 
  296                         CThreadBuffer& c_thread_buf,
 
  299         __builtin_amdgcn_sched_barrier(0);
 
  300         auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataTypeBuf>(
 
  301             a_thread_desc_.GetElementSpaceSize());
 
  302         auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataTypeBuf>(
 
  303             b_thread_desc_.GetElementSpaceSize());
 
  306         a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
 
  307         b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
 
  309         a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
 
  310         b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
 
  313         a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
 
  314         b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
 
  317         a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
 
  318         b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
 
  320         a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
 
  321         b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
 
  324         c_thread_buf.Clear();
 
  330                 a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
 
  338                 b_thread_copy_.Run(b_block_desc_n0_n1_n2_k,
 
  347         __builtin_amdgcn_sched_barrier(0);
 
  350         if constexpr(HasMainLoop)
 
  357                 a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
 
  358                 b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
 
  360                 a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
 
  361                 b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
 
  363                 a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
 
  364                 b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
 
  373                                 a_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
 
  374                                     a_thread_buf[
Number<a_thread_desc_.CalculateOffset(
 
  376                                 b_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
 
  377                                     b_thread_buf[
Number<b_thread_desc_.CalculateOffset(
 
  381                             using mfma_input_type =
 
  383                                                      xdlops_gemm.K1PerXdlops>::type;
 
  386                                 c_thread_desc_.CalculateOffset(
make_tuple(m0, n0, 0));
 
  389                                 a_thread_vec.template AsType<mfma_input_type>(),
 
  390                                 b_thread_vec.template AsType<mfma_input_type>(),
 
  400                         a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
 
  408                         b_thread_copy_.Run(b_block_desc_n0_n1_n2_k,
 
  418                 __builtin_amdgcn_sched_barrier(0);
 
  421             } 
while(i < (num_loop - 1));
 
  433                             a_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
 
  434                                 a_thread_buf[
Number<a_thread_desc_.CalculateOffset(
 
  436                             b_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
 
  437                                 b_thread_buf[
Number<b_thread_desc_.CalculateOffset(
 
  441                         using mfma_input_type =
 
  445                             c_thread_desc_.CalculateOffset(
make_tuple(m0, n0, 0));
 
  447                         xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
 
  448                                         b_thread_vec.template AsType<mfma_input_type>(),
 
  460     using Base::a_thread_copy_;
 
  461     using Base::a_thread_desc_;
 
  462     using Base::b_thread_copy_;
 
  463     using Base::b_thread_desc_;
 
  464     using Base::c_thread_desc_;
 
TailNumber
Definition: blkgemmpipe_scheduler.hpp:31
 
constexpr detail::ignore_t ignore
Definition: ignore.hpp:20
 
BlockGemmPipelineScheduler
Definition: blkgemmpipe_scheduler.hpp:25
 
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
 
int32_t index_t
Definition: ck.hpp:299
 
__device__ void block_sync_lds()
Definition: synchronization.hpp:16
 
Definition: blockwise_gemm_pipeline_xdlops_base.hpp:35
 
conditional_t< std::is_same< ComputeDataType, ck::tf32_t >::value, float, ComputeDataType > ComputeDataTypeBuf
Definition: blockwise_gemm_pipeline_xdlops_base.hpp:58
 
Definition: blockwise_gemm_pipeline_xdlops.hpp:34
 
ck::BlockwiseGemmXdlops_pipeline_v3< BlockGemmPipelineScheduler::Intrawave, BlockSize, ADataType, BDataType, ComputeDataType, AccDataType, ATileDesc, BTileDesc, AMmaTileDesc, BMmaTileDesc, ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerXDL, NPerXDL, MRepeat, NRepeat, KPack >::BlockLoopTailNum __host__ static constexpr __device__ TailNumber BlockLoopTailNum(index_t num_loop)
Definition: blockwise_gemm_pipeline_xdlops_v3.hpp:156
 
ck::BlockwiseGemmXdlops_pipeline_v3< BlockGemmPipelineScheduler::Intrawave, BlockSize, ADataType, BDataType, ComputeDataType, AccDataType, ATileDesc, BTileDesc, AMmaTileDesc, BMmaTileDesc, ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerXDL, NPerXDL, MRepeat, NRepeat, KPack >::Run __device__ void Run(const AGridDesc &a_grid_desc, const ABlockDesc &a_block_desc, ABlockTransfer &a_blockwise_copy, const AGridBuffer &a_grid_buf, ABlockBuffer &a_block_buf, const ABlockTransferStep &a_block_copy_step, const BGridDesc &b_grid_desc, const BBlockDesc &b_block_desc, BBlockTransfer &b_blockwise_copy, const BGridBuffer &b_grid_buf, BBlockBuffer &b_block_buf, const BBlockTransferStep &b_block_copy_step, CThreadBuffer &c_thread_buf, index_t num_loop) const
Definition: blockwise_gemm_pipeline_xdlops_v3.hpp:284
 
ck::BlockwiseGemmXdlops_pipeline_v3< BlockGemmPipelineScheduler::Intrawave, BlockSize, ADataType, BDataType, ComputeDataType, AccDataType, ATileDesc, BTileDesc, AMmaTileDesc, BMmaTileDesc, ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerXDL, NPerXDL, MRepeat, NRepeat, KPack >::HotLoopScheduler static constexpr __device__ auto HotLoopScheduler()
Definition: blockwise_gemm_pipeline_xdlops_v3.hpp:162
 
ck::BlockwiseGemmXdlops_pipeline_v3< BlockGemmPipelineScheduler::Intrawave, BlockSize, ADataType, BDataType, ComputeDataType, AccDataType, ATileDesc, BTileDesc, AMmaTileDesc, BMmaTileDesc, ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerXDL, NPerXDL, MRepeat, NRepeat, KPack >::BlockHasHotloop __host__ static constexpr __device__ bool BlockHasHotloop(index_t num_loop)
Definition: blockwise_gemm_pipeline_xdlops_v3.hpp:151
 
Definition: blockwise_gemm_pipeline_xdlops_v3.hpp:37
 
Definition: integral_constant.hpp:20
 
Definition: functional2.hpp:33
 
Definition: dtype_vector.hpp:10