20           typename ComputeDataType,
 
   24           typename AMmaTileDesc,
 
   25           typename BMmaTileDesc,
 
   26           index_t ABlockTransferSrcScalarPerVector,
 
   27           index_t BBlockTransferSrcScalarPerVector,
 
   43           typename ComputeDataType,
 
   47           typename AMmaTileDesc,
 
   48           typename BMmaTileDesc,
 
   49           index_t ABlockTransferSrcScalarPerVector,
 
   50           index_t BBlockTransferSrcScalarPerVector,
 
   71                                        ABlockTransferSrcScalarPerVector,
 
   72                                        BBlockTransferSrcScalarPerVector,
 
   90                                         ABlockTransferSrcScalarPerVector,
 
   91                                         BBlockTransferSrcScalarPerVector,
 
  111                                                    ABlockTransferSrcScalarPerVector,
 
  112                                                    BBlockTransferSrcScalarPerVector,
 
  126     using Base::xdlops_gemm;
 
  129     using Base::CalculateCThreadOriginDataIndex;
 
  130     using Base::CalculateCThreadOriginDataIndex8D;
 
  131     using Base::GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2;
 
  132     using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2;
 
  133     using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4;
 
  134     using Base::GetCThreadBuffer;
 
  135     using Base::GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2;
 
  136     using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2;
 
  137     using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4;
 
  138     using Base::MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2;
 
  139     using Base::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2;
 
  141     using Base::a_block_desc_m0_m1_m2_k;
 
  142     using Base::b_block_desc_n0_n1_n2_k;
 
  144     using Base::AMmaKStride;
 
  145     using Base::BMmaKStride;
 
  156         return num_loop > PrefetchStages;
 
  161         if(num_loop % HotloopUnroll == 1)
 
  176         constexpr 
auto num_ds_read_inst_a =
 
  177             HotLoopInstList::A_LDS_Read_Width * 
sizeof(ADataType) == 16
 
  178                 ? HotLoopInstList::A_LDS_Read_Inst_Num
 
  179                 : HotLoopInstList::A_LDS_Read_Inst_Num / 2;
 
  180         constexpr 
auto num_ds_read_inst_b =
 
  181             HotLoopInstList::B_LDS_Read_Width * 
sizeof(BDataType) == 16
 
  182                 ? HotLoopInstList::B_LDS_Read_Inst_Num
 
  183                 : HotLoopInstList::B_LDS_Read_Inst_Num / 2;
 
  185         constexpr 
auto num_ds_write_inst_a = HotLoopInstList::A_LDS_Write_Inst_Num;
 
  186         constexpr 
auto num_ds_write_inst_b = HotLoopInstList::B_LDS_Write_Inst_Num;
 
  188         constexpr 
auto num_buffer_load_inst_a = HotLoopInstList::A_Buffer_Load_Inst_Num;
 
  189         constexpr 
auto num_buffer_load_inst_b = HotLoopInstList::B_Buffer_Load_Inst_Num;
 
  191         constexpr 
auto num_mfma_inst = HotLoopInstList::C_MFMA_Inst_Num;
 
  193         constexpr 
auto mfma_cycle = HotLoopInstList::C_MFMA_Inst_Cycle;
 
  194         constexpr 
auto ds_read_a_issue_cycle =
 
  195             HotLoopInstList::A_LDS_Read_Width * 
sizeof(ADataType) == 16 ? 8 : 4;
 
  196         constexpr 
auto ds_read_b_issue_cycle =
 
  197             HotLoopInstList::B_LDS_Read_Width * 
sizeof(BDataType) == 16 ? 8 : 4;
 
  198         constexpr 
auto ds_read_a_mfma_rate =
 
  199             (mfma_cycle - 4 + 2 * ds_read_a_issue_cycle - 1) / (2 * ds_read_a_issue_cycle);
 
  200         constexpr 
auto ds_read_b_mfma_rate =
 
  201             (mfma_cycle - 4 + 2 * ds_read_b_issue_cycle - 1) / (2 * ds_read_b_issue_cycle);
 
  203         constexpr 
auto num_dsread_stage1_a = num_ds_read_inst_a / KRepeat * (KRepeat - 1);
 
  204         constexpr 
auto num_dsread_stage1_b = num_ds_read_inst_b / KRepeat * (KRepeat - 1);
 
  205         constexpr 
auto num_dsread_stage3_a = num_ds_read_inst_a / KRepeat;
 
  206         constexpr 
auto num_dsread_stage3_b = num_ds_read_inst_b / KRepeat;
 
  208         constexpr 
auto num_dsread_stage1_a_mfma =
 
  209             (num_dsread_stage1_a + ds_read_a_mfma_rate - 1) / ds_read_a_mfma_rate;
 
  210         constexpr 
auto num_dsread_stage1_b_mfma =
 
  211             (num_dsread_stage1_b + ds_read_b_mfma_rate - 1) / ds_read_b_mfma_rate;
 
  212         constexpr 
auto num_dsread_stage3_a_mfma =
 
  213             (num_dsread_stage3_a + ds_read_a_mfma_rate - 1) / ds_read_a_mfma_rate;
 
  214         constexpr 
auto num_dsread_stage3_b_mfma =
 
  215             (num_dsread_stage3_b + ds_read_b_mfma_rate - 1) / ds_read_b_mfma_rate;
 
  217         constexpr 
auto num_mfma_stage2 = num_mfma_inst - num_ds_read_inst_a / ds_read_a_mfma_rate -
 
  218                                          num_ds_read_inst_b / ds_read_b_mfma_rate;
 
  219         constexpr 
auto num_mfma_per_issue =
 
  220             num_mfma_stage2 / (num_buffer_load_inst_a + num_buffer_load_inst_b);
 
  221         constexpr 
auto num_dswrite_per_issue_a = num_ds_write_inst_a / num_buffer_load_inst_a;
 
  222         constexpr 
auto num_dswrite_per_issue_b = num_ds_write_inst_b / num_buffer_load_inst_b;
 
  227             if constexpr((num_dsread_stage1_a - (i + 1) * ds_read_a_mfma_rate) >=
 
  230                 __builtin_amdgcn_sched_group_barrier(0x100, ds_read_a_mfma_rate, 0); 
 
  234                 __builtin_amdgcn_sched_group_barrier(
 
  236                     num_dsread_stage1_a - (num_dsread_stage1_a_mfma - 1) * ds_read_a_mfma_rate,
 
  239             __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); 
 
  243             if constexpr((num_dsread_stage1_b - (i + 1) * ds_read_b_mfma_rate) >=
 
  246                 __builtin_amdgcn_sched_group_barrier(0x100, ds_read_b_mfma_rate, 0); 
 
  250                 __builtin_amdgcn_sched_group_barrier(
 
  252                     num_dsread_stage1_b - (num_dsread_stage1_b_mfma - 1) * ds_read_b_mfma_rate,
 
  255             __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); 
 
  263                 __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); 
 
  264                 __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); 
 
  266             __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); 
 
  267             __builtin_amdgcn_sched_group_barrier(
 
  268                 0x008, num_mfma_per_issue - num_dswrite_per_issue_a, 0); 
 
  274                 __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); 
 
  275                 __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); 
 
  277             __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); 
 
  278             __builtin_amdgcn_sched_group_barrier(
 
  279                 0x008, num_mfma_per_issue - num_dswrite_per_issue_b, 0); 
 
  285             if constexpr((num_dsread_stage3_a - (i + 1) * ds_read_a_mfma_rate) >=
 
  288                 __builtin_amdgcn_sched_group_barrier(0x100, ds_read_a_mfma_rate, 0); 
 
  292                 __builtin_amdgcn_sched_group_barrier(
 
  294                     num_dsread_stage3_a - (num_dsread_stage3_a_mfma - 1) * ds_read_a_mfma_rate,
 
  297             __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); 
 
  301             if constexpr((num_dsread_stage3_b - (i + 1) * ds_read_b_mfma_rate) >=
 
  304                 __builtin_amdgcn_sched_group_barrier(0x100, ds_read_b_mfma_rate, 0); 
 
  308                 __builtin_amdgcn_sched_group_barrier(
 
  310                     num_dsread_stage3_b - (num_dsread_stage3_b_mfma - 1) * ds_read_b_mfma_rate,
 
  313             __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); 
 
  318         __builtin_amdgcn_sched_barrier(0);
 
  321     template <
bool HasMainLoop,
 
  325               typename ABlockTransfer,
 
  326               typename AGridBuffer,
 
  327               typename ABlockBuffer,
 
  328               typename ABlockTransferStep,
 
  331               typename BBlockTransfer,
 
  332               typename BGridBuffer,
 
  333               typename BBlockBuffer,
 
  334               typename BBlockTransferStep,
 
  335               typename CThreadBuffer>
 
  336     __device__ 
void Run(
const AGridDesc& a_grid_desc,
 
  337                         const ABlockDesc& a_block_desc,
 
  338                         ABlockTransfer& a_blockwise_copy,
 
  339                         const AGridBuffer& a_grid_buf,
 
  340                         ABlockBuffer& a_block_buf,
 
  341                         const ABlockTransferStep& a_block_copy_step,
 
  342                         const BGridDesc& b_grid_desc,
 
  343                         const BBlockDesc& b_block_desc,
 
  344                         BBlockTransfer& b_blockwise_copy,
 
  345                         const BGridBuffer& b_grid_buf,
 
  346                         BBlockBuffer& b_block_buf,
 
  347                         const BBlockTransferStep& b_block_copy_step,
 
  348                         CThreadBuffer& c_thread_buf,
 
  351         auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataTypeBuf>(
 
  352             a_thread_desc_.GetElementSpaceSize());
 
  353         auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataTypeBuf>(
 
  354             b_thread_desc_.GetElementSpaceSize());
 
  357         a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I0);
 
  358         b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf, I0);
 
  360         a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
 
  361         b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
 
  364         a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, I0);
 
  365         b_blockwise_copy.RunWrite(b_block_desc, b_block_buf, I0);
 
  368         a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I0);
 
  369         b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf, I0);
 
  371         a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
 
  372         b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
 
  375         a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I1);
 
  376         b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf, I1);
 
  378         a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
 
  379         b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
 
  382         c_thread_buf.Clear();
 
  387             a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
 
  395             b_thread_copy_.Run(b_block_desc_n0_n1_n2_k,
 
  404         if constexpr(HasMainLoop)
 
  409                 auto LoopFunc = [&](
auto vmem_buf) {
 
  414                         if constexpr(k0 == (KRepeat - 1))
 
  418                             a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, vmem_buf);
 
  419                             b_blockwise_copy.RunWrite(b_block_desc, b_block_buf, vmem_buf);
 
  421                             a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, vmem_buf);
 
  422                             b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf, vmem_buf);
 
  424                             a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
 
  425                             b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
 
  432                                     a_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
 
  433                                         a_thread_buf[
Number<a_thread_desc_.CalculateOffset(
 
  437                                     b_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
 
  438                                         b_thread_buf[
Number<b_thread_desc_.CalculateOffset(
 
  442                                 using mfma_input_type =
 
  444                                                          xdlops_gemm.K1PerXdlops>::type;
 
  447                                     c_thread_desc_.CalculateOffset(
make_tuple(m0, n0, 0));
 
  450                                     a_thread_vec.template AsType<mfma_input_type>(),
 
  451                                     b_thread_vec.template AsType<mfma_input_type>(),
 
  456                                 a_block_desc_m0_m1_m2_k,
 
  466                                 b_block_desc_n0_n1_n2_k,
 
  482             } 
while(i < (num_loop - PrefetchStages));
 
  485         auto ReadWriteCompFunc = [&](
auto vmem_buf) {
 
  490                 if constexpr(k0 == (KRepeat - 1))
 
  494                     a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, vmem_buf);
 
  495                     b_blockwise_copy.RunWrite(b_block_desc, b_block_buf, vmem_buf);
 
  502                             a_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
 
  503                                 a_thread_buf[
Number<a_thread_desc_.CalculateOffset(
 
  507                             b_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
 
  508                                 b_thread_buf[
Number<b_thread_desc_.CalculateOffset(
 
  512                         using mfma_input_type =
 
  516                             c_thread_desc_.CalculateOffset(
make_tuple(m0, n0, 0));
 
  518                         xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
 
  519                                         b_thread_vec.template AsType<mfma_input_type>(),
 
  523                         a_block_desc_m0_m1_m2_k,
 
  533                         b_block_desc_n0_n1_n2_k,
 
  544         auto ReadCompFunc = [&]() {
 
  548             static_for<0, KRepeat - 1, 1>{}([&](
auto k0) {
 
  552                             a_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
 
  553                                 a_thread_buf[
Number<a_thread_desc_.CalculateOffset(
 
  557                             b_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
 
  558                                 b_thread_buf[
Number<b_thread_desc_.CalculateOffset(
 
  562                         using mfma_input_type =
 
  566                             c_thread_desc_.CalculateOffset(
make_tuple(m0, n0, 0));
 
  568                         xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
 
  569                                         b_thread_vec.template AsType<mfma_input_type>(),
 
  574                         a_block_desc_m0_m1_m2_k,
 
  584                         b_block_desc_n0_n1_n2_k,
 
  596                         a_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) = a_thread_buf
 
  600                         b_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) = b_thread_buf
 
  604                     using mfma_input_type =
 
  608                         c_thread_desc_.CalculateOffset(
make_tuple(m0, n0, 0));
 
  610                     xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
 
  611                                     b_thread_vec.template AsType<mfma_input_type>(),
 
  621             ReadWriteCompFunc(I0);
 
  622             ReadWriteCompFunc(I1);
 
  627             ReadWriteCompFunc(I0);
 
  634     static constexpr 
auto a_thread_desc_ =
 
  638     static constexpr 
auto b_thread_desc_ =
 
  643                                                          decltype(a_block_desc_m0_m1_m2_k),
 
  644                                                          decltype(a_thread_desc_),
 
  653                                                          decltype(b_block_desc_n0_n1_n2_k),
 
  654                                                          decltype(b_thread_desc_),
 
  661     AThreadCopy a_thread_copy_{Base::CalculateAThreadOriginDataIndex()};
 
  662     BThreadCopy b_thread_copy_{Base::CalculateBThreadOriginDataIndex()};
 
  663     using Base::c_thread_desc_;
 
__host__ constexpr __device__ auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition: tensor_descriptor_helper.hpp:101
 
TailNumber
Definition: blkgemmpipe_scheduler.hpp:31
 
constexpr detail::ignore_t ignore
Definition: ignore.hpp:20
 
BlockGemmPipelineScheduler
Definition: blkgemmpipe_scheduler.hpp:25
 
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
 
int32_t index_t
Definition: ck.hpp:299
 
__device__ void block_sync_lds()
Definition: synchronization.hpp:16
 
Definition: blockwise_gemm_pipeline_xdlops_base.hpp:35
 
conditional_t< std::is_same< ComputeDataType, ck::tf32_t >::value, float, ComputeDataType > ComputeDataTypeBuf
Definition: blockwise_gemm_pipeline_xdlops_base.hpp:58
 
Definition: blockwise_gemm_pipeline_xdlops.hpp:34
 
ck::BlockwiseGemmXdlops_pipeline_v5< BlockGemmPipelineScheduler::Intrawave, BlockSize, ADataType, BDataType, ComputeDataType, AccDataType, ATileDesc, BTileDesc, AMmaTileDesc, BMmaTileDesc, ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerXDL, NPerXDL, MRepeat, NRepeat, KPack >::HotLoopScheduler static constexpr __device__ auto HotLoopScheduler()
Definition: blockwise_gemm_pipeline_xdlops_v5.hpp:171
 
ck::BlockwiseGemmXdlops_pipeline_v5< BlockGemmPipelineScheduler::Intrawave, BlockSize, ADataType, BDataType, ComputeDataType, AccDataType, ATileDesc, BTileDesc, AMmaTileDesc, BMmaTileDesc, ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerXDL, NPerXDL, MRepeat, NRepeat, KPack >::Run __device__ void Run(const AGridDesc &a_grid_desc, const ABlockDesc &a_block_desc, ABlockTransfer &a_blockwise_copy, const AGridBuffer &a_grid_buf, ABlockBuffer &a_block_buf, const ABlockTransferStep &a_block_copy_step, const BGridDesc &b_grid_desc, const BBlockDesc &b_block_desc, BBlockTransfer &b_blockwise_copy, const BGridBuffer &b_grid_buf, BBlockBuffer &b_block_buf, const BBlockTransferStep &b_block_copy_step, CThreadBuffer &c_thread_buf, index_t num_loop) const
Definition: blockwise_gemm_pipeline_xdlops_v5.hpp:336
 
ck::BlockwiseGemmXdlops_pipeline_v5< BlockGemmPipelineScheduler::Intrawave, BlockSize, ADataType, BDataType, ComputeDataType, AccDataType, ATileDesc, BTileDesc, AMmaTileDesc, BMmaTileDesc, ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerXDL, NPerXDL, MRepeat, NRepeat, KPack >::BlockLoopTailNum static constexpr __host__ TailNumber BlockLoopTailNum(index_t num_loop)
Definition: blockwise_gemm_pipeline_xdlops_v5.hpp:159
 
ck::BlockwiseGemmXdlops_pipeline_v5< BlockGemmPipelineScheduler::Intrawave, BlockSize, ADataType, BDataType, ComputeDataType, AccDataType, ATileDesc, BTileDesc, AMmaTileDesc, BMmaTileDesc, ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerXDL, NPerXDL, MRepeat, NRepeat, KPack >::BlockHasHotloop static constexpr __host__ bool BlockHasHotloop(index_t num_loop)
Definition: blockwise_gemm_pipeline_xdlops_v5.hpp:154
 
Definition: blockwise_gemm_pipeline_xdlops_v5.hpp:37
 
Definition: sequence.hpp:43
 
ck::ThreadwiseTensorSliceTransfer_v4< ADataType, ComputeDataTypeBuf, decltype(a_block_desc_m0_m1_m2_k), decltype(a_thread_desc_), Sequence< 1, 1, 1, KPack >, Sequence< 0, 1, 2, 3 >, 3, A_K1, A_K1 >  
Definition: integral_constant.hpp:20
 
Definition: functional2.hpp:33
 
Definition: dtype_vector.hpp:10