20           typename ComputeDataType,
 
   24           typename AMmaTileDesc,
 
   25           typename BMmaTileDesc,
 
   26           index_t ABlockTransferSrcScalarPerVector,
 
   27           index_t BBlockTransferSrcScalarPerVector,
 
   43           typename ComputeDataType,
 
   47           typename AMmaTileDesc,
 
   48           typename BMmaTileDesc,
 
   49           index_t ABlockTransferSrcScalarPerVector,
 
   50           index_t BBlockTransferSrcScalarPerVector,
 
   71                                        ABlockTransferSrcScalarPerVector,
 
   72                                        BBlockTransferSrcScalarPerVector,
 
   90                                         ABlockTransferSrcScalarPerVector,
 
   91                                         BBlockTransferSrcScalarPerVector,
 
  111                                                    ABlockTransferSrcScalarPerVector,
 
  112                                                    BBlockTransferSrcScalarPerVector,
 
  123     using Base::xdlops_gemm;
 
  125     using Base::CalculateCThreadOriginDataIndex;
 
  126     using Base::CalculateCThreadOriginDataIndex8D;
 
  127     using Base::GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2;
 
  128     using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2;
 
  129     using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4;
 
  130     using Base::GetCThreadBuffer;
 
  131     using Base::GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2;
 
  132     using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2;
 
  133     using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4;
 
  134     using Base::MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2;
 
  135     using Base::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2;
 
  137     using Base::a_block_desc_m0_m1_m2_k;
 
  138     using Base::b_block_desc_n0_n1_n2_k;
 
  140     using Base::AMmaKStride;
 
  141     using Base::BMmaKStride;
 
  142     using Base::WaveSize;
 
  147         (4 * WaveSize / BlockSize) >= 1 ? 4 * WaveSize / BlockSize : 1;
 
  150         (MPerBlock * 
sizeof(ADataType) + NPerBlock * 
sizeof(BDataType)) * KPerBlock);
 
  152         FullMemBandPrefetchStages >= 2
 
  153             ? FullMemBandPrefetchStages <= 8 ? FullMemBandPrefetchStages : 8
 
  157     static constexpr 
index_t GlobalBufferNum = PrefetchStages;
 
  161         return num_loop > PrefetchStages;
 
  166         if(num_loop % PrefetchStages == 1)
 
  170         else if(num_loop % PrefetchStages == 2)
 
  174         else if(num_loop % PrefetchStages == 3)
 
  178         else if(num_loop % PrefetchStages == 4)
 
  182         else if(num_loop % PrefetchStages == 5)
 
  186         else if(num_loop % PrefetchStages == 6)
 
  190         else if(num_loop % PrefetchStages == 7)
 
  200     template <
bool HasMainLoop,
 
  204               typename ABlockTransfer,
 
  205               typename AGridBuffer,
 
  206               typename ABlockBuffer,
 
  207               typename ABlockTransferStep,
 
  210               typename BBlockTransfer,
 
  211               typename BGridBuffer,
 
  212               typename BBlockBuffer,
 
  213               typename BBlockTransferStep,
 
  214               typename CThreadBuffer>
 
  215     __device__ 
void Run(
const AGridDesc& a_grid_desc,
 
  216                         const ABlockDesc& a_block_desc,
 
  217                         ABlockTransfer& a_blockwise_copy,
 
  218                         const AGridBuffer& a_grid_buf,
 
  219                         ABlockBuffer& a_block_buf,
 
  220                         const ABlockTransferStep& a_block_copy_step,
 
  221                         const BGridDesc& b_grid_desc,
 
  222                         const BBlockDesc& b_block_desc,
 
  223                         BBlockTransfer& b_blockwise_copy,
 
  224                         const BGridBuffer& b_grid_buf,
 
  225                         BBlockBuffer& b_block_buf,
 
  226                         const BBlockTransferStep& b_block_copy_step,
 
  227                         CThreadBuffer& c_thread_buf,
 
  230         auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataTypeBuf>(
 
  231             a_thread_desc_.GetElementSpaceSize());
 
  232         auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataTypeBuf>(
 
  233             b_thread_desc_.GetElementSpaceSize());
 
  236         a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I0);
 
  237         b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf, I0);
 
  239         a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
 
  240         b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
 
  243         c_thread_buf.Clear();
 
  246         a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, I0);
 
  247         b_blockwise_copy.RunWrite(b_block_desc, b_block_buf, I0);
 
  251             a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, iprefetch);
 
  252             b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf, iprefetch);
 
  254             a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
 
  255             b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
 
  259         if constexpr(HasMainLoop)
 
  269                             a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
 
  277                             b_thread_copy_.Run(b_block_desc_n0_n1_n2_k,
 
  293                                     a_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
 
  294                                         a_thread_buf[
Number<a_thread_desc_.CalculateOffset(
 
  296                                     b_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
 
  297                                         b_thread_buf[
Number<b_thread_desc_.CalculateOffset(
 
  301                                 using mfma_input_type =
 
  303                                                          xdlops_gemm.K1PerXdlops>::type;
 
  306                                     c_thread_desc_.CalculateOffset(
make_tuple(m0, n0, 0));
 
  309                                     a_thread_vec.template AsType<mfma_input_type>(),
 
  310                                     b_thread_vec.template AsType<mfma_input_type>(),
 
  317                     a_blockwise_copy.RunWrite(
 
  318                         a_block_desc, a_block_buf, 
Number<(iprefetch + 1) % PrefetchStages>{});
 
  319                     b_blockwise_copy.RunWrite(
 
  320                         b_block_desc, b_block_buf, 
Number<(iprefetch + 1) % PrefetchStages>{});
 
  322                     a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, iprefetch);
 
  323                     b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf, iprefetch);
 
  325                     a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
 
  326                     b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
 
  330             } 
while(i < (num_loop - PrefetchStages));
 
  335         auto LoopTailFunc = [&](
auto tail_num) {
 
  340                         a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
 
  348                         b_thread_copy_.Run(b_block_desc_n0_n1_n2_k,
 
  364                                 a_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
 
  365                                     a_thread_buf[
Number<a_thread_desc_.CalculateOffset(
 
  367                                 b_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
 
  368                                     b_thread_buf[
Number<b_thread_desc_.CalculateOffset(
 
  372                             using mfma_input_type =
 
  374                                                      xdlops_gemm.K1PerXdlops>::type;
 
  377                                 c_thread_desc_.CalculateOffset(
make_tuple(m0, n0, 0));
 
  380                                 a_thread_vec.template AsType<mfma_input_type>(),
 
  381                                 b_thread_vec.template AsType<mfma_input_type>(),
 
  388                 a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, iprefetch);
 
  389                 b_blockwise_copy.RunWrite(b_block_desc, b_block_buf, iprefetch);
 
  395                     a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
 
  403                     b_thread_copy_.Run(b_block_desc_n0_n1_n2_k,
 
  419                             a_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
 
  420                                 a_thread_buf[
Number<a_thread_desc_.CalculateOffset(
 
  422                             b_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
 
  423                                 b_thread_buf[
Number<b_thread_desc_.CalculateOffset(
 
  427                         using mfma_input_type =
 
  431                             c_thread_desc_.CalculateOffset(
make_tuple(m0, n0, 0));
 
  433                         xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
 
  434                                         b_thread_vec.template AsType<mfma_input_type>(),
 
  446                     a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
 
  454                     b_thread_copy_.Run(b_block_desc_n0_n1_n2_k,
 
  470                             a_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
 
  471                                 a_thread_buf[
Number<a_thread_desc_.CalculateOffset(
 
  473                             b_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
 
  474                                 b_thread_buf[
Number<b_thread_desc_.CalculateOffset(
 
  478                         using mfma_input_type =
 
  482                             c_thread_desc_.CalculateOffset(
make_tuple(m0, n0, 0));
 
  484                         xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
 
  485                                         b_thread_vec.template AsType<mfma_input_type>(),
 
  522     using Base::a_thread_copy_;
 
  523     using Base::a_thread_desc_;
 
  524     using Base::b_thread_copy_;
 
  525     using Base::b_thread_desc_;
 
  526     using Base::c_thread_desc_;
 
  532           typename ComputeDataType,
 
  533           typename AccDataType,
 
  536           typename AMmaTileDesc,
 
  537           typename BMmaTileDesc,
 
  538           index_t ABlockTransferSrcScalarPerVector,
 
  539           index_t BBlockTransferSrcScalarPerVector,
 
  560                                        ABlockTransferSrcScalarPerVector,
 
  561                                        BBlockTransferSrcScalarPerVector,
 
  579                                         ABlockTransferSrcScalarPerVector,
 
  580                                         BBlockTransferSrcScalarPerVector,
 
  600                                                    ABlockTransferSrcScalarPerVector,
 
  601                                                    BBlockTransferSrcScalarPerVector,
 
  614     using Base::KPerThread;
 
  615     using Base::xdlops_gemm;
 
  617     using Base::CalculateCThreadOriginDataIndex;
 
  618     using Base::CalculateCThreadOriginDataIndex8D;
 
  619     using Base::GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2;
 
  620     using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2;
 
  621     using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4;
 
  622     using Base::GetCThreadBuffer;
 
  623     using Base::GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2;
 
  624     using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2;
 
  625     using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4;
 
  626     using Base::MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2;
 
  627     using Base::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2;
 
  629     using Base::a_block_desc_m0_m1_m2_k;
 
  630     using Base::b_block_desc_n0_n1_n2_k;
 
  631     using Base::WaveSize;
 
  637     static constexpr 
index_t KRepeat        = KPerThread / KPerInnerLoop;
 
  640         (4 * WaveSize / BlockSize) >= 1 ? 4 * WaveSize / BlockSize : 1;
 
  643         (MPerBlock * 
sizeof(ADataType) + NPerBlock * 
sizeof(BDataType)) * KPerBlock);
 
  645         FullMemBandPrefetchStages >= 2
 
  646             ? FullMemBandPrefetchStages <= 8 ? FullMemBandPrefetchStages : 8
 
  650     static constexpr 
index_t GlobalBufferNum = PrefetchStages;
 
  654         return num_loop > PrefetchStages;
 
  659         if(num_loop % PrefetchStages == 1)
 
  663         else if(num_loop % PrefetchStages == 2)
 
  667         else if(num_loop % PrefetchStages == 3)
 
  671         else if(num_loop % PrefetchStages == 4)
 
  675         else if(num_loop % PrefetchStages == 5)
 
  679         else if(num_loop % PrefetchStages == 6)
 
  683         else if(num_loop % PrefetchStages == 7)
 
  693     template <
bool HasMainLoop,
 
  697               typename ABlockTransfer,
 
  698               typename AGridBuffer,
 
  699               typename ABlockBuffer,
 
  700               typename ABlockTransferStep,
 
  703               typename BBlockTransfer,
 
  704               typename BGridBuffer,
 
  705               typename BBlockBuffer,
 
  706               typename BBlockTransferStep,
 
  707               typename CThreadBuffer>
 
  708     __device__ 
void Run(
const AGridDesc& a_grid_desc,
 
  709                         const ABlockDesc& a_block_desc,
 
  710                         ABlockTransfer& a_blockwise_copy,
 
  711                         const AGridBuffer& a_grid_buf,
 
  712                         ABlockBuffer& a_block_buf,
 
  713                         const ABlockTransferStep& a_block_copy_step,
 
  714                         const BGridDesc& b_grid_desc,
 
  715                         const BBlockDesc& b_block_desc,
 
  716                         BBlockTransfer& b_blockwise_copy,
 
  717                         const BGridBuffer& b_grid_buf,
 
  718                         BBlockBuffer& b_block_buf,
 
  719                         const BBlockTransferStep& b_block_copy_step,
 
  720                         CThreadBuffer& c_thread_buf,
 
  723         auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataTypeBuf>(
 
  724             a_thread_desc_.GetElementSpaceSize());
 
  725         auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataTypeBuf>(
 
  726             b_thread_desc_.GetElementSpaceSize());
 
  729         a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I0);
 
  730         b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf, I0);
 
  732         a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
 
  733         b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
 
  736         c_thread_buf.Clear();
 
  739         a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, I0);
 
  740         b_blockwise_copy.RunWrite(b_block_desc, b_block_buf, I0);
 
  744             a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, iprefetch);
 
  745             b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf, iprefetch);
 
  747             a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
 
  748             b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
 
  752         if constexpr(HasMainLoop)
 
  762                             a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
 
  770                             b_thread_copy_.Run(b_block_desc_n0_n1_n2_k,
 
  777                         __builtin_amdgcn_sched_barrier(0);
 
  785                         if constexpr(k0.value != 0 || KRepeat == 1)
 
  787                             __builtin_amdgcn_s_barrier();
 
  788                             __builtin_amdgcn_sched_barrier(0);
 
  797                                         a_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
 
  798                                             a_thread_buf[
Number<a_thread_desc_.CalculateOffset(
 
  800                                         b_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
 
  801                                             b_thread_buf[
Number<b_thread_desc_.CalculateOffset(
 
  805                                     using mfma_input_type =
 
  807                                                              xdlops_gemm.K1PerXdlops>::type;
 
  810                                         c_thread_desc_.CalculateOffset(
make_tuple(m0, n0, 0));
 
  818                                     if constexpr(k0.value == KRepeat - 1 &&
 
  819                                                  k_.value == KPerInnerLoop - KPack &&
 
  820                                                  m0.value == MRepeat - 1 && n0.value == NRepeat - 1)
 
  822                                         __builtin_amdgcn_sched_barrier(0);
 
  824                                         __builtin_amdgcn_sched_barrier(0);
 
  827                                         a_thread_vec.template AsType<mfma_input_type>(),
 
  828                                         b_thread_vec.template AsType<mfma_input_type>(),
 
  830                                     if constexpr(k_.value == 0 && m0.value == 0 && n0.value == 0)
 
  832                                         __builtin_amdgcn_sched_barrier(0);
 
  833                                         __builtin_amdgcn_s_setprio(1);
 
  834                                         __builtin_amdgcn_sched_barrier(0);
 
  839                         __builtin_amdgcn_sched_barrier(0);
 
  840                         __builtin_amdgcn_s_setprio(0);
 
  841                         __builtin_amdgcn_sched_barrier(0);
 
  845                     a_blockwise_copy.RunWrite(
 
  846                         a_block_desc, a_block_buf, 
Number<(iprefetch + 1) % PrefetchStages>{});
 
  847                     b_blockwise_copy.RunWrite(
 
  848                         b_block_desc, b_block_buf, 
Number<(iprefetch + 1) % PrefetchStages>{});
 
  850                     a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, iprefetch);
 
  851                     b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf, iprefetch);
 
  853                     a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
 
  854                     b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
 
  857             } 
while(i < (num_loop - PrefetchStages));
 
  862         auto LoopTailFunc = [&](
auto tail_num) {
 
  867                         a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
 
  875                         b_thread_copy_.Run(b_block_desc_n0_n1_n2_k,
 
  883                     __builtin_amdgcn_sched_barrier(0);
 
  884                     if constexpr(k0.value != 0 || KRepeat == 1)
 
  886                         __builtin_amdgcn_s_barrier();
 
  887                         __builtin_amdgcn_sched_barrier(0);
 
  896                                     a_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
 
  897                                         a_thread_buf[
Number<a_thread_desc_.CalculateOffset(
 
  899                                     b_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
 
  900                                         b_thread_buf[
Number<b_thread_desc_.CalculateOffset(
 
  904                                 using mfma_input_type =
 
  906                                                          xdlops_gemm.K1PerXdlops>::type;
 
  909                                     c_thread_desc_.CalculateOffset(
make_tuple(m0, n0, 0));
 
  911                                 if constexpr(k0.value == KRepeat - 1 &&
 
  912                                              k_.value == KPerInnerLoop - KPack &&
 
  913                                              m0.value == MRepeat - 1 && n0.value == NRepeat - 1)
 
  915                                     __builtin_amdgcn_sched_barrier(0);
 
  917                                     __builtin_amdgcn_sched_barrier(0);
 
  920                                     a_thread_vec.template AsType<mfma_input_type>(),
 
  921                                     b_thread_vec.template AsType<mfma_input_type>(),
 
  923                                 if constexpr(k_.value == 0 && m0.value == 0 && n0.value == 0)
 
  925                                     __builtin_amdgcn_sched_barrier(0);
 
  926                                     __builtin_amdgcn_s_setprio(1);
 
  927                                     __builtin_amdgcn_sched_barrier(0);
 
  932                     __builtin_amdgcn_sched_barrier(0);
 
  933                     __builtin_amdgcn_s_setprio(0);
 
  934                     __builtin_amdgcn_sched_barrier(0);
 
  937                 a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, iprefetch);
 
  938                 b_blockwise_copy.RunWrite(b_block_desc, b_block_buf, iprefetch);
 
  943                     a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
 
  951                     b_thread_copy_.Run(b_block_desc_n0_n1_n2_k,
 
  959                 __builtin_amdgcn_sched_barrier(0);
 
  960                 if constexpr(k0.value != 0 || KRepeat == 1)
 
  962                     __builtin_amdgcn_s_barrier();
 
  963                     __builtin_amdgcn_sched_barrier(0);
 
  972                                 a_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
 
  973                                     a_thread_buf[
Number<a_thread_desc_.CalculateOffset(
 
  975                                 b_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
 
  976                                     b_thread_buf[
Number<b_thread_desc_.CalculateOffset(
 
  980                             using mfma_input_type =
 
  982                                                      xdlops_gemm.K1PerXdlops>::type;
 
  985                                 c_thread_desc_.CalculateOffset(
make_tuple(m0, n0, 0));
 
  987                             if constexpr(k0.value == KRepeat - 1 &&
 
  988                                          k_.value == KPerInnerLoop - KPack &&
 
  989                                          m0.value == MRepeat - 1 && n0.value == NRepeat - 1)
 
  991                                 __builtin_amdgcn_sched_barrier(0);
 
  993                                 __builtin_amdgcn_sched_barrier(0);
 
  996                                 a_thread_vec.template AsType<mfma_input_type>(),
 
  997                                 b_thread_vec.template AsType<mfma_input_type>(),
 
  999                             if constexpr(k_.value == 0 && m0.value == 0 && n0.value == 0)
 
 1001                                 __builtin_amdgcn_sched_barrier(0);
 
 1002                                 __builtin_amdgcn_s_setprio(1);
 
 1003                                 __builtin_amdgcn_sched_barrier(0);
 
 1008                 __builtin_amdgcn_sched_barrier(0);
 
 1009                 __builtin_amdgcn_s_setprio(0);
 
 1010                 __builtin_amdgcn_sched_barrier(0);
 
 1019                     a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
 
 1027                     b_thread_copy_.Run(b_block_desc_n0_n1_n2_k,
 
 1035                 __builtin_amdgcn_sched_barrier(0);
 
 1036                 if constexpr(k0.value != 0 || KRepeat == 1)
 
 1038                     __builtin_amdgcn_s_barrier();
 
 1039                     __builtin_amdgcn_sched_barrier(0);
 
 1048                                 a_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
 
 1049                                     a_thread_buf[
Number<a_thread_desc_.CalculateOffset(
 
 1051                                 b_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
 
 1052                                     b_thread_buf[
Number<b_thread_desc_.CalculateOffset(
 
 1056                             using mfma_input_type =
 
 1058                                                      xdlops_gemm.K1PerXdlops>::type;
 
 1061                                 c_thread_desc_.CalculateOffset(
make_tuple(m0, n0, 0));
 
 1063                             if constexpr(k0.value == KRepeat - 1 &&
 
 1064                                          k_.value == KPerInnerLoop - KPack &&
 
 1065                                          m0.value == MRepeat - 1 && n0.value == NRepeat - 1)
 
 1067                                 __builtin_amdgcn_sched_barrier(0);
 
 1069                                 __builtin_amdgcn_sched_barrier(0);
 
 1072                                 a_thread_vec.template AsType<mfma_input_type>(),
 
 1073                                 b_thread_vec.template AsType<mfma_input_type>(),
 
 1075                             if constexpr(k_.value == 0 && m0.value == 0 && n0.value == 0)
 
 1077                                 __builtin_amdgcn_sched_barrier(0);
 
 1078                                 __builtin_amdgcn_s_setprio(1);
 
 1079                                 __builtin_amdgcn_sched_barrier(0);
 
 1084                 __builtin_amdgcn_sched_barrier(0);
 
 1085                 __builtin_amdgcn_s_setprio(0);
 
 1086                 __builtin_amdgcn_sched_barrier(0);
 
 1124                    Number<KRepeat * MRepeat * KPerInnerLoop>{},
 
 1125                    Number<MRepeat * KPerInnerLoop>{},
 
 1131                    Number<KRepeat * NRepeat * KPerInnerLoop>{},
 
 1132                    Number<NRepeat * KPerInnerLoop>{},
 
 1137                                                          decltype(a_block_desc_m0_m1_m2_k),
 
 1138                                                          decltype(a_thread_desc_),
 
 1147                                                          decltype(b_block_desc_n0_n1_n2_k),
 
 1148                                                          decltype(b_thread_desc_),
 
 1157     using Base::c_thread_desc_;
 
#define CK_EXPERIMENTAL_INTER_WAVE_SCHEDULING_MAC_CLUSTERS
Definition: ck.hpp:209
 
__host__ constexpr __device__ auto integer_divide_ceil(X x, Y y)
Definition: math.hpp:72
 
__host__ constexpr __device__ T max(T x)
Definition: math.hpp:84
 
__host__ constexpr __device__ auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition: tensor_descriptor_helper.hpp:49
 
TailNumber
Definition: blkgemmpipe_scheduler.hpp:31
 
BlockGemmPipelineScheduler
Definition: blkgemmpipe_scheduler.hpp:25
 
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
 
int32_t index_t
Definition: ck.hpp:299
 
__device__ void block_sync_lds()
Definition: synchronization.hpp:16
 
Definition: blockwise_gemm_pipeline_xdlops_base.hpp:35
 
conditional_t< std::is_same< ComputeDataType, ck::tf32_t >::value, float, ComputeDataType > ComputeDataTypeBuf
Definition: blockwise_gemm_pipeline_xdlops_base.hpp:58
 
ck::BlockwiseGemmXdlops_pipeline_v2< BlockGemmPipelineScheduler::Interwave, BlockSize, ADataType, BDataType, ComputeDataType, AccDataType, ATileDesc, BTileDesc, AMmaTileDesc, BMmaTileDesc, ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerXDL, NPerXDL, MRepeat, NRepeat, KPack >::Run __device__ void Run(const AGridDesc &a_grid_desc, const ABlockDesc &a_block_desc, ABlockTransfer &a_blockwise_copy, const AGridBuffer &a_grid_buf, ABlockBuffer &a_block_buf, const ABlockTransferStep &a_block_copy_step, const BGridDesc &b_grid_desc, const BBlockDesc &b_block_desc, BBlockTransfer &b_blockwise_copy, const BGridBuffer &b_grid_buf, BBlockBuffer &b_block_buf, const BBlockTransferStep &b_block_copy_step, CThreadBuffer &c_thread_buf, index_t num_loop) const
Definition: blockwise_gemm_pipeline_xdlops_v2.hpp:708
 
ck::BlockwiseGemmXdlops_pipeline_v2< BlockGemmPipelineScheduler::Interwave, BlockSize, ADataType, BDataType, ComputeDataType, AccDataType, ATileDesc, BTileDesc, AMmaTileDesc, BMmaTileDesc, ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerXDL, NPerXDL, MRepeat, NRepeat, KPack >::BlockLoopTailNum __host__ static constexpr __device__ TailNumber BlockLoopTailNum(index_t num_loop)
Definition: blockwise_gemm_pipeline_xdlops_v2.hpp:657
 
ck::BlockwiseGemmXdlops_pipeline_v2< BlockGemmPipelineScheduler::Interwave, BlockSize, ADataType, BDataType, ComputeDataType, AccDataType, ATileDesc, BTileDesc, AMmaTileDesc, BMmaTileDesc, ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerXDL, NPerXDL, MRepeat, NRepeat, KPack >::BlockHasHotloop __host__ static constexpr __device__ bool BlockHasHotloop(index_t num_loop)
Definition: blockwise_gemm_pipeline_xdlops_v2.hpp:652
 
ck::BlockwiseGemmXdlops_pipeline_v2< BlockGemmPipelineScheduler::Intrawave, BlockSize, ADataType, BDataType, ComputeDataType, AccDataType, ATileDesc, BTileDesc, AMmaTileDesc, BMmaTileDesc, ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerXDL, NPerXDL, MRepeat, NRepeat, KPack >::Run __device__ void Run(const AGridDesc &a_grid_desc, const ABlockDesc &a_block_desc, ABlockTransfer &a_blockwise_copy, const AGridBuffer &a_grid_buf, ABlockBuffer &a_block_buf, const ABlockTransferStep &a_block_copy_step, const BGridDesc &b_grid_desc, const BBlockDesc &b_block_desc, BBlockTransfer &b_blockwise_copy, const BGridBuffer &b_grid_buf, BBlockBuffer &b_block_buf, const BBlockTransferStep &b_block_copy_step, CThreadBuffer &c_thread_buf, index_t num_loop) const
Definition: blockwise_gemm_pipeline_xdlops_v2.hpp:215
 
ck::BlockwiseGemmXdlops_pipeline_v2< BlockGemmPipelineScheduler::Intrawave, BlockSize, ADataType, BDataType, ComputeDataType, AccDataType, ATileDesc, BTileDesc, AMmaTileDesc, BMmaTileDesc, ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerXDL, NPerXDL, MRepeat, NRepeat, KPack >::BlockLoopTailNum __host__ static constexpr __device__ TailNumber BlockLoopTailNum(index_t num_loop)
Definition: blockwise_gemm_pipeline_xdlops_v2.hpp:164
 
ck::BlockwiseGemmXdlops_pipeline_v2< BlockGemmPipelineScheduler::Intrawave, BlockSize, ADataType, BDataType, ComputeDataType, AccDataType, ATileDesc, BTileDesc, AMmaTileDesc, BMmaTileDesc, ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerXDL, NPerXDL, MRepeat, NRepeat, KPack >::BlockHasHotloop __host__ static constexpr __device__ bool BlockHasHotloop(index_t num_loop)
Definition: blockwise_gemm_pipeline_xdlops_v2.hpp:159
 
Definition: blockwise_gemm_pipeline_xdlops_v2.hpp:37
 
Definition: sequence.hpp:43
 
ck::ThreadwiseTensorSliceTransfer_v4< ADataType, ComputeDataTypeBuf, decltype(a_block_desc_m0_m1_m2_k), decltype(a_thread_desc_), Sequence< 1, 1, 1, KPack >, Sequence< 0, 1, 2, 3 >, 3, A_K1, A_K1 >  
Definition: integral_constant.hpp:20
 
Definition: functional2.hpp:33
 
Definition: dtype_vector.hpp:10