20           typename ComputeTypeA,
 
   21           typename ComputeTypeB,
 
   23           typename AWmmaTileDesc,
 
   24           typename BWmmaTileDesc,
 
   25           index_t ABlockTransferSrcScalarPerVector,
 
   26           index_t BBlockTransferSrcScalarPerVector,
 
   35           bool TransposeC = 
false>
 
   43           typename ComputeTypeA,
 
   44           typename ComputeTypeB,
 
   46           typename AWmmaTileDesc,
 
   47           typename BWmmaTileDesc,
 
   48           index_t ABlockTransferSrcScalarPerVector,
 
   49           index_t BBlockTransferSrcScalarPerVector,
 
   68                                         ABlockTransferSrcScalarPerVector,
 
   69                                         BBlockTransferSrcScalarPerVector,
 
   87                                          ABlockTransferSrcScalarPerVector,
 
   88                                          BBlockTransferSrcScalarPerVector,
 
  107                                                     ABlockTransferSrcScalarPerVector,
 
  108                                                     BBlockTransferSrcScalarPerVector,
 
  127     using Base::wmma_gemm;
 
  130     using Base::CalculateCThreadOriginDataIndex;
 
  132         GetCBlockDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs;
 
  133     using Base::GetCThreadBuffer;
 
  135         GetCThreadDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs;
 
  137         GetCThreadDescriptor_MRepeat_MWave_MThreadPerSubGroup_NRepeat_NWave_NSubGroup_NAccVgprs;
 
  139     using Base::a_block_desc_k0_m0_m1_m2_k1;
 
  140     using Base::b_block_desc_k0_n0_n1_n2_k1;
 
  142     using typename Base::Empty;
 
  150         return num_loop > PrefetchStages;
 
  155         if(BlockHasHotloop(num_loop))
 
  280     template <
typename ABlockBuffer,
 
  281               typename AThreadBuffer,
 
  282               typename BBlockBuffer,
 
  283               typename BThreadBuffer,
 
  284               typename BScaleStruct>
 
  285     __device__ 
inline void LocalLoad(ABlockBuffer& a_block_buf,
 
  286                                      AThreadBuffer& a_thread_buf,
 
  287                                      BBlockBuffer& b_block_buf,
 
  288                                      BThreadBuffer& b_thread_buf,
 
  289                                      BScaleStruct& b_scale_struct)
 const 
  294                     a_block_desc_k0_m0_m1_m2_k1,
 
  302             if constexpr(ck::is_same_v<BScaleStruct, Empty>)
 
  306                         b_block_desc_k0_n0_n1_n2_k1,
 
  318                         b_block_desc_k0_n0_n1_n2_k1,
 
  321                         b_scale_struct.b_scale_thread_bufs(
 
  322                             I0)[
Number<n0 * BScaleStruct::num_scale_k_block +
 
  323                                        k0 / BScaleStruct::num_scale_krepeat>{}],
 
  332     template <
bool HasMainLoop,
 
  336               typename ABlockTransfer,
 
  337               typename AGridBuffer,
 
  338               typename ABlockBuffer,
 
  339               typename ABlockTransferStep,
 
  342               typename BBlockTransfer,
 
  343               typename BGridBuffer,
 
  344               typename BBlockBuffer,
 
  345               typename BBlockTransferStep,
 
  346               typename CThreadBuffer,
 
  347               typename BScaleStruct>
 
  348     __device__ 
void Run(
const AGridDesc& a_grid_desc,
 
  349                         const ABlockDesc& a_block_desc,
 
  350                         ABlockTransfer& a_blockwise_copy,
 
  351                         const AGridBuffer& a_grid_buf,
 
  352                         ABlockBuffer& a_block_buf,
 
  353                         const ABlockTransferStep& a_block_copy_step,
 
  354                         const BGridDesc& b_grid_desc,
 
  355                         const BBlockDesc& b_block_desc,
 
  356                         BBlockTransfer& b_blockwise_copy,
 
  357                         const BGridBuffer& b_grid_buf,
 
  358                         BBlockBuffer& b_block_buf,
 
  359                         const BBlockTransferStep& b_block_copy_step,
 
  360                         CThreadBuffer& c_thread_buf,
 
  362                         BScaleStruct& b_scale_struct,
 
  364                         index_t num_loop_per_scale)
 const 
  366         __builtin_amdgcn_sched_barrier(0);
 
  367         auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeTypeA>(
 
  368             a_thread_desc_.GetElementSpaceSize());
 
  369         auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeTypeB>(
 
  370             b_thread_desc_.GetElementSpaceSize());
 
  373         a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
 
  374         b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
 
  376         a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
 
  377         b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
 
  379         b_scale_struct.template GlobalLoad<0>(num_loop_per_scale == 1);
 
  382         a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
 
  383         b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
 
  388             a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
 
  389             b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
 
  391             a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
 
  392             b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
 
  396         c_thread_buf.Clear();
 
  401         LocalLoad(a_block_buf, a_thread_buf, b_block_buf, b_thread_buf, b_scale_struct);
 
  403         __builtin_amdgcn_sched_barrier(0);
 
  406         if constexpr(HasMainLoop)
 
  413                 a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
 
  414                 b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
 
  416                 a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
 
  417                 b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
 
  419                 a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
 
  420                 b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
 
  422                 b_scale_struct.template GlobalLoad<0>((i + 2) % num_loop_per_scale == 0);
 
  427                             vector_type<ComputeTypeA, KPack / A_KRow> a_thread_vec;
 
  428                             vector_type<ComputeTypeB, KPack / B_KRow> b_thread_vec;
 
  430                             static_for<0, KPack / A_KRow, 1>{}([&](
auto ik) {
 
  431                                 a_thread_vec.template AsType<ComputeTypeA>()(ik) =
 
  432                                     a_thread_buf[
Number<a_thread_desc_.CalculateOffset(
 
  440                             static_for<0, KPack / B_KRow, 1>{}([&](
auto ik) {
 
  441                                 b_thread_vec.template AsType<ComputeTypeB>()(ik) =
 
  442                                     b_thread_buf[
Number<b_thread_desc_.CalculateOffset(
 
  451                             using wmma_input_type_a =
 
  452                                 typename vector_type<ComputeTypeA, WmmaK / A_KRow>::type;
 
  453                             using wmma_input_type_b =
 
  454                                 typename vector_type<ComputeTypeB, WmmaK / B_KRow>::type;
 
  457                                 c_thread_desc_.CalculateOffset(
make_tuple(m0, n0, I0));
 
  459                             wmma_gemm.Run(a_thread_vec.template AsType<wmma_input_type_a>(),
 
  460                                           b_thread_vec.template AsType<wmma_input_type_b>(),
 
  468                 LocalLoad(a_block_buf, a_thread_buf, b_block_buf, b_thread_buf, b_scale_struct);
 
  471                 __builtin_amdgcn_sched_barrier(0);
 
  474             } 
while(i < (num_loop - 2));
 
  482             a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
 
  483             b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
 
  487             b_scale_struct.template GlobalLoad<0>(num_loop % num_loop_per_scale == 0);
 
  492                         vector_type<ComputeTypeA, KPack / A_KRow> a_thread_vec;
 
  493                         vector_type<ComputeTypeB, KPack / B_KRow> b_thread_vec;
 
  495                         static_for<0, KPack / A_KRow, 1>{}([&](
auto ik) {
 
  496                             a_thread_vec.template AsType<ComputeTypeA>()(ik) =
 
  500                         static_for<0, KPack / B_KRow, 1>{}([&](
auto ik) {
 
  501                             b_thread_vec.template AsType<ComputeTypeB>()(ik) =
 
  506                         using wmma_input_type_a =
 
  507                             typename vector_type<ComputeTypeA, WmmaK / A_KRow>::type;
 
  508                         using wmma_input_type_b =
 
  509                             typename vector_type<ComputeTypeB, WmmaK / B_KRow>::type;
 
  512                             c_thread_desc_.CalculateOffset(
make_tuple(m0, n0, I0));
 
  514                         wmma_gemm.Run(a_thread_vec.template AsType<wmma_input_type_a>(),
 
  515                                       b_thread_vec.template AsType<wmma_input_type_b>(),
 
  523             LocalLoad(a_block_buf, a_thread_buf, b_block_buf, b_thread_buf, b_scale_struct);
 
  526             __builtin_amdgcn_sched_barrier(0);
 
  534                         vector_type<ComputeTypeA, KPack / A_KRow> a_thread_vec;
 
  535                         vector_type<ComputeTypeB, KPack / B_KRow> b_thread_vec;
 
  537                         static_for<0, KPack / A_KRow, 1>{}([&](
auto ik) {
 
  538                             a_thread_vec.template AsType<ComputeTypeA>()(ik) =
 
  542                         static_for<0, KPack / B_KRow, 1>{}([&](
auto ik) {
 
  543                             b_thread_vec.template AsType<ComputeTypeB>()(ik) =
 
  548                         using wmma_input_type_a =
 
  549                             typename vector_type<ComputeTypeA, WmmaK / A_KRow>::type;
 
  550                         using wmma_input_type_b =
 
  551                             typename vector_type<ComputeTypeB, WmmaK / B_KRow>::type;
 
  554                             c_thread_desc_.CalculateOffset(
make_tuple(m0, n0, I0));
 
  556                         wmma_gemm.Run(a_thread_vec.template AsType<wmma_input_type_a>(),
 
  557                                       b_thread_vec.template AsType<wmma_input_type_b>(),
 
  569     using Base::a_thread_copy_;
 
  570     using Base::a_thread_desc_;
 
  571     using Base::b_thread_copy_;
 
  572     using Base::b_thread_desc_;
 
  573     using Base::c_thread_desc_;
 
TailNumber
Definition: blkgemmpipe_scheduler.hpp:31
 
BlockGemmPipelineScheduler
Definition: blkgemmpipe_scheduler.hpp:25
 
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
 
int32_t index_t
Definition: ck.hpp:299
 
__device__ void block_sync_lds()
Definition: synchronization.hpp:16
 
Definition: blockwise_gemm_pipeline_wmmaops_base.hpp:35
 
Definition: blockwise_gemm_pipeline_wmmaops.hpp:26
 
ck::BlockwiseGemmWmmaops_pipeline_v3< BlockGemmPipelineScheduler::Intrawave, BlockSize, ADataType, BDataType, ComputeTypeA, ComputeTypeB, AccDataType, AWmmaTileDesc, BWmmaTileDesc, ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerWmma, NPerWmma, MRepeat, NRepeat, KPack, TransposeC >::LocalLoad __device__ void LocalLoad(ABlockBuffer &a_block_buf, AThreadBuffer &a_thread_buf, BBlockBuffer &b_block_buf, BThreadBuffer &b_thread_buf, BScaleStruct &b_scale_struct) const
Definition: blockwise_gemm_pipeline_wmmaops_v3.hpp:285
 
ck::BlockwiseGemmWmmaops_pipeline_v3< BlockGemmPipelineScheduler::Intrawave, BlockSize, ADataType, BDataType, ComputeTypeA, ComputeTypeB, AccDataType, AWmmaTileDesc, BWmmaTileDesc, ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerWmma, NPerWmma, MRepeat, NRepeat, KPack, TransposeC >::BlockLoopTailNum __host__ static constexpr __device__ TailNumber BlockLoopTailNum(index_t num_loop)
Definition: blockwise_gemm_pipeline_wmmaops_v3.hpp:153
 
ck::BlockwiseGemmWmmaops_pipeline_v3< BlockGemmPipelineScheduler::Intrawave, BlockSize, ADataType, BDataType, ComputeTypeA, ComputeTypeB, AccDataType, AWmmaTileDesc, BWmmaTileDesc, ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerWmma, NPerWmma, MRepeat, NRepeat, KPack, TransposeC >::BlockHasHotloop __host__ static constexpr __device__ bool BlockHasHotloop(index_t num_loop)
Definition: blockwise_gemm_pipeline_wmmaops_v3.hpp:148
 
ck::BlockwiseGemmWmmaops_pipeline_v3< BlockGemmPipelineScheduler::Intrawave, BlockSize, ADataType, BDataType, ComputeTypeA, ComputeTypeB, AccDataType, AWmmaTileDesc, BWmmaTileDesc, ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerWmma, NPerWmma, MRepeat, NRepeat, KPack, TransposeC >::HotLoopScheduler static constexpr __device__ auto HotLoopScheduler()
Definition: blockwise_gemm_pipeline_wmmaops_v3.hpp:172
 
ck::BlockwiseGemmWmmaops_pipeline_v3< BlockGemmPipelineScheduler::Intrawave, BlockSize, ADataType, BDataType, ComputeTypeA, ComputeTypeB, AccDataType, AWmmaTileDesc, BWmmaTileDesc, ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerWmma, NPerWmma, MRepeat, NRepeat, KPack, TransposeC >::Run __device__ void Run(const AGridDesc &a_grid_desc, const ABlockDesc &a_block_desc, ABlockTransfer &a_blockwise_copy, const AGridBuffer &a_grid_buf, ABlockBuffer &a_block_buf, const ABlockTransferStep &a_block_copy_step, const BGridDesc &b_grid_desc, const BBlockDesc &b_block_desc, BBlockTransfer &b_blockwise_copy, const BGridBuffer &b_grid_buf, BBlockBuffer &b_block_buf, const BBlockTransferStep &b_block_copy_step, CThreadBuffer &c_thread_buf, BScaleStruct &b_scale_struct, index_t num_loop, index_t num_loop_per_scale) const
Definition: blockwise_gemm_pipeline_wmmaops_v3.hpp:348
 
Definition: blockwise_gemm_pipeline_wmmaops_v3.hpp:37
 
Definition: integral_constant.hpp:20
 
Definition: functional2.hpp:33
 
Definition: dtype_vector.hpp:10