40 MPerBlock * KPerBlock / (BlockSize * ABufferLoadWidth);
42 NPerBlock * KPerBlock / (BlockSize * BBufferLoadWidth);
45 MPerBlock * KPerBlock / (BlockSize * ALDSWriteWidth);
47 NPerBlock * KPerBlock / (BlockSize * BLDSWriteWidth);
50 WaveNumN * MPerBlock * KPerBlock / (BlockSize * ALDSReadWidth);
52 WaveNumM * MPerBlock * KPerBlock / (BlockSize * BLDSReadWidth);
55 MPerBlock * NPerBlock * KPerBlock / (BlockSize /
WaveSize) / (MPerXDL * NPerXDL * KPerXDL);
59 printf(
" Blk/Wave Size: %d, %d, M/N/K PerBlk: %d, %d, %d, M/N/K PerXdl: %d, %d, %d\n",
69 printf(
" A/B buffer load inst: %d, %d\n A/B LDS write inst: %d, %d\n A/B LDS read inst: "
70 "%d, %d\n C MFMA inst: %d\n",
87 typename AMmaTileDesc,
88 typename BMmaTileDesc,
97 bool TransposeC =
false,
99 KPack * XdlopsGemm<FloatAB, MPerXDL, NPerXDL, KPack, FloatAB, TransposeC>{}.K0PerXdlops,
101 KPack * XdlopsGemm<FloatAB, MPerXDL, NPerXDL, KPack, FloatAB, TransposeC>{}.K0PerXdlops>
113 static_assert(
MWaves > 0);
114 static_assert(
NWaves > 0);
145 "Wrong KPack setting; try increasing KPerThread or decreasing KPack");
165 return threadid_to_wave_idx_adaptor.CalculateBottomIndex(
make_multi_index(thread_id));
172 const auto waveId_m = wave_idx[
I0];
174 const auto xdlops_a_idx =
xdlops_gemm.CalculateAThreadOriginDataIndex();
176 return make_tuple(0, waveId_m, xdlops_a_idx[
I1], KPack * xdlops_a_idx[
I0]);
183 const auto waveId_n = wave_idx[
I1];
185 const auto xdlops_b_idx =
xdlops_gemm.CalculateBThreadOriginDataIndex();
187 return make_tuple(0, waveId_n, xdlops_b_idx[
I1], KPack * xdlops_b_idx[
I0]);
190 template <index_t m0, index_t n0, index_t xdlops_i, index_t blk_i>
191 __device__
static auto
196 const auto waveId_m = wave_idx[
I0];
197 const auto waveId_n = wave_idx[
I1];
199 const auto blk_idx =
xdlops_gemm.GetBeginOfThreadBlk(xdlops_i, blk_i);
211 const index_t c_thread_m = mrepeat_mwave_mperxdl_to_m_adaptor.CalculateBottomIndex(
213 const index_t c_thread_n = nrepeat_nwave_nperxdl_to_n_adaptor.CalculateBottomIndex(
219 template <index_t m0, index_t n0, index_t xdlops_i, index_t blk_i>
220 __device__
static auto
225 const auto waveId_m = wave_idx[
I0];
226 const auto waveId_n = wave_idx[
I1];
228 const auto blk_idx =
xdlops_gemm.GetBeginOfThreadBlk4D(xdlops_i, blk_i);
231 m0, n0, waveId_m, waveId_n, blk_idx[
I0], blk_idx[
I1], blk_idx[
I2], blk_idx[
I3]);
241 #if defined(__HIP_DEVICE_COMPILE__)
242 static_assert(AMmaTileDesc::IsKnownAtCompileTime() && BMmaTileDesc::IsKnownAtCompileTime(),
243 "wrong! Desc should be known at compile-time");
246 "ThisThreadBlock::GetNumOfThread() != MWaves * NWaves * WaveSize\n");
248 static_assert(MPerBlock % (MPerXDL * MRepeat) == 0 && NPerBlock % (NPerXDL * NRepeat) == 0,
257 constexpr
auto c_m0_m1_m2_n_tblk_lens =
xdlops_gemm.GetCM0M1M2NThreadBlkLengths();
259 constexpr
auto M0 = c_m0_m1_m2_n_tblk_lens[
I0];
260 constexpr
auto M1 = c_m0_m1_m2_n_tblk_lens[
I1];
261 constexpr
auto M2 = c_m0_m1_m2_n_tblk_lens[
I2];
262 constexpr
auto N = c_m0_m1_m2_n_tblk_lens[
I3];
271 constexpr
auto c_m0_m1_m2_n_tblk_lens =
xdlops_gemm.GetCM0M1M2NThreadBlkLengths();
273 constexpr
auto M0 = c_m0_m1_m2_n_tblk_lens[
I0];
274 constexpr
auto M1 = c_m0_m1_m2_n_tblk_lens[
I1];
275 constexpr
auto M2 = c_m0_m1_m2_n_tblk_lens[
I2];
276 constexpr
auto N = c_m0_m1_m2_n_tblk_lens[
I3];
284 constexpr
auto c_m0_m1_m2_n_tblk_lens =
xdlops_gemm.GetCM0M1M2NThreadBlkLengths();
286 constexpr
auto M0 = c_m0_m1_m2_n_tblk_lens[
I0];
287 constexpr
auto M1 = c_m0_m1_m2_n_tblk_lens[
I1];
288 constexpr
auto M2 = c_m0_m1_m2_n_tblk_lens[
I2];
289 constexpr
auto N = c_m0_m1_m2_n_tblk_lens[
I3];
298 constexpr
auto c_block_desc_m0_n0_m1_n1_m2_n2 =
306 return xdlops_gemm.MakeCDescriptor_M0_N0_M1_N1_M2_N2_N3_N4(c_block_desc_m0_n0_m1_n1_m2_n2);
312 constexpr
auto c_block_desc_m0_n0_m1_n1_m2_n2 =
320 return xdlops_gemm.MakeCDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_block_desc_m0_n0_m1_n1_m2_n2);
325 constexpr
auto c_block_desc_g_m0_n0_m1_n1_m2_n2 =
334 return xdlops_gemm.MakeCDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2(
335 c_block_desc_g_m0_n0_m1_n1_m2_n2);
338 template <
typename CGr
idDesc_M_N>
339 __host__ __device__
static constexpr
auto
342 const auto M = c_grid_desc_m_n.GetLength(
I0);
343 const auto N = c_grid_desc_m_n.GetLength(
I1);
352 return xdlops_gemm.MakeCDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_grid_desc_m0_n0_m1_n1_m2_n2);
355 template <
typename CGr
idDesc_G_M_N>
356 __host__ __device__
static constexpr
auto
359 const auto G = c_grid_desc_g_m_n.GetLength(
I0);
360 const auto M = c_grid_desc_g_m_n.GetLength(
I1);
361 const auto N = c_grid_desc_g_m_n.GetLength(
I2);
371 return xdlops_gemm.MakeCDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2(
372 c_grid_desc_g_m0_n0_m1_n1_m2_n2);
378 constexpr
auto num_ds_read_inst =
380 constexpr
auto num_ds_write_inst =
383 constexpr
auto num_buffer_load_inst =
388 constexpr
auto num_issue = num_buffer_load_inst;
392 __builtin_amdgcn_sched_group_barrier(0x008, 1, 0);
393 __builtin_amdgcn_sched_group_barrier(
394 0x100, num_ds_read_inst / num_buffer_load_inst, 0);
395 __builtin_amdgcn_sched_group_barrier(0x008, 1, 0);
396 __builtin_amdgcn_sched_group_barrier(
397 0x200, num_ds_write_inst / num_buffer_load_inst, 0);
398 __builtin_amdgcn_sched_group_barrier(0x008, 1, 0);
399 __builtin_amdgcn_sched_group_barrier(0x020, 1, 0);
400 __builtin_amdgcn_sched_group_barrier(
401 0x008, num_mfma_inst / num_buffer_load_inst - 3, 0);
405 template <index_t stage>
411 __device__ constexpr
auto TailScheduler<1>()
414 constexpr
auto num_ds_read_inst =
416 constexpr
auto num_ds_write_inst =
421 constexpr
auto num_issue = num_ds_write_inst;
425 __builtin_amdgcn_sched_group_barrier(0x008, 1, 0);
426 __builtin_amdgcn_sched_group_barrier(0x200, 1, 0);
427 __builtin_amdgcn_sched_group_barrier(0x008, 1, 0);
428 __builtin_amdgcn_sched_group_barrier(0x100, 1, 0);
429 __builtin_amdgcn_sched_group_barrier(0x008, 1, 0);
430 __builtin_amdgcn_sched_group_barrier(
431 0x100, num_ds_read_inst / num_ds_write_inst - 1, 0);
432 __builtin_amdgcn_sched_group_barrier(
433 0x008, num_mfma_inst / num_ds_write_inst - 3, 0);
438 __device__ constexpr
auto TailScheduler<2>()
441 constexpr
auto num_ds_read_inst =
445 constexpr
auto num_issue = num_ds_read_inst;
449 __builtin_amdgcn_sched_group_barrier(0x100, 1, 0);
450 __builtin_amdgcn_sched_group_barrier(
451 0x008, num_mfma_inst / num_ds_read_inst, 0);
458 template <
bool HasMainLoop,
462 typename ABlockTransfer,
463 typename AGridBuffer,
464 typename ABlockBuffer,
465 typename ABlockTransferStep,
468 typename BBlockTransfer,
469 typename BGridBuffer,
470 typename BBlockBuffer,
471 typename BBlockTransferStep,
472 typename CThreadBuffer>
473 __device__
void Run(
const AGridDesc& a_grid_desc,
474 const ABlockDesc& a_block_desc,
475 ABlockTransfer& a_blockwise_copy,
476 const AGridBuffer& a_grid_buf,
477 ABlockBuffer& a_block_buf,
478 const ABlockTransferStep& a_block_copy_step,
479 const BGridDesc& b_grid_desc,
480 const BBlockDesc& b_block_desc,
481 BBlockTransfer& b_blockwise_copy,
482 const BGridBuffer& b_grid_buf,
483 BBlockBuffer& b_block_buf,
484 const BBlockTransferStep& b_block_copy_step,
485 CThreadBuffer& c_thread_buf,
488 __builtin_amdgcn_sched_barrier(0);
489 auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatAB>(
491 auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatAB>(
504 a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
505 b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
507 a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
508 b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
510 a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(
I0));
511 b_blockwise_copy.RunWrite(b_block_desc, b_block_buf.At(
I0));
535 a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
536 b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
538 a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
539 b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
541 a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(
I1));
542 b_blockwise_copy.RunWrite(b_block_desc, b_block_buf.At(
I1));
545 a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
546 b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
548 a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
549 b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
552 c_thread_buf.Clear();
555 if constexpr(HasMainLoop)
573 a_block_buf.At(PongP1{}),
576 a_thread_bufs(PongP1{}));
580 b_block_buf.At(PongP1{}),
583 b_thread_bufs(PongP1{}));
588 a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(PingP1{}));
589 b_blockwise_copy.RunWrite(b_block_desc, b_block_buf.At(PingP1{}));
591 a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
592 b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
594 a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
595 b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
597 static_for<0, KRepeat, 1>{}([&](
auto k0) {
598 static_for<0, MRepeat, 1>{}([&](
auto m0) {
599 static_for<0, NRepeat, 1>{}([&](
auto n0) {
600 vector_type<FloatAB, KPack> a_thread_vec;
601 vector_type<FloatAB, KPack> b_thread_vec;
603 static_for<0, KPack, 1>{}([&](
auto ik) {
604 a_thread_vec.template AsType<FloatAB>()(ik) =
605 a_thread_bufs[PingP1{}][
Number<a_thread_desc_.CalculateOffset(
607 b_thread_vec.template AsType<FloatAB>()(ik) =
608 b_thread_bufs[PingP1{}][
Number<b_thread_desc_.CalculateOffset(
612 using mfma_input_type =
613 typename vector_type<FloatAB, xdlops_gemm.K1PerXdlops>::type;
616 c_thread_desc_.CalculateOffset(
make_tuple(m0, n0, 0));
619 a_thread_vec.template AsType<mfma_input_type>(),
620 b_thread_vec.template AsType<mfma_input_type>(),
621 c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
627 __builtin_amdgcn_sched_barrier(0);
630 using PingP2 = Number<1>;
631 using PongP2 = Number<0>;
637 static_for<0, KRepeat, 1>{}([&](
auto k) {
638 static_for<0, MRepeat, 1>{}([&](
auto m0) {
639 a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
640 make_tuple(m0, I0, I0, Number<k * AMmaKStride>{}),
641 a_block_buf.At(PongP2{}),
644 a_thread_bufs(PongP2{}));
645 static_for<0, NRepeat, 1>{}([&](
auto n0) {
646 b_thread_copy_.Run(b_block_desc_n0_n1_n2_k,
647 make_tuple(n0, I0, I0, Number<k * BMmaKStride>{}),
648 b_block_buf.At(PongP2{}),
651 b_thread_bufs(PongP2{}));
656 a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(PingP2{}));
657 b_blockwise_copy.RunWrite(b_block_desc, b_block_buf.At(PingP2{}));
659 a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
660 b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
662 a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
663 b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
665 static_for<0, KRepeat, 1>{}([&](
auto k0) {
666 static_for<0, MRepeat, 1>{}([&](
auto m0) {
667 static_for<0, NRepeat, 1>{}([&](
auto n0) {
668 vector_type<FloatAB, KPack> a_thread_vec;
669 vector_type<FloatAB, KPack> b_thread_vec;
671 static_for<0, KPack, 1>{}([&](
auto ik) {
672 a_thread_vec.template AsType<FloatAB>()(ik) =
673 a_thread_bufs[PingP2{}][
Number<a_thread_desc_.CalculateOffset(
675 b_thread_vec.template AsType<FloatAB>()(ik) =
676 b_thread_bufs[PingP2{}][
Number<b_thread_desc_.CalculateOffset(
680 using mfma_input_type =
681 typename vector_type<FloatAB, xdlops_gemm.K1PerXdlops>::type;
684 c_thread_desc_.CalculateOffset(
make_tuple(m0, n0, 0));
687 a_thread_vec.template AsType<mfma_input_type>(),
688 b_thread_vec.template AsType<mfma_input_type>(),
689 c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
695 __builtin_amdgcn_sched_barrier(0);
698 }
while(i < (num_loop - 3));
702 if constexpr(TailNum == 3)
704 using PingP1 = Number<0>;
705 using PongP1 = Number<1>;
711 static_for<0, KRepeat, 1>{}([&](
auto k) {
712 static_for<0, MRepeat, 1>{}([&](
auto m0) {
713 a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
714 make_tuple(m0, I0, I0, Number<k * AMmaKStride>{}),
715 a_block_buf.At(PongP1{}),
718 a_thread_bufs(PongP1{}));
719 static_for<0, NRepeat, 1>{}([&](
auto n0) {
720 b_thread_copy_.Run(b_block_desc_n0_n1_n2_k,
721 make_tuple(n0, I0, I0, Number<k * BMmaKStride>{}),
722 b_block_buf.At(PongP1{}),
725 b_thread_bufs(PongP1{}));
730 a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(PingP1{}));
731 b_blockwise_copy.RunWrite(b_block_desc, b_block_buf.At(PingP1{}));
733 static_for<0, KRepeat, 1>{}([&](
auto k0) {
734 static_for<0, MRepeat, 1>{}([&](
auto m0) {
735 static_for<0, NRepeat, 1>{}([&](
auto n0) {
736 vector_type<FloatAB, KPack> a_thread_vec;
737 vector_type<FloatAB, KPack> b_thread_vec;
739 static_for<0, KPack, 1>{}([&](
auto ik) {
740 a_thread_vec.template AsType<FloatAB>()(ik) =
741 a_thread_bufs[PingP1{}][
Number<a_thread_desc_.CalculateOffset(
743 b_thread_vec.template AsType<FloatAB>()(ik) =
744 b_thread_bufs[PingP1{}][
Number<b_thread_desc_.CalculateOffset(
748 using mfma_input_type =
749 typename vector_type<FloatAB, xdlops_gemm.K1PerXdlops>::type;
752 c_thread_desc_.CalculateOffset(
make_tuple(m0, n0, 0));
754 xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
755 b_thread_vec.template AsType<mfma_input_type>(),
756 c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
762 __builtin_amdgcn_sched_barrier(0);
765 using PingP2 = Number<1>;
766 using PongP2 = Number<0>;
772 static_for<0, KRepeat, 1>{}([&](
auto k) {
773 static_for<0, MRepeat, 1>{}([&](
auto m0) {
774 a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
775 make_tuple(m0, I0, I0, Number<k * AMmaKStride>{}),
776 a_block_buf.At(PongP2{}),
779 a_thread_bufs(PongP2{}));
780 static_for<0, NRepeat, 1>{}([&](
auto n0) {
781 b_thread_copy_.Run(b_block_desc_n0_n1_n2_k,
782 make_tuple(n0, I0, I0, Number<k * BMmaKStride>{}),
783 b_block_buf.At(PongP2{}),
786 b_thread_bufs(PongP2{}));
791 static_for<0, KRepeat, 1>{}([&](
auto k0) {
792 static_for<0, MRepeat, 1>{}([&](
auto m0) {
793 static_for<0, NRepeat, 1>{}([&](
auto n0) {
794 vector_type<FloatAB, KPack> a_thread_vec;
795 vector_type<FloatAB, KPack> b_thread_vec;
797 static_for<0, KPack, 1>{}([&](
auto ik) {
798 a_thread_vec.template AsType<FloatAB>()(ik) =
799 a_thread_bufs[PingP2{}][
Number<a_thread_desc_.CalculateOffset(
801 b_thread_vec.template AsType<FloatAB>()(ik) =
802 b_thread_bufs[PingP2{}][
Number<b_thread_desc_.CalculateOffset(
806 using mfma_input_type =
807 typename vector_type<FloatAB, xdlops_gemm.K1PerXdlops>::type;
810 c_thread_desc_.CalculateOffset(
make_tuple(m0, n0, 0));
812 xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
813 b_thread_vec.template AsType<mfma_input_type>(),
814 c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
820 __builtin_amdgcn_sched_barrier(0);
822 static_for<0, KRepeat, 1>{}([&](
auto k) {
823 static_for<0, MRepeat, 1>{}([&](
auto m0) {
824 static_for<0, NRepeat, 1>{}([&](
auto n0) {
825 vector_type<FloatAB, KPack> a_thread_vec;
826 vector_type<FloatAB, KPack> b_thread_vec;
828 static_for<0, KPack, 1>{}([&](
auto ik) {
829 a_thread_vec.template AsType<FloatAB>()(ik) =
830 a_thread_bufs[PongP2{}][
Number<a_thread_desc_.CalculateOffset(
832 b_thread_vec.template AsType<FloatAB>()(ik) =
833 b_thread_bufs[PongP2{}][
Number<b_thread_desc_.CalculateOffset(
837 using mfma_input_type =
838 typename vector_type<FloatAB, xdlops_gemm.K1PerXdlops>::type;
841 c_thread_desc_.CalculateOffset(
make_tuple(m0, n0, 0));
843 xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
844 b_thread_vec.template AsType<mfma_input_type>(),
845 c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
851 __builtin_amdgcn_sched_group_barrier(0x008, 64, 0);
852 __builtin_amdgcn_sched_barrier(0);
854 else if constexpr(TailNum == 2)
856 using PingP1 = Number<0>;
857 using PongP1 = Number<1>;
863 static_for<0, KRepeat, 1>{}([&](
auto k) {
864 static_for<0, MRepeat, 1>{}([&](
auto m0) {
865 a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
866 make_tuple(m0, I0, I0, Number<k * AMmaKStride>{}),
867 a_block_buf.At(PongP1{}),
870 a_thread_bufs(PongP1{}));
871 static_for<0, NRepeat, 1>{}([&](
auto n0) {
872 b_thread_copy_.Run(b_block_desc_n0_n1_n2_k,
873 make_tuple(n0, I0, I0, Number<k * BMmaKStride>{}),
874 b_block_buf.At(PongP1{}),
877 b_thread_bufs(PongP1{}));
882 static_for<0, KRepeat, 1>{}([&](
auto k0) {
883 static_for<0, MRepeat, 1>{}([&](
auto m0) {
884 static_for<0, NRepeat, 1>{}([&](
auto n0) {
885 vector_type<FloatAB, KPack> a_thread_vec;
886 vector_type<FloatAB, KPack> b_thread_vec;
888 static_for<0, KPack, 1>{}([&](
auto ik) {
889 a_thread_vec.template AsType<FloatAB>()(ik) =
890 a_thread_bufs[PingP1{}][
Number<a_thread_desc_.CalculateOffset(
892 b_thread_vec.template AsType<FloatAB>()(ik) =
893 b_thread_bufs[PingP1{}][
Number<b_thread_desc_.CalculateOffset(
897 using mfma_input_type =
898 typename vector_type<FloatAB, xdlops_gemm.K1PerXdlops>::type;
901 c_thread_desc_.CalculateOffset(
make_tuple(m0, n0, 0));
903 xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
904 b_thread_vec.template AsType<mfma_input_type>(),
905 c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
911 __builtin_amdgcn_sched_barrier(0);
914 using PingP2 = Number<1>;
919 static_for<0, KRepeat, 1>{}([&](
auto k0) {
920 static_for<0, MRepeat, 1>{}([&](
auto m0) {
921 static_for<0, NRepeat, 1>{}([&](
auto n0) {
922 vector_type<FloatAB, KPack> a_thread_vec;
923 vector_type<FloatAB, KPack> b_thread_vec;
925 static_for<0, KPack, 1>{}([&](
auto ik) {
926 a_thread_vec.template AsType<FloatAB>()(ik) =
927 a_thread_bufs[PingP2{}][
Number<a_thread_desc_.CalculateOffset(
929 b_thread_vec.template AsType<FloatAB>()(ik) =
930 b_thread_bufs[PingP2{}][
Number<b_thread_desc_.CalculateOffset(
934 using mfma_input_type =
935 typename vector_type<FloatAB, xdlops_gemm.K1PerXdlops>::type;
938 c_thread_desc_.CalculateOffset(
make_tuple(m0, n0, 0));
940 xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
941 b_thread_vec.template AsType<mfma_input_type>(),
942 c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
948 __builtin_amdgcn_sched_group_barrier(0x008, 64, 0);
949 __builtin_amdgcn_sched_barrier(0);
960 Number<KPack>{}, Number<KRepeat * MRepeat * KPack>{}, Number<MRepeat * KPack>{}, I1));
966 Number<KPack>{}, Number<KRepeat * NRepeat * KPack>{}, Number<NRepeat * KPack>{}, I1));
974 decltype(a_block_desc_m0_m1_m2_k),
975 decltype(a_thread_desc_),
984 decltype(b_block_desc_n0_n1_n2_k),
985 decltype(b_thread_desc_),
CK_TILE_DEVICE void block_sync_lds()
Definition: arch.hpp:190
typename detail::StaticallyIndexedArrayImpl< T, N >::type StaticallyIndexedArray
Definition: statically_indexed_array.hpp:45
__host__ constexpr __device__ auto make_multi_index(Xs &&... xs)
Definition: array_multi_index.hpp:15
__host__ constexpr __device__ auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition: tensor_descriptor_helper.hpp:49
__host__ constexpr __device__ auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition: tensor_descriptor_helper.hpp:101
__host__ constexpr __device__ auto make_merge_transform(const LowLengths &low_lengths)
Definition: multi_index_transform_helper.hpp:55
__host__ constexpr __device__ auto make_single_stage_tensor_adaptor(const Transforms &transforms, LowerDimensionOldTopIdss, UpperDimensionNewTopIdss)
Definition: tensor_adaptor.hpp:425
constexpr detail::ignore_t ignore
Definition: ignore.hpp:20
__host__ constexpr __device__ auto make_pass_through_transform(const LowLength &low_length)
Definition: multi_index_transform_helper.hpp:12
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
__host__ constexpr __device__ auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition: multi_index_transform_helper.hpp:90
int32_t index_t
Definition: ck.hpp:298
__host__ constexpr __device__ auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition: tensor_descriptor.hpp:319
__device__ void block_sync_lds()
Definition: synchronization.hpp:10
integral_constant< index_t, N > Number
Definition: number.hpp:12
Definition: blockwise_gemm_pipeline_xdlops.hpp:34
static constexpr index_t B_LDS_Write_Inst_Num
Definition: blockwise_gemm_pipeline_xdlops.hpp:46
static constexpr index_t A_LDS_Read_Inst_Num
Definition: blockwise_gemm_pipeline_xdlops.hpp:49
static constexpr index_t B_LDS_Read_Inst_Num
Definition: blockwise_gemm_pipeline_xdlops.hpp:51
static constexpr index_t A_LDS_Write_Inst_Num
Definition: blockwise_gemm_pipeline_xdlops.hpp:44
static constexpr index_t C_MFMA_Inst_Num
Definition: blockwise_gemm_pipeline_xdlops.hpp:54
static constexpr index_t A_Buffer_Load_Inst_Num
Definition: blockwise_gemm_pipeline_xdlops.hpp:39
static constexpr index_t WaveSize
Definition: blockwise_gemm_pipeline_xdlops.hpp:37
static constexpr index_t B_Buffer_Load_Inst_Num
Definition: blockwise_gemm_pipeline_xdlops.hpp:41
static constexpr auto Print()
Definition: blockwise_gemm_pipeline_xdlops.hpp:57
static constexpr index_t WaveNumN
Definition: blockwise_gemm_pipeline_xdlops.hpp:36
static constexpr index_t WaveNumM
Definition: blockwise_gemm_pipeline_xdlops.hpp:35
Definition: blockwise_gemm_pipeline_xdlops.hpp:103
static constexpr auto I1
Definition: blockwise_gemm_pipeline_xdlops.hpp:105
__host__ static constexpr __device__ auto MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2(const CGridDesc_G_M_N &c_grid_desc_g_m_n)
Definition: blockwise_gemm_pipeline_xdlops.hpp:357
static constexpr index_t MWaves
Definition: blockwise_gemm_pipeline_xdlops.hpp:111
__host__ static constexpr __device__ auto GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4()
Definition: blockwise_gemm_pipeline_xdlops.hpp:255
__host__ static constexpr __device__ auto MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(const CGridDesc_M_N &c_grid_desc_m_n)
Definition: blockwise_gemm_pipeline_xdlops.hpp:340
static constexpr index_t A_K1
Definition: blockwise_gemm_pipeline_xdlops.hpp:119
static constexpr index_t A_K0
Definition: blockwise_gemm_pipeline_xdlops.hpp:117
static constexpr auto b_thread_desc_
Definition: blockwise_gemm_pipeline_xdlops.hpp:963
__host__ static constexpr __device__ auto GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2()
Definition: blockwise_gemm_pipeline_xdlops.hpp:310
static constexpr __device__ auto HotLoopScheduler()
Definition: blockwise_gemm_pipeline_xdlops.hpp:375
static constexpr index_t WaveSize
Definition: blockwise_gemm_pipeline_xdlops.hpp:115
BThreadCopy b_thread_copy_
Definition: blockwise_gemm_pipeline_xdlops.hpp:993
decltype(CalculateAThreadOriginDataIndex()) Tuple4
Definition: blockwise_gemm_pipeline_xdlops.hpp:234
static constexpr auto I0
Definition: blockwise_gemm_pipeline_xdlops.hpp:104
static constexpr AMmaTileDesc a_block_desc_m0_m1_m2_k
Definition: blockwise_gemm_pipeline_xdlops.hpp:455
static __device__ auto CalculateBThreadOriginDataIndex()
Definition: blockwise_gemm_pipeline_xdlops.hpp:179
AThreadCopy a_thread_copy_
Definition: blockwise_gemm_pipeline_xdlops.hpp:992
__host__ static constexpr __device__ auto GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2()
Definition: blockwise_gemm_pipeline_xdlops.hpp:269
ThisThreadBlock< BlockSize > ThisThreadBlock
Definition: blockwise_gemm_pipeline_xdlops.hpp:109
static constexpr BMmaTileDesc b_block_desc_n0_n1_n2_k
Definition: blockwise_gemm_pipeline_xdlops.hpp:456
static constexpr index_t KRepeat
Definition: blockwise_gemm_pipeline_xdlops.hpp:126
__host__ constexpr __device__ auto & GetCThreadBuffer()
Definition: blockwise_gemm_pipeline_xdlops.hpp:154
static constexpr auto I3
Definition: blockwise_gemm_pipeline_xdlops.hpp:107
static constexpr index_t B_K1
Definition: blockwise_gemm_pipeline_xdlops.hpp:120
static constexpr auto I2
Definition: blockwise_gemm_pipeline_xdlops.hpp:106
static constexpr index_t B_K0
Definition: blockwise_gemm_pipeline_xdlops.hpp:118
__host__ static constexpr __device__ auto GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2()
Definition: blockwise_gemm_pipeline_xdlops.hpp:323
static __device__ auto CalculateAThreadOriginDataIndex()
Definition: blockwise_gemm_pipeline_xdlops.hpp:168
static constexpr index_t KPerThread
Definition: blockwise_gemm_pipeline_xdlops.hpp:125
static __device__ auto GetWaveIdx()
Definition: blockwise_gemm_pipeline_xdlops.hpp:156
static constexpr auto a_thread_desc_
Definition: blockwise_gemm_pipeline_xdlops.hpp:957
__host__ static constexpr __device__ auto GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4()
Definition: blockwise_gemm_pipeline_xdlops.hpp:296
static constexpr index_t NWaves
Definition: blockwise_gemm_pipeline_xdlops.hpp:112
StaticBufferTupleOfVector< AddressSpaceEnum::Vgpr, FloatAcc, MRepeat *NRepeat, xdlops_gemm.GetRegSizePerXdlops(), true > c_thread_buf_
Definition: blockwise_gemm_pipeline_xdlops.hpp:145
__host__ static constexpr __device__ auto GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2()
Definition: blockwise_gemm_pipeline_xdlops.hpp:282
static constexpr auto xdlops_gemm
Definition: blockwise_gemm_pipeline_xdlops.hpp:122
static __device__ auto CalculateCThreadOriginDataIndex8D(Number< m0 >, Number< n0 >, Number< xdlops_i >, Number< blk_i >)
Definition: blockwise_gemm_pipeline_xdlops.hpp:221
__device__ void Run(const AGridDesc &a_grid_desc, const ABlockDesc &a_block_desc, ABlockTransfer &a_blockwise_copy, const AGridBuffer &a_grid_buf, ABlockBuffer &a_block_buf, const ABlockTransferStep &a_block_copy_step, const BGridDesc &b_grid_desc, const BBlockDesc &b_block_desc, BBlockTransfer &b_blockwise_copy, const BGridBuffer &b_grid_buf, BBlockBuffer &b_block_buf, const BBlockTransferStep &b_block_copy_step, CThreadBuffer &c_thread_buf, index_t num_loop) const
Definition: blockwise_gemm_pipeline_xdlops.hpp:473
__host__ __device__ BlockwiseGemmXdlops_pipeline_v4(Tuple4 a_origin=CalculateAThreadOriginDataIndex(), Tuple4 b_origin=CalculateBThreadOriginDataIndex())
Definition: blockwise_gemm_pipeline_xdlops.hpp:237
static constexpr __device__ auto TailScheduler()
Definition: blockwise_gemm_pipeline_xdlops.hpp:406
static __device__ auto CalculateCThreadOriginDataIndex(Number< m0 >, Number< n0 >, Number< xdlops_i >, Number< blk_i >)
Definition: blockwise_gemm_pipeline_xdlops.hpp:192
Definition: sequence.hpp:43
Definition: static_buffer.hpp:75
static __device__ index_t GetThreadId()
Definition: thread_group.hpp:19
static constexpr __device__ index_t GetNumOfThread()
Definition: thread_group.hpp:15
ck::ThreadwiseTensorSliceTransfer_v4< FloatAB, FloatAB, decltype(a_block_desc_m0_m1_m2_k), decltype(a_thread_desc_), Sequence< 1, 1, 1, KPack >, Sequence< 0, 1, 2, 3 >, 3, A_K1, A_K1 >
__device__ void Run(const SrcDesc &, const SrcRefToOriginDisplacement &, const SrcBuffer &src_buf, const DstDesc &, const DstOriginIdx &, DstBuffer &dst_buf) const
Definition: threadwise_tensor_slice_transfer.hpp:1293
Definition: xdlops_gemm.hpp:1711
Definition: integral_constant.hpp:20
Definition: functional2.hpp:33