20 typename ComputeDataType,
24 typename AMmaTileDesc,
25 typename BMmaTileDesc,
26 index_t ABlockTransferSrcScalarPerVector,
27 index_t BBlockTransferSrcScalarPerVector,
43 typename ComputeDataType,
47 typename AMmaTileDesc,
48 typename BMmaTileDesc,
49 index_t ABlockTransferSrcScalarPerVector,
50 index_t BBlockTransferSrcScalarPerVector,
71 ABlockTransferSrcScalarPerVector,
72 BBlockTransferSrcScalarPerVector,
90 ABlockTransferSrcScalarPerVector,
91 BBlockTransferSrcScalarPerVector,
111 ABlockTransferSrcScalarPerVector,
112 BBlockTransferSrcScalarPerVector,
123 using Base::xdlops_gemm;
125 using Base::CalculateCThreadOriginDataIndex;
126 using Base::CalculateCThreadOriginDataIndex8D;
127 using Base::GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2;
128 using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2;
129 using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4;
130 using Base::GetCThreadBuffer;
131 using Base::GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2;
132 using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2;
133 using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4;
134 using Base::MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2;
135 using Base::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2;
137 using Base::a_block_desc_m0_m1_m2_k;
138 using Base::b_block_desc_n0_n1_n2_k;
140 using Base::AMmaKStride;
141 using Base::BMmaKStride;
142 using Base::WaveSize;
145 (4 * WaveSize / BlockSize) >= 1 ? 4 * WaveSize / BlockSize : 1;
148 (MPerBlock *
sizeof(ADataType) + NPerBlock *
sizeof(BDataType)) * KPerBlock);
150 FullMemBandPrefetchStages >= 2
151 ? FullMemBandPrefetchStages <= 8 ? FullMemBandPrefetchStages : 8
155 static constexpr
index_t GlobalBufferNum = PrefetchStages;
159 return num_loop > PrefetchStages;
164 if(num_loop % PrefetchStages == 1)
168 else if(num_loop % PrefetchStages == 2)
172 else if(num_loop % PrefetchStages == 3)
176 else if(num_loop % PrefetchStages == 4)
180 else if(num_loop % PrefetchStages == 5)
184 else if(num_loop % PrefetchStages == 6)
188 else if(num_loop % PrefetchStages == 7)
198 template <
bool HasMainLoop,
202 typename ABlockTransfer,
203 typename AGridBuffer,
204 typename ABlockBuffer,
205 typename ABlockTransferStep,
208 typename BBlockTransfer,
209 typename BGridBuffer,
210 typename BBlockBuffer,
211 typename BBlockTransferStep,
212 typename CThreadBuffer>
213 __device__
void Run(
const AGridDesc& a_grid_desc,
214 const ABlockDesc& a_block_desc,
215 ABlockTransfer& a_blockwise_copy,
216 const AGridBuffer& a_grid_buf,
217 ABlockBuffer& a_block_buf,
218 const ABlockTransferStep& a_block_copy_step,
219 const BGridDesc& b_grid_desc,
220 const BBlockDesc& b_block_desc,
221 BBlockTransfer& b_blockwise_copy,
222 const BGridBuffer& b_grid_buf,
223 BBlockBuffer& b_block_buf,
224 const BBlockTransferStep& b_block_copy_step,
225 CThreadBuffer& c_thread_buf,
228 auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataType>(
229 a_thread_desc_.GetElementSpaceSize());
230 auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataType>(
231 b_thread_desc_.GetElementSpaceSize());
234 a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I0);
235 b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf, I0);
237 a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
238 b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
241 c_thread_buf.Clear();
244 a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, I0);
245 b_blockwise_copy.RunWrite(b_block_desc, b_block_buf, I0);
249 a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, iprefetch);
250 b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf, iprefetch);
252 a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
253 b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
257 if constexpr(HasMainLoop)
267 a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
275 b_thread_copy_.Run(b_block_desc_n0_n1_n2_k,
291 a_thread_vec.template AsType<ComputeDataType>()(ik) =
292 a_thread_buf[
Number<a_thread_desc_.CalculateOffset(
294 b_thread_vec.template AsType<ComputeDataType>()(ik) =
295 b_thread_buf[
Number<b_thread_desc_.CalculateOffset(
299 using mfma_input_type =
301 xdlops_gemm.K1PerXdlops>::type;
304 c_thread_desc_.CalculateOffset(
make_tuple(m0, n0, 0));
307 a_thread_vec.template AsType<mfma_input_type>(),
308 b_thread_vec.template AsType<mfma_input_type>(),
315 a_blockwise_copy.RunWrite(
316 a_block_desc, a_block_buf,
Number<(iprefetch + 1) % PrefetchStages>{});
317 b_blockwise_copy.RunWrite(
318 b_block_desc, b_block_buf,
Number<(iprefetch + 1) % PrefetchStages>{});
320 a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, iprefetch);
321 b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf, iprefetch);
323 a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
324 b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
328 }
while(i < (num_loop - PrefetchStages));
333 auto LoopTailFunc = [&](
auto tail_num) {
338 a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
346 b_thread_copy_.Run(b_block_desc_n0_n1_n2_k,
362 a_thread_vec.template AsType<ComputeDataType>()(ik) =
363 a_thread_buf[
Number<a_thread_desc_.CalculateOffset(
365 b_thread_vec.template AsType<ComputeDataType>()(ik) =
366 b_thread_buf[
Number<b_thread_desc_.CalculateOffset(
370 using mfma_input_type =
372 xdlops_gemm.K1PerXdlops>::type;
375 c_thread_desc_.CalculateOffset(
make_tuple(m0, n0, 0));
378 a_thread_vec.template AsType<mfma_input_type>(),
379 b_thread_vec.template AsType<mfma_input_type>(),
386 a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, iprefetch);
387 b_blockwise_copy.RunWrite(b_block_desc, b_block_buf, iprefetch);
393 a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
401 b_thread_copy_.Run(b_block_desc_n0_n1_n2_k,
417 a_thread_vec.template AsType<ComputeDataType>()(ik) =
418 a_thread_buf[
Number<a_thread_desc_.CalculateOffset(
420 b_thread_vec.template AsType<ComputeDataType>()(ik) =
421 b_thread_buf[
Number<b_thread_desc_.CalculateOffset(
425 using mfma_input_type =
426 typename vector_type<ComputeDataType, xdlops_gemm.K1PerXdlops>::type;
429 c_thread_desc_.CalculateOffset(
make_tuple(m0, n0, 0));
431 xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
432 b_thread_vec.template AsType<mfma_input_type>(),
444 a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
452 b_thread_copy_.Run(b_block_desc_n0_n1_n2_k,
468 a_thread_vec.template AsType<ComputeDataType>()(ik) =
469 a_thread_buf[
Number<a_thread_desc_.CalculateOffset(
471 b_thread_vec.template AsType<ComputeDataType>()(ik) =
472 b_thread_buf[
Number<b_thread_desc_.CalculateOffset(
476 using mfma_input_type =
477 typename vector_type<ComputeDataType, xdlops_gemm.K1PerXdlops>::type;
480 c_thread_desc_.CalculateOffset(
make_tuple(m0, n0, 0));
482 xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
483 b_thread_vec.template AsType<mfma_input_type>(),
520 using Base::a_thread_copy_;
521 using Base::a_thread_desc_;
522 using Base::b_thread_copy_;
523 using Base::b_thread_desc_;
524 using Base::c_thread_desc_;
530 typename ComputeDataType,
531 typename AccDataType,
534 typename AMmaTileDesc,
535 typename BMmaTileDesc,
536 index_t ABlockTransferSrcScalarPerVector,
537 index_t BBlockTransferSrcScalarPerVector,
558 ABlockTransferSrcScalarPerVector,
559 BBlockTransferSrcScalarPerVector,
577 ABlockTransferSrcScalarPerVector,
578 BBlockTransferSrcScalarPerVector,
598 ABlockTransferSrcScalarPerVector,
599 BBlockTransferSrcScalarPerVector,
612 using Base::KPerThread;
613 using Base::xdlops_gemm;
615 using Base::CalculateCThreadOriginDataIndex;
616 using Base::CalculateCThreadOriginDataIndex8D;
617 using Base::GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2;
618 using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2;
619 using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4;
620 using Base::GetCThreadBuffer;
621 using Base::GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2;
622 using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2;
623 using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4;
624 using Base::MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2;
625 using Base::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2;
627 using Base::a_block_desc_m0_m1_m2_k;
628 using Base::b_block_desc_n0_n1_n2_k;
629 using Base::WaveSize;
633 static constexpr
index_t KRepeat = KPerThread / KPerInnerLoop;
636 (4 * WaveSize / BlockSize) >= 1 ? 4 * WaveSize / BlockSize : 1;
639 (MPerBlock *
sizeof(ADataType) + NPerBlock *
sizeof(BDataType)) * KPerBlock);
641 FullMemBandPrefetchStages >= 2
642 ? FullMemBandPrefetchStages <= 8 ? FullMemBandPrefetchStages : 8
646 static constexpr
index_t GlobalBufferNum = PrefetchStages;
650 return num_loop > PrefetchStages;
655 if(num_loop % PrefetchStages == 1)
659 else if(num_loop % PrefetchStages == 2)
663 else if(num_loop % PrefetchStages == 3)
667 else if(num_loop % PrefetchStages == 4)
671 else if(num_loop % PrefetchStages == 5)
675 else if(num_loop % PrefetchStages == 6)
679 else if(num_loop % PrefetchStages == 7)
689 template <
bool HasMainLoop,
693 typename ABlockTransfer,
694 typename AGridBuffer,
695 typename ABlockBuffer,
696 typename ABlockTransferStep,
699 typename BBlockTransfer,
700 typename BGridBuffer,
701 typename BBlockBuffer,
702 typename BBlockTransferStep,
703 typename CThreadBuffer>
704 __device__
void Run(
const AGridDesc& a_grid_desc,
705 const ABlockDesc& a_block_desc,
706 ABlockTransfer& a_blockwise_copy,
707 const AGridBuffer& a_grid_buf,
708 ABlockBuffer& a_block_buf,
709 const ABlockTransferStep& a_block_copy_step,
710 const BGridDesc& b_grid_desc,
711 const BBlockDesc& b_block_desc,
712 BBlockTransfer& b_blockwise_copy,
713 const BGridBuffer& b_grid_buf,
714 BBlockBuffer& b_block_buf,
715 const BBlockTransferStep& b_block_copy_step,
716 CThreadBuffer& c_thread_buf,
719 auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataType>(
720 a_thread_desc_.GetElementSpaceSize());
721 auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataType>(
722 b_thread_desc_.GetElementSpaceSize());
725 a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I0);
726 b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf, I0);
728 a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
729 b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
732 c_thread_buf.Clear();
735 a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, I0);
736 b_blockwise_copy.RunWrite(b_block_desc, b_block_buf, I0);
740 a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, iprefetch);
741 b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf, iprefetch);
743 a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
744 b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
748 if constexpr(HasMainLoop)
758 a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
766 b_thread_copy_.Run(b_block_desc_n0_n1_n2_k,
773 __builtin_amdgcn_sched_barrier(0);
781 if constexpr(k0.value != 0 || KRepeat == 1)
783 __builtin_amdgcn_s_barrier();
784 __builtin_amdgcn_sched_barrier(0);
793 a_thread_vec.template AsType<ComputeDataType>()(ik) =
794 a_thread_buf[
Number<a_thread_desc_.CalculateOffset(
796 b_thread_vec.template AsType<ComputeDataType>()(ik) =
797 b_thread_buf[
Number<b_thread_desc_.CalculateOffset(
801 using mfma_input_type =
803 xdlops_gemm.K1PerXdlops>::type;
806 c_thread_desc_.CalculateOffset(
make_tuple(m0, n0, 0));
814 if constexpr(k0.value == KRepeat - 1 &&
815 k_.value == KPerInnerLoop - KPack &&
816 m0.value == MRepeat - 1 && n0.value == NRepeat - 1)
818 __builtin_amdgcn_sched_barrier(0);
820 __builtin_amdgcn_sched_barrier(0);
823 a_thread_vec.template AsType<mfma_input_type>(),
824 b_thread_vec.template AsType<mfma_input_type>(),
826 if constexpr(k_.value == 0 && m0.value == 0 && n0.value == 0)
828 __builtin_amdgcn_sched_barrier(0);
829 __builtin_amdgcn_s_setprio(1);
830 __builtin_amdgcn_sched_barrier(0);
835 __builtin_amdgcn_sched_barrier(0);
836 __builtin_amdgcn_s_setprio(0);
837 __builtin_amdgcn_sched_barrier(0);
841 a_blockwise_copy.RunWrite(
842 a_block_desc, a_block_buf,
Number<(iprefetch + 1) % PrefetchStages>{});
843 b_blockwise_copy.RunWrite(
844 b_block_desc, b_block_buf,
Number<(iprefetch + 1) % PrefetchStages>{});
846 a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, iprefetch);
847 b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf, iprefetch);
849 a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
850 b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
853 }
while(i < (num_loop - PrefetchStages));
858 auto LoopTailFunc = [&](
auto tail_num) {
863 a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
871 b_thread_copy_.Run(b_block_desc_n0_n1_n2_k,
879 __builtin_amdgcn_sched_barrier(0);
880 if constexpr(k0.value != 0 || KRepeat == 1)
882 __builtin_amdgcn_s_barrier();
883 __builtin_amdgcn_sched_barrier(0);
892 a_thread_vec.template AsType<ComputeDataType>()(ik) =
893 a_thread_buf[
Number<a_thread_desc_.CalculateOffset(
895 b_thread_vec.template AsType<ComputeDataType>()(ik) =
896 b_thread_buf[
Number<b_thread_desc_.CalculateOffset(
900 using mfma_input_type =
902 xdlops_gemm.K1PerXdlops>::type;
905 c_thread_desc_.CalculateOffset(
make_tuple(m0, n0, 0));
907 if constexpr(k0.value == KRepeat - 1 &&
908 k_.value == KPerInnerLoop - KPack &&
909 m0.value == MRepeat - 1 && n0.value == NRepeat - 1)
911 __builtin_amdgcn_sched_barrier(0);
913 __builtin_amdgcn_sched_barrier(0);
916 a_thread_vec.template AsType<mfma_input_type>(),
917 b_thread_vec.template AsType<mfma_input_type>(),
919 if constexpr(k_.value == 0 && m0.value == 0 && n0.value == 0)
921 __builtin_amdgcn_sched_barrier(0);
922 __builtin_amdgcn_s_setprio(1);
923 __builtin_amdgcn_sched_barrier(0);
928 __builtin_amdgcn_sched_barrier(0);
929 __builtin_amdgcn_s_setprio(0);
930 __builtin_amdgcn_sched_barrier(0);
933 a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, iprefetch);
934 b_blockwise_copy.RunWrite(b_block_desc, b_block_buf, iprefetch);
939 a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
947 b_thread_copy_.Run(b_block_desc_n0_n1_n2_k,
955 __builtin_amdgcn_sched_barrier(0);
956 if constexpr(k0.value != 0 || KRepeat == 1)
958 __builtin_amdgcn_s_barrier();
959 __builtin_amdgcn_sched_barrier(0);
968 a_thread_vec.template AsType<ComputeDataType>()(ik) =
969 a_thread_buf[
Number<a_thread_desc_.CalculateOffset(
971 b_thread_vec.template AsType<ComputeDataType>()(ik) =
972 b_thread_buf[
Number<b_thread_desc_.CalculateOffset(
976 using mfma_input_type =
978 xdlops_gemm.K1PerXdlops>::type;
981 c_thread_desc_.CalculateOffset(
make_tuple(m0, n0, 0));
983 if constexpr(k0.value == KRepeat - 1 &&
984 k_.value == KPerInnerLoop - KPack &&
985 m0.value == MRepeat - 1 && n0.value == NRepeat - 1)
987 __builtin_amdgcn_sched_barrier(0);
989 __builtin_amdgcn_sched_barrier(0);
992 a_thread_vec.template AsType<mfma_input_type>(),
993 b_thread_vec.template AsType<mfma_input_type>(),
995 if constexpr(k_.value == 0 && m0.value == 0 && n0.value == 0)
997 __builtin_amdgcn_sched_barrier(0);
998 __builtin_amdgcn_s_setprio(1);
999 __builtin_amdgcn_sched_barrier(0);
1004 __builtin_amdgcn_sched_barrier(0);
1005 __builtin_amdgcn_s_setprio(0);
1006 __builtin_amdgcn_sched_barrier(0);
1015 a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
1023 b_thread_copy_.Run(b_block_desc_n0_n1_n2_k,
1031 __builtin_amdgcn_sched_barrier(0);
1032 if constexpr(k0.value != 0 || KRepeat == 1)
1034 __builtin_amdgcn_s_barrier();
1035 __builtin_amdgcn_sched_barrier(0);
1044 a_thread_vec.template AsType<ComputeDataType>()(ik) =
1045 a_thread_buf[
Number<a_thread_desc_.CalculateOffset(
1047 b_thread_vec.template AsType<ComputeDataType>()(ik) =
1048 b_thread_buf[
Number<b_thread_desc_.CalculateOffset(
1052 using mfma_input_type =
1054 xdlops_gemm.K1PerXdlops>::type;
1057 c_thread_desc_.CalculateOffset(
make_tuple(m0, n0, 0));
1059 if constexpr(k0.value == KRepeat - 1 &&
1060 k_.value == KPerInnerLoop - KPack &&
1061 m0.value == MRepeat - 1 && n0.value == NRepeat - 1)
1063 __builtin_amdgcn_sched_barrier(0);
1065 __builtin_amdgcn_sched_barrier(0);
1068 a_thread_vec.template AsType<mfma_input_type>(),
1069 b_thread_vec.template AsType<mfma_input_type>(),
1071 if constexpr(k_.value == 0 && m0.value == 0 && n0.value == 0)
1073 __builtin_amdgcn_sched_barrier(0);
1074 __builtin_amdgcn_s_setprio(1);
1075 __builtin_amdgcn_sched_barrier(0);
1080 __builtin_amdgcn_sched_barrier(0);
1081 __builtin_amdgcn_s_setprio(0);
1082 __builtin_amdgcn_sched_barrier(0);
1120 Number<KRepeat * MRepeat * KPerInnerLoop>{},
1121 Number<MRepeat * KPerInnerLoop>{},
1127 Number<KRepeat * NRepeat * KPerInnerLoop>{},
1128 Number<NRepeat * KPerInnerLoop>{},
1133 decltype(a_block_desc_m0_m1_m2_k),
1134 decltype(a_thread_desc_),
1143 decltype(b_block_desc_n0_n1_n2_k),
1144 decltype(b_thread_desc_),
1153 using Base::c_thread_desc_;
#define CK_EXPERIMENTAL_INTER_WAVE_SCHEDULING_MAC_CLUSTERS
Definition: ck.hpp:208
__host__ constexpr __device__ auto integer_divide_ceil(X x, Y y)
Definition: math.hpp:72
__host__ constexpr __device__ T max(T x)
Definition: math.hpp:84
__host__ constexpr __device__ auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition: tensor_descriptor_helper.hpp:49
TailNumber
Definition: blkgemmpipe_scheduler.hpp:31
BlockGemmPipelineScheduler
Definition: blkgemmpipe_scheduler.hpp:25
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
int32_t index_t
Definition: ck.hpp:298
__device__ void block_sync_lds()
Definition: synchronization.hpp:10
Definition: blockwise_gemm_pipeline_xdlops_base.hpp:35
ck::BlockwiseGemmXdlops_pipeline_v2< BlockGemmPipelineScheduler::Interwave, BlockSize, ADataType, BDataType, ComputeDataType, AccDataType, ATileDesc, BTileDesc, AMmaTileDesc, BMmaTileDesc, ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerXDL, NPerXDL, MRepeat, NRepeat, KPack >::Run __device__ void Run(const AGridDesc &a_grid_desc, const ABlockDesc &a_block_desc, ABlockTransfer &a_blockwise_copy, const AGridBuffer &a_grid_buf, ABlockBuffer &a_block_buf, const ABlockTransferStep &a_block_copy_step, const BGridDesc &b_grid_desc, const BBlockDesc &b_block_desc, BBlockTransfer &b_blockwise_copy, const BGridBuffer &b_grid_buf, BBlockBuffer &b_block_buf, const BBlockTransferStep &b_block_copy_step, CThreadBuffer &c_thread_buf, index_t num_loop) const
Definition: blockwise_gemm_pipeline_xdlops_v2.hpp:704
ck::BlockwiseGemmXdlops_pipeline_v2< BlockGemmPipelineScheduler::Interwave, BlockSize, ADataType, BDataType, ComputeDataType, AccDataType, ATileDesc, BTileDesc, AMmaTileDesc, BMmaTileDesc, ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerXDL, NPerXDL, MRepeat, NRepeat, KPack >::BlockLoopTailNum __host__ static constexpr __device__ TailNumber BlockLoopTailNum(index_t num_loop)
Definition: blockwise_gemm_pipeline_xdlops_v2.hpp:653
ck::BlockwiseGemmXdlops_pipeline_v2< BlockGemmPipelineScheduler::Interwave, BlockSize, ADataType, BDataType, ComputeDataType, AccDataType, ATileDesc, BTileDesc, AMmaTileDesc, BMmaTileDesc, ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerXDL, NPerXDL, MRepeat, NRepeat, KPack >::BlockHasHotloop __host__ static constexpr __device__ bool BlockHasHotloop(index_t num_loop)
Definition: blockwise_gemm_pipeline_xdlops_v2.hpp:648
ck::BlockwiseGemmXdlops_pipeline_v2< BlockGemmPipelineScheduler::Intrawave, BlockSize, ADataType, BDataType, ComputeDataType, AccDataType, ATileDesc, BTileDesc, AMmaTileDesc, BMmaTileDesc, ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerXDL, NPerXDL, MRepeat, NRepeat, KPack >::Run __device__ void Run(const AGridDesc &a_grid_desc, const ABlockDesc &a_block_desc, ABlockTransfer &a_blockwise_copy, const AGridBuffer &a_grid_buf, ABlockBuffer &a_block_buf, const ABlockTransferStep &a_block_copy_step, const BGridDesc &b_grid_desc, const BBlockDesc &b_block_desc, BBlockTransfer &b_blockwise_copy, const BGridBuffer &b_grid_buf, BBlockBuffer &b_block_buf, const BBlockTransferStep &b_block_copy_step, CThreadBuffer &c_thread_buf, index_t num_loop) const
Definition: blockwise_gemm_pipeline_xdlops_v2.hpp:213
ck::BlockwiseGemmXdlops_pipeline_v2< BlockGemmPipelineScheduler::Intrawave, BlockSize, ADataType, BDataType, ComputeDataType, AccDataType, ATileDesc, BTileDesc, AMmaTileDesc, BMmaTileDesc, ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerXDL, NPerXDL, MRepeat, NRepeat, KPack >::BlockLoopTailNum __host__ static constexpr __device__ TailNumber BlockLoopTailNum(index_t num_loop)
Definition: blockwise_gemm_pipeline_xdlops_v2.hpp:162
ck::BlockwiseGemmXdlops_pipeline_v2< BlockGemmPipelineScheduler::Intrawave, BlockSize, ADataType, BDataType, ComputeDataType, AccDataType, ATileDesc, BTileDesc, AMmaTileDesc, BMmaTileDesc, ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerXDL, NPerXDL, MRepeat, NRepeat, KPack >::BlockHasHotloop __host__ static constexpr __device__ bool BlockHasHotloop(index_t num_loop)
Definition: blockwise_gemm_pipeline_xdlops_v2.hpp:157
Definition: blockwise_gemm_pipeline_xdlops_v2.hpp:37
Definition: sequence.hpp:43
ck::ThreadwiseTensorSliceTransfer_v4< ADataType, ComputeDataType, decltype(a_block_desc_m0_m1_m2_k), decltype(a_thread_desc_), Sequence< 1, 1, 1, KPack >, Sequence< 0, 1, 2, 3 >, 3, A_K1, A_K1 >
Definition: integral_constant.hpp:20
Definition: functional2.hpp:33
Definition: dtype_vector.hpp:10