20 typename ComputeDataType,
24 typename AMmaTileDesc,
25 typename BMmaTileDesc,
26 index_t ABlockTransferSrcScalarPerVector,
27 index_t BBlockTransferSrcScalarPerVector,
43 typename ComputeDataType,
47 typename AMmaTileDesc,
48 typename BMmaTileDesc,
49 index_t ABlockTransferSrcScalarPerVector,
50 index_t BBlockTransferSrcScalarPerVector,
71 ABlockTransferSrcScalarPerVector,
72 BBlockTransferSrcScalarPerVector,
90 ABlockTransferSrcScalarPerVector,
91 BBlockTransferSrcScalarPerVector,
111 ABlockTransferSrcScalarPerVector,
112 BBlockTransferSrcScalarPerVector,
123 using Base::xdlops_gemm;
125 using Base::CalculateCThreadOriginDataIndex;
126 using Base::CalculateCThreadOriginDataIndex8D;
127 using Base::GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2;
128 using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2;
129 using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4;
130 using Base::GetCThreadBuffer;
131 using Base::GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2;
132 using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2;
133 using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4;
134 using Base::MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2;
135 using Base::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2;
137 using Base::a_block_desc_m0_m1_m2_k;
138 using Base::b_block_desc_n0_n1_n2_k;
140 using Base::AMmaKStride;
141 using Base::BMmaKStride;
151 return num_loop > PrefetchStages;
160 template <
bool HasMainLoop,
164 typename ABlockTransfer,
165 typename AGridBuffer,
166 typename ABlockBuffer,
167 typename ABlockTransferStep,
170 typename BBlockTransfer,
171 typename BGridBuffer,
172 typename BBlockBuffer,
173 typename BBlockTransferStep,
174 typename CThreadBuffer>
175 __device__
void Run(
const AGridDesc& a_grid_desc,
176 const ABlockDesc& a_block_desc,
177 ABlockTransfer& a_blockwise_copy,
178 const AGridBuffer& a_grid_buf,
179 ABlockBuffer& a_block_buf,
180 const ABlockTransferStep& a_block_copy_step,
181 const BGridDesc& b_grid_desc,
182 const BBlockDesc& b_block_desc,
183 BBlockTransfer& b_blockwise_copy,
184 const BGridBuffer& b_grid_buf,
185 BBlockBuffer& b_block_buf,
186 const BBlockTransferStep& b_block_copy_step,
187 CThreadBuffer& c_thread_buf,
190 auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataTypeBuf>(
191 a_thread_desc_.GetElementSpaceSize());
192 auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataTypeBuf>(
193 b_thread_desc_.GetElementSpaceSize());
196 a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
197 b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
199 a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
200 b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
203 a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
204 b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
207 c_thread_buf.Clear();
210 if constexpr(HasMainLoop)
216 a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
217 b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
219 a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
220 b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
225 a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
232 b_thread_copy_.Run(b_block_desc_n0_n1_n2_k,
249 a_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
250 a_thread_buf[
Number<a_thread_desc_.CalculateOffset(
252 b_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
253 b_thread_buf[
Number<b_thread_desc_.CalculateOffset(
257 using mfma_input_type =
259 xdlops_gemm.K1PerXdlops>::type;
262 c_thread_desc_.CalculateOffset(
make_tuple(m0, n0, 0));
265 a_thread_vec.template AsType<mfma_input_type>(),
266 b_thread_vec.template AsType<mfma_input_type>(),
273 a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
274 b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
277 }
while(i < (num_loop - 1));
286 a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
293 b_thread_copy_.Run(b_block_desc_n0_n1_n2_k,
310 a_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
311 a_thread_buf[
Number<a_thread_desc_.CalculateOffset(
313 b_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
314 b_thread_buf[
Number<b_thread_desc_.CalculateOffset(
318 using mfma_input_type =
322 c_thread_desc_.CalculateOffset(
make_tuple(m0, n0, 0));
324 xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
325 b_thread_vec.template AsType<mfma_input_type>(),
334 using Base::a_thread_copy_;
335 using Base::a_thread_desc_;
336 using Base::b_thread_copy_;
337 using Base::b_thread_desc_;
338 using Base::c_thread_desc_;
344 typename ComputeDataType,
345 typename AccDataType,
348 typename AMmaTileDesc,
349 typename BMmaTileDesc,
350 index_t ABlockTransferSrcScalarPerVector,
351 index_t BBlockTransferSrcScalarPerVector,
372 ABlockTransferSrcScalarPerVector,
373 BBlockTransferSrcScalarPerVector,
391 ABlockTransferSrcScalarPerVector,
392 BBlockTransferSrcScalarPerVector,
412 ABlockTransferSrcScalarPerVector,
413 BBlockTransferSrcScalarPerVector,
426 using Base::KPerThread;
427 using Base::xdlops_gemm;
429 using Base::CalculateCThreadOriginDataIndex;
430 using Base::CalculateCThreadOriginDataIndex8D;
431 using Base::GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2;
432 using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2;
433 using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4;
434 using Base::GetCThreadBuffer;
435 using Base::GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2;
436 using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2;
437 using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4;
438 using Base::MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2;
439 using Base::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2;
441 using Base::a_block_desc_m0_m1_m2_k;
442 using Base::b_block_desc_n0_n1_n2_k;
448 static constexpr
index_t KRepeat = KPerThread / KPerInnerLoop;
454 return num_loop > PrefetchStages;
463 template <
bool HasMainLoop,
467 typename ABlockTransfer,
468 typename AGridBuffer,
469 typename ABlockBuffer,
470 typename ABlockTransferStep,
473 typename BBlockTransfer,
474 typename BGridBuffer,
475 typename BBlockBuffer,
476 typename BBlockTransferStep,
477 typename CThreadBuffer>
478 __device__
void Run(
const AGridDesc& a_grid_desc,
479 const ABlockDesc& a_block_desc,
480 ABlockTransfer& a_blockwise_copy,
481 const AGridBuffer& a_grid_buf,
482 ABlockBuffer& a_block_buf,
483 const ABlockTransferStep& a_block_copy_step,
484 const BGridDesc& b_grid_desc,
485 const BBlockDesc& b_block_desc,
486 BBlockTransfer& b_blockwise_copy,
487 const BGridBuffer& b_grid_buf,
488 BBlockBuffer& b_block_buf,
489 const BBlockTransferStep& b_block_copy_step,
490 CThreadBuffer& c_thread_buf,
493 auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataTypeBuf>(
494 a_thread_desc_.GetElementSpaceSize());
495 auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataTypeBuf>(
496 b_thread_desc_.GetElementSpaceSize());
499 a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
500 b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
502 a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
503 b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
506 a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
507 b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
510 c_thread_buf.Clear();
513 if constexpr(HasMainLoop)
519 a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
520 b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
522 a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
523 b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
528 a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
535 b_thread_copy_.Run(b_block_desc_n0_n1_n2_k,
543 __builtin_amdgcn_sched_barrier(0);
550 if constexpr(k0.value != 0 || KRepeat == 1)
552 __builtin_amdgcn_s_barrier();
553 __builtin_amdgcn_sched_barrier(0);
562 a_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
563 a_thread_buf[
Number<a_thread_desc_.CalculateOffset(
565 b_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
566 b_thread_buf[
Number<b_thread_desc_.CalculateOffset(
570 using mfma_input_type =
572 xdlops_gemm.K1PerXdlops>::type;
575 c_thread_desc_.CalculateOffset(
make_tuple(m0, n0, 0));
582 if constexpr(k0.value == KRepeat - 1 &&
583 k_.value == KPerInnerLoop - KPack &&
584 m0.value == MRepeat - 1 && n0.value == NRepeat - 1)
586 __builtin_amdgcn_sched_barrier(0);
588 __builtin_amdgcn_sched_barrier(0);
591 a_thread_vec.template AsType<mfma_input_type>(),
592 b_thread_vec.template AsType<mfma_input_type>(),
594 if constexpr(k_.value == 0 && m0.value == 0 && n0.value == 0)
596 __builtin_amdgcn_sched_barrier(0);
597 __builtin_amdgcn_s_setprio(1);
598 __builtin_amdgcn_sched_barrier(0);
603 __builtin_amdgcn_sched_barrier(0);
604 __builtin_amdgcn_s_setprio(0);
605 __builtin_amdgcn_sched_barrier(0);
609 a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
610 b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
613 }
while(i < (num_loop - 1));
622 a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
629 b_thread_copy_.Run(b_block_desc_n0_n1_n2_k,
638 __builtin_amdgcn_sched_barrier(0);
639 if constexpr(k0.value != 0 || KRepeat == 1)
641 __builtin_amdgcn_s_barrier();
642 __builtin_amdgcn_sched_barrier(0);
651 a_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
652 a_thread_buf[
Number<a_thread_desc_.CalculateOffset(
654 b_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
655 b_thread_buf[
Number<b_thread_desc_.CalculateOffset(
659 using mfma_input_type =
661 xdlops_gemm.K1PerXdlops>::type;
664 c_thread_desc_.CalculateOffset(
make_tuple(m0, n0, 0));
666 if constexpr(k0.value == KRepeat - 1 &&
667 k_.value == KPerInnerLoop - KPack &&
668 m0.value == MRepeat - 1 && n0.value == NRepeat - 1)
670 __builtin_amdgcn_sched_barrier(0);
672 __builtin_amdgcn_sched_barrier(0);
675 a_thread_vec.template AsType<mfma_input_type>(),
676 b_thread_vec.template AsType<mfma_input_type>(),
678 if constexpr(k_.value == 0 && m0.value == 0 && n0.value == 0)
680 __builtin_amdgcn_sched_barrier(0);
681 __builtin_amdgcn_s_setprio(1);
682 __builtin_amdgcn_sched_barrier(0);
687 __builtin_amdgcn_sched_barrier(0);
688 __builtin_amdgcn_s_setprio(0);
689 __builtin_amdgcn_sched_barrier(0);
699 Number<KRepeat * MRepeat * KPerInnerLoop>{},
700 Number<MRepeat * KPerInnerLoop>{},
706 Number<KRepeat * NRepeat * KPerInnerLoop>{},
707 Number<NRepeat * KPerInnerLoop>{},
712 decltype(a_block_desc_m0_m1_m2_k),
713 decltype(a_thread_desc_),
722 decltype(b_block_desc_n0_n1_n2_k),
723 decltype(b_thread_desc_),
730 AThreadCopy a_thread_copy_{Base::CalculateAThreadOriginDataIndex()};
731 BThreadCopy b_thread_copy_{Base::CalculateBThreadOriginDataIndex()};
732 using Base::c_thread_desc_;
#define CK_EXPERIMENTAL_INTER_WAVE_SCHEDULING_MAC_CLUSTERS
Definition: ck.hpp:209
__host__ constexpr __device__ T max(T x)
Definition: math.hpp:84
__host__ constexpr __device__ auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition: tensor_descriptor_helper.hpp:49
TailNumber
Definition: blkgemmpipe_scheduler.hpp:31
constexpr detail::ignore_t ignore
Definition: ignore.hpp:20
BlockGemmPipelineScheduler
Definition: blkgemmpipe_scheduler.hpp:25
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
int32_t index_t
Definition: ck.hpp:299
__device__ void block_sync_lds()
Definition: synchronization.hpp:10
Definition: blockwise_gemm_pipeline_xdlops_base.hpp:35
conditional_t< std::is_same< ComputeDataType, ck::tf32_t >::value, float, ComputeDataType > ComputeDataTypeBuf
Definition: blockwise_gemm_pipeline_xdlops_base.hpp:58
ck::BlockwiseGemmXdlops_pipeline_v1< BlockGemmPipelineScheduler::Interwave, BlockSize, ADataType, BDataType, ComputeDataType, AccDataType, ATileDesc, BTileDesc, AMmaTileDesc, BMmaTileDesc, ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerXDL, NPerXDL, MRepeat, NRepeat, KPack >::BlockLoopTailNum __host__ static constexpr __device__ TailNumber BlockLoopTailNum(index_t num_loop)
Definition: blockwise_gemm_pipeline_xdlops_v1.hpp:457
ck::BlockwiseGemmXdlops_pipeline_v1< BlockGemmPipelineScheduler::Interwave, BlockSize, ADataType, BDataType, ComputeDataType, AccDataType, ATileDesc, BTileDesc, AMmaTileDesc, BMmaTileDesc, ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerXDL, NPerXDL, MRepeat, NRepeat, KPack >::BlockHasHotloop __host__ static constexpr __device__ bool BlockHasHotloop(index_t num_loop)
Definition: blockwise_gemm_pipeline_xdlops_v1.hpp:452
ck::BlockwiseGemmXdlops_pipeline_v1< BlockGemmPipelineScheduler::Interwave, BlockSize, ADataType, BDataType, ComputeDataType, AccDataType, ATileDesc, BTileDesc, AMmaTileDesc, BMmaTileDesc, ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerXDL, NPerXDL, MRepeat, NRepeat, KPack >::Run __device__ void Run(const AGridDesc &a_grid_desc, const ABlockDesc &a_block_desc, ABlockTransfer &a_blockwise_copy, const AGridBuffer &a_grid_buf, ABlockBuffer &a_block_buf, const ABlockTransferStep &a_block_copy_step, const BGridDesc &b_grid_desc, const BBlockDesc &b_block_desc, BBlockTransfer &b_blockwise_copy, const BGridBuffer &b_grid_buf, BBlockBuffer &b_block_buf, const BBlockTransferStep &b_block_copy_step, CThreadBuffer &c_thread_buf, index_t num_loop) const
Definition: blockwise_gemm_pipeline_xdlops_v1.hpp:478
ck::BlockwiseGemmXdlops_pipeline_v1< BlockGemmPipelineScheduler::Intrawave, BlockSize, ADataType, BDataType, ComputeDataType, AccDataType, ATileDesc, BTileDesc, AMmaTileDesc, BMmaTileDesc, ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerXDL, NPerXDL, MRepeat, NRepeat, KPack >::Run __device__ void Run(const AGridDesc &a_grid_desc, const ABlockDesc &a_block_desc, ABlockTransfer &a_blockwise_copy, const AGridBuffer &a_grid_buf, ABlockBuffer &a_block_buf, const ABlockTransferStep &a_block_copy_step, const BGridDesc &b_grid_desc, const BBlockDesc &b_block_desc, BBlockTransfer &b_blockwise_copy, const BGridBuffer &b_grid_buf, BBlockBuffer &b_block_buf, const BBlockTransferStep &b_block_copy_step, CThreadBuffer &c_thread_buf, index_t num_loop) const
Definition: blockwise_gemm_pipeline_xdlops_v1.hpp:175
ck::BlockwiseGemmXdlops_pipeline_v1< BlockGemmPipelineScheduler::Intrawave, BlockSize, ADataType, BDataType, ComputeDataType, AccDataType, ATileDesc, BTileDesc, AMmaTileDesc, BMmaTileDesc, ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerXDL, NPerXDL, MRepeat, NRepeat, KPack >::BlockLoopTailNum __host__ static constexpr __device__ TailNumber BlockLoopTailNum(index_t num_loop)
Definition: blockwise_gemm_pipeline_xdlops_v1.hpp:154
ck::BlockwiseGemmXdlops_pipeline_v1< BlockGemmPipelineScheduler::Intrawave, BlockSize, ADataType, BDataType, ComputeDataType, AccDataType, ATileDesc, BTileDesc, AMmaTileDesc, BMmaTileDesc, ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerXDL, NPerXDL, MRepeat, NRepeat, KPack >::BlockHasHotloop __host__ static constexpr __device__ bool BlockHasHotloop(index_t num_loop)
Definition: blockwise_gemm_pipeline_xdlops_v1.hpp:149
Definition: blockwise_gemm_pipeline_xdlops_v1.hpp:37
Definition: sequence.hpp:43
ck::ThreadwiseTensorSliceTransfer_v4< ADataType, ComputeDataTypeBuf, decltype(a_block_desc_m0_m1_m2_k), decltype(a_thread_desc_), Sequence< 1, 1, 1, KPack >, Sequence< 0, 1, 2, 3 >, 3, A_K1, A_K1 >
Definition: integral_constant.hpp:20
Definition: functional2.hpp:33
Definition: dtype_vector.hpp:10