20           typename ComputeDataType,
 
   24           typename AMmaTileDesc,
 
   25           typename BMmaTileDesc,
 
   26           index_t ABlockTransferSrcScalarPerVector,
 
   27           index_t BBlockTransferSrcScalarPerVector,
 
   36 struct BlockwiseGemmXdlops_pipeline_v4
 
   43           typename ComputeDataType,
 
   47           typename AMmaTileDesc,
 
   48           typename BMmaTileDesc,
 
   49           index_t ABlockTransferSrcScalarPerVector,
 
   50           index_t BBlockTransferSrcScalarPerVector,
 
   71                                        ABlockTransferSrcScalarPerVector,
 
   72                                        BBlockTransferSrcScalarPerVector,
 
   90                                         ABlockTransferSrcScalarPerVector,
 
   91                                         BBlockTransferSrcScalarPerVector,
 
  111                                                    ABlockTransferSrcScalarPerVector,
 
  112                                                    BBlockTransferSrcScalarPerVector,
 
  124     using Base::xdlops_gemm;
 
  127     using Base::CalculateCThreadOriginDataIndex;
 
  128     using Base::CalculateCThreadOriginDataIndex8D;
 
  129     using Base::GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2;
 
  130     using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2;
 
  131     using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4;
 
  132     using Base::GetCThreadBuffer;
 
  133     using Base::GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2;
 
  134     using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2;
 
  135     using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4;
 
  136     using Base::MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2;
 
  137     using Base::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2;
 
  139     using Base::a_block_desc_m0_m1_m2_k;
 
  140     using Base::b_block_desc_n0_n1_n2_k;
 
  142     using Base::AMmaKStride;
 
  143     using Base::BMmaKStride;
 
  154         return num_loop > PrefetchStages;
 
  159         if(num_loop % HotloopUnroll == 1)
 
  173         constexpr 
auto num_ds_read_inst_a =
 
  177         constexpr 
auto num_ds_read_inst_b =
 
  183         constexpr 
auto num_dswrite_per_issue_a =
 
  185         constexpr 
auto num_dsread_per_issue_a = num_ds_read_inst_a / num_issue_a;
 
  188         constexpr 
auto num_dswrite_per_issue_b =
 
  190         constexpr 
auto num_dsread_per_issue_b = num_ds_read_inst_b / num_issue_b;
 
  192         constexpr 
auto num_mfma_per_issue =
 
  199                 __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); 
 
  200                 __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); 
 
  205                 __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); 
 
  206                 __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); 
 
  209             __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); 
 
  210             __builtin_amdgcn_sched_group_barrier(0x008,
 
  211                                                  num_mfma_per_issue - num_dsread_per_issue_a -
 
  212                                                      num_dswrite_per_issue_a,
 
  220                 __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); 
 
  221                 __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); 
 
  226                 __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); 
 
  227                 __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); 
 
  230             __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); 
 
  231             __builtin_amdgcn_sched_group_barrier(0x008,
 
  232                                                  num_mfma_per_issue - num_dsread_per_issue_a -
 
  233                                                      num_dswrite_per_issue_b,
 
  236         __builtin_amdgcn_sched_barrier(0);
 
  239     template <
bool HasMainLoop,
 
  243               typename ABlockTransfer,
 
  244               typename AGridBuffer,
 
  245               typename ABlockBuffer,
 
  246               typename ABlockTransferStep,
 
  249               typename BBlockTransfer,
 
  250               typename BGridBuffer,
 
  251               typename BBlockBuffer,
 
  252               typename BBlockTransferStep,
 
  253               typename CThreadBuffer>
 
  254     __device__ 
void Run(
const AGridDesc& a_grid_desc,
 
  255                         const ABlockDesc& a_block_desc,
 
  256                         ABlockTransfer& a_blockwise_copy,
 
  257                         const AGridBuffer& a_grid_buf,
 
  258                         ABlockBuffer& a_block_buf,
 
  259                         const ABlockTransferStep& a_block_copy_step,
 
  260                         const BGridDesc& b_grid_desc,
 
  261                         const BBlockDesc& b_block_desc,
 
  262                         BBlockTransfer& b_blockwise_copy,
 
  263                         const BGridBuffer& b_grid_buf,
 
  264                         BBlockBuffer& b_block_buf,
 
  265                         const BBlockTransferStep& b_block_copy_step,
 
  266                         CThreadBuffer& c_thread_buf,
 
  269         auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataTypeBuf>(
 
  271         auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataTypeBuf>(
 
  278         a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
 
  279         b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
 
  281         a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
 
  282         b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
 
  285         a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(
I0));
 
  286         b_blockwise_copy.RunWrite(b_block_desc, b_block_buf.At(
I0));
 
  310         a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
 
  311         b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
 
  313         a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
 
  314         b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
 
  317         a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(
I1));
 
  318         b_blockwise_copy.RunWrite(b_block_desc, b_block_buf.At(
I1));
 
  321         a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
 
  322         b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
 
  324         a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
 
  325         b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
 
  328         c_thread_buf.Clear();
 
  331         if constexpr(HasMainLoop)
 
  337                 auto LoopFunc = [&](
auto lds_read_buf,
 
  338                                     auto lds_read_reg_buf,
 
  347                                                a_block_buf.At(lds_read_buf),
 
  350                                                a_thread_bufs(lds_read_reg_buf));
 
  355                                                b_block_buf.At(lds_read_buf),
 
  358                                                b_thread_bufs(lds_read_reg_buf));
 
  362                     a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(lds_write_buf));
 
  363                     b_blockwise_copy.RunWrite(b_block_desc, b_block_buf.At(lds_write_buf));
 
  365                     a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
 
  366                     b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
 
  368                     a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
 
  369                     b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
 
  378                                     a_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
 
  379                                         a_thread_bufs[mfma_reg_buf]
 
  382                                     b_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
 
  383                                         b_thread_bufs[mfma_reg_buf]
 
  388                                 using mfma_input_type =
 
  396                                     a_thread_vec.template AsType<mfma_input_type>(),
 
  397                                     b_thread_vec.template AsType<mfma_input_type>(),
 
  410             } 
while(i < (num_loop - PrefetchStages));
 
  413         auto ReadWriteCompFunc = [&](
auto lds_read_buf,
 
  414                                      auto lds_read_reg_buf,
 
  423                                        a_block_buf.At(lds_read_buf),
 
  426                                        a_thread_bufs(lds_read_reg_buf));
 
  431                                        b_block_buf.At(lds_read_buf),
 
  434                                        b_thread_bufs(lds_read_reg_buf));
 
  438             a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(lds_write_buf));
 
  439             b_blockwise_copy.RunWrite(b_block_desc, b_block_buf.At(lds_write_buf));
 
  448                             a_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
 
  451                             b_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
 
  456                         using mfma_input_type =
 
  462                         xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
 
  463                                         b_thread_vec.template AsType<mfma_input_type>(),
 
  472         auto ReadCompFunc = [&](
auto lds_read_buf, 
auto lds_read_reg_buf, 
auto mfma_reg_buf) {
 
  479                                        a_block_buf.At(lds_read_buf),
 
  482                                        a_thread_bufs(lds_read_reg_buf));
 
  487                                        b_block_buf.At(lds_read_buf),
 
  490                                        b_thread_bufs(lds_read_reg_buf));
 
  501                             a_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
 
  504                             b_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
 
  509                         using mfma_input_type =
 
  515                         xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
 
  516                                         b_thread_vec.template AsType<mfma_input_type>(),
 
  525         auto CompFunc = [&](
auto mfma_reg_buf) {
 
  533                             a_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
 
  536                             b_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
 
  541                         using mfma_input_type =
 
  547                         xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
 
  548                                         b_thread_vec.template AsType<mfma_input_type>(),
 
  569     using Base::a_thread_copy_;
 
  570     using Base::a_thread_desc_;
 
  571     using Base::b_thread_copy_;
 
  572     using Base::b_thread_desc_;
 
  573     using Base::c_thread_desc_;
 
  587           typename ComputeDataType,
 
  588           typename AccDataType,
 
  591           typename AMmaTileDesc,
 
  592           typename BMmaTileDesc,
 
  593           index_t ABlockTransferSrcScalarPerVector,
 
  594           index_t BBlockTransferSrcScalarPerVector,
 
  610           typename ComputeDataType,
 
  611           typename AccDataType,
 
  614           typename AMmaTileDesc,
 
  615           typename BMmaTileDesc,
 
  616           index_t ABlockTransferSrcScalarPerVector,
 
  617           index_t BBlockTransferSrcScalarPerVector,
 
  638                                                  ABlockTransferSrcScalarPerVector,
 
  639                                                  BBlockTransferSrcScalarPerVector,
 
  657                                         ABlockTransferSrcScalarPerVector,
 
  658                                         BBlockTransferSrcScalarPerVector,
 
  678                                                    ABlockTransferSrcScalarPerVector,
 
  679                                                    BBlockTransferSrcScalarPerVector,
 
  691     using Base::xdlops_gemm;
 
  694     using Base::CalculateCThreadOriginDataIndex;
 
  695     using Base::CalculateCThreadOriginDataIndex8D;
 
  696     using Base::GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2;
 
  697     using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2;
 
  698     using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4;
 
  699     using Base::GetCThreadBuffer;
 
  700     using Base::GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2;
 
  701     using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2;
 
  702     using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4;
 
  703     using Base::MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2;
 
  704     using Base::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2;
 
  706     using Base::a_block_desc_m0_m1_m2_k;
 
  707     using Base::b_block_desc_n0_n1_n2_k;
 
  709     using Base::AMmaKStride;
 
  710     using Base::BMmaKStride;
 
  721         return num_loop > PrefetchStages;
 
  726         if(num_loop % HotloopUnroll == 1)
 
  740         constexpr 
auto num_ds_read_inst_a =
 
  741             HotLoopInstList::A_LDS_Read_Width * 
sizeof(ADataType) == 16
 
  742                 ? HotLoopInstList::A_LDS_Read_Inst_Num
 
  743                 : HotLoopInstList::A_LDS_Read_Inst_Num / 2;
 
  744         constexpr 
auto num_ds_read_inst_b =
 
  745             HotLoopInstList::B_LDS_Read_Width * 
sizeof(BDataType) == 16
 
  746                 ? HotLoopInstList::B_LDS_Read_Inst_Num
 
  747                 : HotLoopInstList::B_LDS_Read_Inst_Num / 2;
 
  749         constexpr 
auto num_issue_a             = HotLoopInstList::A_Buffer_Load_Inst_Num;
 
  750         constexpr 
auto num_dswrite_per_issue_a = 0;
 
  751         constexpr 
auto num_dsread_per_issue_a  = num_ds_read_inst_a / num_issue_a;
 
  753         constexpr 
auto num_issue_b             = HotLoopInstList::B_Buffer_Load_Inst_Num;
 
  754         constexpr 
auto num_dswrite_per_issue_b = 0;
 
  755         constexpr 
auto num_dsread_per_issue_b  = num_ds_read_inst_b / num_issue_b;
 
  757         constexpr 
auto num_mfma_per_issue =
 
  758             HotLoopInstList::C_MFMA_Inst_Num / (num_issue_a + num_issue_b);
 
  764                 __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); 
 
  765                 __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); 
 
  770                 __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); 
 
  771                 __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); 
 
  774             __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); 
 
  775             __builtin_amdgcn_sched_group_barrier(0x008,
 
  776                                                  num_mfma_per_issue - num_dsread_per_issue_a -
 
  777                                                      num_dswrite_per_issue_a,
 
  785                 __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); 
 
  786                 __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); 
 
  791                 __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); 
 
  792                 __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); 
 
  795             __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); 
 
  796             __builtin_amdgcn_sched_group_barrier(0x008,
 
  797                                                  num_mfma_per_issue - num_dsread_per_issue_a -
 
  798                                                      num_dswrite_per_issue_b,
 
  801         __builtin_amdgcn_sched_barrier(0);
 
  804     template <
bool HasMainLoop,
 
  808               typename ABlockTransfer,
 
  809               typename AGridBuffer,
 
  810               typename ABlockBuffer,
 
  811               typename ABlockTransferStep,
 
  814               typename BBlockTransfer,
 
  815               typename BGridBuffer,
 
  816               typename BBlockBuffer,
 
  817               typename BBlockTransferStep,
 
  818               typename CThreadBuffer>
 
  819     __device__ 
void Run(
const AGridDesc& a_grid_desc,
 
  820                         const ABlockDesc& a_block_desc,
 
  821                         ABlockTransfer& a_blockwise_copy,
 
  822                         const AGridBuffer& a_grid_buf,
 
  823                         ABlockBuffer& a_block_buf,
 
  824                         const ABlockTransferStep& a_block_copy_step,
 
  825                         const BGridDesc& b_grid_desc,
 
  826                         const BBlockDesc& b_block_desc,
 
  827                         BBlockTransfer& b_blockwise_copy,
 
  828                         const BGridBuffer& b_grid_buf,
 
  829                         BBlockBuffer& b_block_buf,
 
  830                         const BBlockTransferStep& b_block_copy_step,
 
  831                         CThreadBuffer& c_thread_buf,
 
  834         auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataTypeBuf>(
 
  835             a_thread_desc_.GetElementSpaceSize());
 
  836         auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataTypeBuf>(
 
  837             b_thread_desc_.GetElementSpaceSize());
 
  843         a_blockwise_copy.Run(a_grid_desc, a_grid_buf, a_block_desc, a_block_buf.At(I0));
 
  844         b_blockwise_copy.Run(b_grid_desc, b_grid_buf, b_block_desc, b_block_buf.At(I0));
 
  846         a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
 
  847         b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
 
  854                 a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
 
  862                 b_thread_copy_.Run(b_block_desc_n0_n1_n2_k,
 
  872         a_blockwise_copy.Run(a_grid_desc, a_grid_buf, a_block_desc, a_block_buf.At(I1));
 
  873         b_blockwise_copy.Run(b_grid_desc, b_grid_buf, b_block_desc, b_block_buf.At(I1));
 
  875         a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
 
  876         b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
 
  879         c_thread_buf.Clear();
 
  882         if constexpr(HasMainLoop)
 
  888                 auto LoopFunc = [&](
auto lds_read_buf,
 
  889                                     auto lds_read_reg_buf,
 
  896                             a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
 
  898                                                a_block_buf.At(lds_read_buf),
 
  901                                                a_thread_bufs(lds_read_reg_buf));
 
  904                             b_thread_copy_.Run(b_block_desc_n0_n1_n2_k,
 
  906                                                b_block_buf.At(lds_read_buf),
 
  909                                                b_thread_bufs(lds_read_reg_buf));
 
  913                     a_blockwise_copy.Run(
 
  914                         a_grid_desc, a_grid_buf, a_block_desc, a_block_buf.At(lds_write_buf));
 
  915                     b_blockwise_copy.Run(
 
  916                         b_grid_desc, b_grid_buf, b_block_desc, b_block_buf.At(lds_write_buf));
 
  918                     a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
 
  919                     b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
 
  928                                     a_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
 
  929                                         a_thread_bufs[mfma_reg_buf]
 
  930                                                      [
Number<a_thread_desc_.CalculateOffset(
 
  932                                     b_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
 
  933                                         b_thread_bufs[mfma_reg_buf]
 
  934                                                      [
Number<b_thread_desc_.CalculateOffset(
 
  938                                 using mfma_input_type =
 
  940                                                          xdlops_gemm.K1PerXdlops>::type;
 
  943                                     c_thread_desc_.CalculateOffset(
make_tuple(m0, n0, 0));
 
  946                                     a_thread_vec.template AsType<mfma_input_type>(),
 
  947                                     b_thread_vec.template AsType<mfma_input_type>(),
 
  956                 LoopFunc(I1, I1, I0, I0);
 
  957                 LoopFunc(I0, I0, I1, I1);
 
  960             } 
while(i < (num_loop - PrefetchStages));
 
  963         auto ReadWriteCompFunc = [&](
auto lds_read_buf,
 
  964                                      auto lds_read_reg_buf,
 
  971                     a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
 
  973                                        a_block_buf.At(lds_read_buf),
 
  976                                        a_thread_bufs(lds_read_reg_buf));
 
  979                     b_thread_copy_.Run(b_block_desc_n0_n1_n2_k,
 
  981                                        b_block_buf.At(lds_read_buf),
 
  984                                        b_thread_bufs(lds_read_reg_buf));
 
  988             a_blockwise_copy.Run(
 
  989                 a_grid_desc, a_grid_buf, a_block_desc, a_block_buf.At(lds_write_buf));
 
  990             b_blockwise_copy.Run(
 
  991                 b_grid_desc, b_grid_buf, b_block_desc, b_block_buf.At(lds_write_buf));
 
 1000                             a_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
 
 1001                                 a_thread_bufs[mfma_reg_buf][
Number<a_thread_desc_.CalculateOffset(
 
 1003                             b_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
 
 1004                                 b_thread_bufs[mfma_reg_buf][
Number<b_thread_desc_.CalculateOffset(
 
 1008                         using mfma_input_type =
 
 1012                             c_thread_desc_.CalculateOffset(
make_tuple(m0, n0, 0));
 
 1014                         xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
 
 1015                                         b_thread_vec.template AsType<mfma_input_type>(),
 
 1024         auto ReadCompFunc = [&](
auto lds_read_buf, 
auto lds_read_reg_buf, 
auto mfma_reg_buf) {
 
 1029                     a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
 
 1031                                        a_block_buf.At(lds_read_buf),
 
 1034                                        a_thread_bufs(lds_read_reg_buf));
 
 1037                     b_thread_copy_.Run(b_block_desc_n0_n1_n2_k,
 
 1039                                        b_block_buf.At(lds_read_buf),
 
 1042                                        b_thread_bufs(lds_read_reg_buf));
 
 1053                             a_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
 
 1054                                 a_thread_bufs[mfma_reg_buf][
Number<a_thread_desc_.CalculateOffset(
 
 1056                             b_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
 
 1057                                 b_thread_bufs[mfma_reg_buf][
Number<b_thread_desc_.CalculateOffset(
 
 1061                         using mfma_input_type =
 
 1065                             c_thread_desc_.CalculateOffset(
make_tuple(m0, n0, 0));
 
 1067                         xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
 
 1068                                         b_thread_vec.template AsType<mfma_input_type>(),
 
 1077         auto CompFunc = [&](
auto mfma_reg_buf) {
 
 1085                             a_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
 
 1086                                 a_thread_bufs[mfma_reg_buf][
Number<a_thread_desc_.CalculateOffset(
 
 1088                             b_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
 
 1089                                 b_thread_bufs[mfma_reg_buf][
Number<b_thread_desc_.CalculateOffset(
 
 1093                         using mfma_input_type =
 
 1097                             c_thread_desc_.CalculateOffset(
make_tuple(m0, n0, 0));
 
 1099                         xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
 
 1100                                         b_thread_vec.template AsType<mfma_input_type>(),
 
 1109             ReadWriteCompFunc(I1, I1, I0, I0);
 
 1110             ReadCompFunc(I0, I0, I1);
 
 1115             ReadCompFunc(I1, I1, I0);
 
 1121     using Base::a_thread_copy_;
 
 1122     using Base::a_thread_desc_;
 
 1123     using Base::b_thread_copy_;
 
 1124     using Base::b_thread_desc_;
 
 1125     using Base::c_thread_desc_;
 
typename detail::StaticallyIndexedArrayImpl< T, N >::type StaticallyIndexedArray
Definition: statically_indexed_array.hpp:45
 
TailNumber
Definition: blkgemmpipe_scheduler.hpp:31
 
__device__ void block_sync_lds_direct_load()
Definition: synchronization.hpp:43
 
constexpr detail::ignore_t ignore
Definition: ignore.hpp:20
 
BlockGemmPipelineScheduler
Definition: blkgemmpipe_scheduler.hpp:25
 
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
 
int32_t index_t
Definition: ck.hpp:299
 
__device__ void block_sync_lds()
Definition: synchronization.hpp:16
 
Definition: blockwise_gemm_pipeline_xdlops_base.hpp:35
 
conditional_t< std::is_same< ComputeDataType, ck::tf32_t >::value, float, ComputeDataType > ComputeDataTypeBuf
Definition: blockwise_gemm_pipeline_xdlops_base.hpp:58
 
Definition: blockwise_gemm_pipeline_xdlops.hpp:34
 
static constexpr index_t B_LDS_Write_Inst_Num
Definition: blockwise_gemm_pipeline_xdlops.hpp:46
 
static constexpr index_t A_LDS_Read_Inst_Num
Definition: blockwise_gemm_pipeline_xdlops.hpp:49
 
static constexpr index_t A_LDS_Read_Width
Definition: blkgemmpipe_scheduler.hpp:82
 
static constexpr index_t B_LDS_Read_Inst_Num
Definition: blockwise_gemm_pipeline_xdlops.hpp:51
 
static constexpr index_t A_LDS_Write_Inst_Num
Definition: blockwise_gemm_pipeline_xdlops.hpp:44
 
static constexpr index_t C_MFMA_Inst_Num
Definition: blockwise_gemm_pipeline_xdlops.hpp:54
 
static constexpr index_t A_Buffer_Load_Inst_Num
Definition: blockwise_gemm_pipeline_xdlops.hpp:39
 
static constexpr index_t B_Buffer_Load_Inst_Num
Definition: blockwise_gemm_pipeline_xdlops.hpp:41
 
static constexpr index_t B_LDS_Read_Width
Definition: blkgemmpipe_scheduler.hpp:83
 
ck::BlockwiseGemmXdlops_pipeline_v4< BlockGemmPipelineScheduler::Intrawave, BlockSize, ADataType, BDataType, ComputeDataType, AccDataType, ATileDesc, BTileDesc, AMmaTileDesc, BMmaTileDesc, ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerXDL, NPerXDL, MRepeat, NRepeat, KPack >::BlockHasHotloop __host__ static constexpr __device__ bool BlockHasHotloop(index_t num_loop)
Definition: blockwise_gemm_pipeline_xdlops_v4.hpp:152
 
ck::BlockwiseGemmXdlops_pipeline_v4< BlockGemmPipelineScheduler::Intrawave, BlockSize, ADataType, BDataType, ComputeDataType, AccDataType, ATileDesc, BTileDesc, AMmaTileDesc, BMmaTileDesc, ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerXDL, NPerXDL, MRepeat, NRepeat, KPack >::HotLoopScheduler static constexpr __device__ void HotLoopScheduler()
Definition: blockwise_gemm_pipeline_xdlops_v4.hpp:169
 
ck::BlockwiseGemmXdlops_pipeline_v4< BlockGemmPipelineScheduler::Intrawave, BlockSize, ADataType, BDataType, ComputeDataType, AccDataType, ATileDesc, BTileDesc, AMmaTileDesc, BMmaTileDesc, ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerXDL, NPerXDL, MRepeat, NRepeat, KPack >::BlockLoopTailNum __host__ static constexpr __device__ TailNumber BlockLoopTailNum(index_t num_loop)
Definition: blockwise_gemm_pipeline_xdlops_v4.hpp:157
 
ck::BlockwiseGemmXdlops_pipeline_v4< BlockGemmPipelineScheduler::Intrawave, BlockSize, ADataType, BDataType, ComputeDataType, AccDataType, ATileDesc, BTileDesc, AMmaTileDesc, BMmaTileDesc, ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerXDL, NPerXDL, MRepeat, NRepeat, KPack >::Run __device__ void Run(const AGridDesc &a_grid_desc, const ABlockDesc &a_block_desc, ABlockTransfer &a_blockwise_copy, const AGridBuffer &a_grid_buf, ABlockBuffer &a_block_buf, const ABlockTransferStep &a_block_copy_step, const BGridDesc &b_grid_desc, const BBlockDesc &b_block_desc, BBlockTransfer &b_blockwise_copy, const BGridBuffer &b_grid_buf, BBlockBuffer &b_block_buf, const BBlockTransferStep &b_block_copy_step, CThreadBuffer &c_thread_buf, index_t num_loop) const
Definition: blockwise_gemm_pipeline_xdlops_v4.hpp:254
 
Definition: blockwise_gemm_pipeline_xdlops.hpp:103
 
static constexpr auto I1
Definition: blockwise_gemm_pipeline_xdlops.hpp:105
 
static constexpr auto b_thread_desc_
Definition: blockwise_gemm_pipeline_xdlops.hpp:961
 
static constexpr __device__ auto HotLoopScheduler()
Definition: blockwise_gemm_pipeline_xdlops.hpp:373
 
static constexpr auto c_thread_desc_
Definition: blockwise_gemm_pipeline_xdlops.hpp:967
 
BThreadCopy b_thread_copy_
Definition: blockwise_gemm_pipeline_xdlops.hpp:991
 
static constexpr auto I0
Definition: blockwise_gemm_pipeline_xdlops.hpp:104
 
static constexpr AMmaTileDesc a_block_desc_m0_m1_m2_k
Definition: blockwise_gemm_pipeline_xdlops.hpp:453
 
AThreadCopy a_thread_copy_
Definition: blockwise_gemm_pipeline_xdlops.hpp:990
 
static constexpr BMmaTileDesc b_block_desc_n0_n1_n2_k
Definition: blockwise_gemm_pipeline_xdlops.hpp:454
 
static constexpr auto a_thread_desc_
Definition: blockwise_gemm_pipeline_xdlops.hpp:955
 
static constexpr auto xdlops_gemm
Definition: blockwise_gemm_pipeline_xdlops.hpp:120
 
ck::BlockwiseGemmXdlopsDirectLoad_pipeline_v4< BlockGemmPipelineScheduler::Intrawave, BlockSize, ADataType, BDataType, ComputeDataType, AccDataType, ATileDesc, BTileDesc, AMmaTileDesc, BMmaTileDesc, ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerXDL, NPerXDL, MRepeat, NRepeat, KPack >::BlockLoopTailNum __host__ static constexpr __device__ TailNumber BlockLoopTailNum(index_t num_loop)
Definition: blockwise_gemm_pipeline_xdlops_v4.hpp:724
 
ck::BlockwiseGemmXdlopsDirectLoad_pipeline_v4< BlockGemmPipelineScheduler::Intrawave, BlockSize, ADataType, BDataType, ComputeDataType, AccDataType, ATileDesc, BTileDesc, AMmaTileDesc, BMmaTileDesc, ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerXDL, NPerXDL, MRepeat, NRepeat, KPack >::HotLoopScheduler static constexpr __device__ void HotLoopScheduler()
Definition: blockwise_gemm_pipeline_xdlops_v4.hpp:736
 
ck::BlockwiseGemmXdlopsDirectLoad_pipeline_v4< BlockGemmPipelineScheduler::Intrawave, BlockSize, ADataType, BDataType, ComputeDataType, AccDataType, ATileDesc, BTileDesc, AMmaTileDesc, BMmaTileDesc, ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerXDL, NPerXDL, MRepeat, NRepeat, KPack >::BlockHasHotloop __host__ static constexpr __device__ bool BlockHasHotloop(index_t num_loop)
Definition: blockwise_gemm_pipeline_xdlops_v4.hpp:719
 
ck::BlockwiseGemmXdlopsDirectLoad_pipeline_v4< BlockGemmPipelineScheduler::Intrawave, BlockSize, ADataType, BDataType, ComputeDataType, AccDataType, ATileDesc, BTileDesc, AMmaTileDesc, BMmaTileDesc, ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerXDL, NPerXDL, MRepeat, NRepeat, KPack >::Run __device__ void Run(const AGridDesc &a_grid_desc, const ABlockDesc &a_block_desc, ABlockTransfer &a_blockwise_copy, const AGridBuffer &a_grid_buf, ABlockBuffer &a_block_buf, const ABlockTransferStep &a_block_copy_step, const BGridDesc &b_grid_desc, const BBlockDesc &b_block_desc, BBlockTransfer &b_blockwise_copy, const BGridBuffer &b_grid_buf, BBlockBuffer &b_block_buf, const BBlockTransferStep &b_block_copy_step, CThreadBuffer &c_thread_buf, index_t num_loop) const
Definition: blockwise_gemm_pipeline_xdlops_v4.hpp:819
 
Definition: blockwise_gemm_pipeline_xdlops_v4.hpp:604
 
__device__ void Run(const SrcDesc &, const SrcRefToOriginDisplacement &, const SrcBuffer &src_buf, const DstDesc &, const DstOriginIdx &, DstBuffer &dst_buf) const
Definition: threadwise_tensor_slice_transfer.hpp:1293
 
Definition: integral_constant.hpp:20
 
Definition: functional2.hpp:33
 
Definition: dtype_vector.hpp:10