23 template <
typename Gr
idwiseGemm>
25 #if CK_USE_LAUNCH_BOUNDS
29 const typename GridwiseGemm::FloatAB* p_b_grid,
30 typename GridwiseGemm::FloatC* p_c_grid,
38 typename GridwiseGemm::Block2CTileMap block_mapping)
40 #if defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx94__) || defined(__gfx12__)
41 if constexpr(GridwiseGemm::template IsValidCompilationParameter<>())
43 constexpr
index_t shared_size = GridwiseGemm::GetSharedMemoryNumberOfByte();
45 __shared__
uint8_t p_shared[shared_size];
47 GridwiseGemm::Run(p_a_grid,
58 static_cast<void*
>(p_shared));
76 typename Block2CTileMap_,
83 typename AElementwiseOperation,
84 typename BElementwiseOperation,
85 typename CElementwiseOperation,
94 typename ABlockTransferThreadClusterLengths_K0_M_K1,
95 typename ABlockTransferThreadClusterArrangeOrder,
96 typename ABlockTransferSrcAccessOrder,
97 index_t ABlockTransferSrcVectorDim,
98 index_t ABlockTransferSrcScalarPerVector,
99 index_t ABlockTransferDstScalarPerVector_K1,
100 bool AThreadTransferSrcResetCoordinateAfterRun,
102 typename BBlockTransferThreadClusterLengths_K0_N_K1,
103 typename BBlockTransferThreadClusterArrangeOrder,
104 typename BBlockTransferSrcAccessOrder,
105 index_t BBlockTransferSrcVectorDim,
106 index_t BBlockTransferSrcScalarPerVector,
107 index_t BBlockTransferDstScalarPerVector_K1,
108 bool BThreadTransferSrcResetCoordinateAfterRun,
110 index_t CShuffleMRepeatPerShuffle,
111 index_t CShuffleNRepeatPerShuffle,
112 index_t CBlockTransferScalarPerVector_NWaveNPerXDL,
113 typename CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock>
127 static constexpr
auto M01 = 1;
128 static constexpr
auto N01 = 1;
179 std::cout <<
"arg {" <<
"M:" <<
M <<
", " <<
"N:" <<
N <<
", " <<
"K:" <<
K <<
", "
194 __host__ __device__
static auto
199 const auto a_grid_desc_m_k = [&]() {
223 __host__ __device__
static auto
228 const auto b_grid_desc_k_n = [&]() {
252 __host__ __device__
static auto
255 const auto c_grid_desc_m_n = [&]() {
291 constexpr
auto max_lds_align =
K1;
297 constexpr
auto a_block_space_size_aligned =
300 constexpr
auto b_block_space_size_aligned =
303 constexpr
auto c_block_size =
306 return math::max((a_block_space_size_aligned + b_block_space_size_aligned) *
319 if(karg.K % ABlockTransferSrcScalarPerVector != 0)
324 if(karg.M % ABlockTransferSrcScalarPerVector != 0)
330 if(karg.N % BBlockTransferSrcScalarPerVector != 0)
335 if(karg.K % BBlockTransferSrcScalarPerVector != 0)
341 if(karg.N % CBlockTransferScalarPerVector_NWaveNPerXDL != 0)
346 if(karg.M % CBlockTransferScalarPerVector_NWaveNPerXDL != 0)
355 const bool has_main_k0_block_loop = K0 > K0PerBlock;
357 return has_main_k0_block_loop;
360 template <
typename CGr
idDesc>
361 __host__ __device__
static constexpr
auto
364 const auto M = c_m_n_grid_desc.GetLength(
I0);
365 const auto N = c_m_n_grid_desc.GetLength(
I1);
367 const auto MBlock = M / MPerBlock;
368 const auto NBlock = N / NPerBlock;
379 template <
typename CGr
idDesc>
384 c_m_n_grid_desc, 8, KBatch);
387 __host__ __device__
static constexpr
auto
390 constexpr
index_t MWave = MPerBlock / (MRepeat * MPerXdl);
391 constexpr
index_t NWave = NRepeat * NPerXdl == 0 ? 1 : NPerBlock / (NRepeat * NPerXdl);
400 __host__ __device__
static constexpr
auto
403 constexpr
index_t MWave = MPerBlock / (MRepeat * MPerXdl);
404 constexpr
index_t NWave = NRepeat * NPerXdl == 0 ? 1 : NPerBlock / (NRepeat * NPerXdl);
409 Number<NRepeat / CShuffleNRepeatPerShuffle>{},
417 constexpr
auto NPerBlockPow2 = math::next_power_of_two<NPerBlock>();
418 constexpr
auto NPerBlockReduction =
419 NPerBlockPow2 / CBlockTransferScalarPerVector_NWaveNPerXDL;
420 constexpr
auto MPerBlockReduction =
421 (BlockSize + NPerBlockReduction - 1) / NPerBlockReduction;
427 const auto c_partial_acc_block_m_n = [&]() {
439 return c_partial_acc_block_m_n;
455 void* __restrict__ p_shared_block)
460 uint32_t pad_m = (m + MPerBlock - 1) / MPerBlock * MPerBlock;
461 uint32_t pad_n = (n + NPerBlock - 1) / NPerBlock * NPerBlock;
471 const auto c_grid_desc_mblock_mperblock_nblock_nperblock =
473 const AElementwiseOperation a_element_op = AElementwiseOperation{};
474 const BElementwiseOperation b_element_op = BElementwiseOperation{};
475 const CElementwiseOperation c_element_op = CElementwiseOperation{};
477 const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
478 p_a_grid, a_k0_m_k1_grid_desc.GetElementSpaceSize());
479 const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
480 p_b_grid, b_k0_n_k1_grid_desc.GetElementSpaceSize());
481 auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
482 p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
485 constexpr
auto max_lds_align =
K1;
493 auto blockwise_gemm =
498 decltype(a_block_desc_k0_m_k1),
499 decltype(b_block_desc_k0_n_k1),
506 auto c_thread_buf = blockwise_gemm.GetCThreadBuffer();
509 constexpr
auto a_block_space_size =
513 FloatAB* p_b_block =
static_cast<FloatAB*
>(p_shared_block) + a_block_space_size;
515 constexpr
auto a_block_slice_copy_step =
make_multi_index(K0PerBlock, 0, 0);
516 constexpr
auto b_block_slice_copy_step =
make_multi_index(K0PerBlock, 0, 0);
518 auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
519 p_a_block, a_block_desc_k0_m_k1.GetElementSpaceSize());
520 auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
521 p_b_block, b_block_desc_k0_n_k1.GetElementSpaceSize());
526 uint32_t block_idx = block_mapping.get_block_idx();
527 bool is_sk_block = block_idx < block_mapping.sk_num_blocks;
528 bool is_dp_block = block_idx >= block_mapping.dp_start_block_idx &&
529 block_idx < block_mapping.reduction_start_block_idx;
530 bool is_reduction_block = block_idx >= block_mapping.reduction_start_block_idx;
531 bool is_padding_block = block_idx >= block_mapping.sk_num_blocks &&
532 block_idx < block_mapping.dp_start_block_idx;
534 block_mapping.get_block_itr(block_idx, iter_start, iter_end);
535 uint32_t total_iter_length = iter_end - iter_start;
541 reinterpret_cast<uint32_t*
>(
reinterpret_cast<char*
>(p_workspace) +
542 block_mapping.get_workspace_size_for_acc(
sizeof(
FloatAcc)));
546 if(is_reduction_block)
551 const auto reduce_thread_cluster_idx =
553 const auto thread_m_cluster_id = reduce_thread_cluster_idx[
I0];
554 const auto thread_n_cluster_id = reduce_thread_cluster_idx[
I1];
556 constexpr
auto MReduceIters =
560 cluster_length_reduce.At(
I1) *
571 0, cluster_length_reduce.At(
I1) * CBlockTransferScalarPerVector_NWaveNPerXDL);
572 constexpr
auto partial_acc_load_step_n_reverse =
574 -1 * cluster_length_reduce.At(
I1).value * (NReduceIters - 1) *
575 CBlockTransferScalarPerVector_NWaveNPerXDL);
576 constexpr
auto partial_acc_load_step_m =
583 cluster_length_reduce.At(
I1) * CBlockTransferScalarPerVector_NWaveNPerXDL);
584 constexpr
auto partial_acc_store_step_n_reverse =
588 -1 * cluster_length_reduce.At(
I1).value * (NReduceIters - 1) *
589 CBlockTransferScalarPerVector_NWaveNPerXDL);
590 constexpr
auto partial_acc_store_step_m =
595 CBlockTransferScalarPerVector_NWaveNPerXDL,
600 CBlockTransferScalarPerVector_NWaveNPerXDL,
605 auto reduction_idx = blockIdx.x - block_mapping.reduction_start_block_idx;
606 auto spatial_idx = block_mapping.tile_to_spatial(reduction_idx, m, n);
611 block_mapping.get_acc_buffer_offset_from_tile(reduction_idx);
613 block_mapping.get_acc_buffer_offset_from_tile(reduction_idx + 1);
618 decltype(c_partial_acc_block_m_n),
619 decltype(acc_thread_buf_load_desc),
623 CBlockTransferScalarPerVector_NWaveNPerXDL,
626 >{c_partial_acc_block_m_n,
628 thread_n_cluster_id *
629 CBlockTransferScalarPerVector_NWaveNPerXDL)};
634 decltype(acc_thread_buf_store_desc),
635 decltype(c_grid_desc_mblock_mperblock_nblock_nperblock),
636 CElementwiseOperation,
640 CBlockTransferScalarPerVector_NWaveNPerXDL,
644 >{c_grid_desc_mblock_mperblock_nblock_nperblock,
647 __builtin_amdgcn_readfirstlane(spatial_idx[
I1]),
648 thread_n_cluster_id *
649 CBlockTransferScalarPerVector_NWaveNPerXDL),
650 CElementwiseOperation{}};
653 wg_barrier.
wait_eq(reduction_idx, tile_acc_offset_end - tile_acc_offset_start);
656 if(threadIdx.x == 0) {
657 printf(
"bid:%d, rid:%d, os:%d,%d, spatial:%d,%d\n",
static_cast<int>(blockIdx.x),
658 reduction_idx, __builtin_amdgcn_readfirstlane(tile_acc_offset_start), __builtin_amdgcn_readfirstlane(tile_acc_offset_end),
659 __builtin_amdgcn_readfirstlane(spatial_idx[
I0]),
660 __builtin_amdgcn_readfirstlane(spatial_idx[
I1]));
664 using Accumulation = ck::detail::
667 for(
int i_m = 0; i_m < MReduceIters; i_m++)
671 for(
auto i = tile_acc_offset_start; i < tile_acc_offset_end; i++)
673 auto c_partial_acc_buf =
676 reinterpret_cast<FloatAcc*
>(p_workspace) +
677 i * c_partial_acc_block_m_n.GetElementSpaceSize(),
678 c_partial_acc_block_m_n.GetElementSpaceSize());
680 acc_load.Run(c_partial_acc_block_m_n,
682 acc_thread_buf_load_desc,
688 constexpr
auto offset =
689 acc_thread_buf_load_desc.CalculateOffset(
696 if(thread_n_cluster_id * CBlockTransferScalarPerVector_NWaveNPerXDL <
699 acc_store.Run(acc_thread_buf_store_desc,
702 c_grid_desc_mblock_mperblock_nblock_nperblock,
705 if constexpr(NReduceIters != 1)
707 if constexpr(i_n_reduce != (NReduceIters - 1))
709 acc_load.MoveSrcSliceWindow(c_partial_acc_block_m_n,
710 partial_acc_load_step_n);
711 acc_store.MoveDstSliceWindow(
712 c_grid_desc_mblock_mperblock_nblock_nperblock,
713 partial_acc_store_step_n);
717 acc_load.MoveSrcSliceWindow(c_partial_acc_block_m_n,
718 partial_acc_load_step_n_reverse);
719 acc_store.MoveDstSliceWindow(
720 c_grid_desc_mblock_mperblock_nblock_nperblock,
721 partial_acc_store_step_n_reverse);
726 acc_load.MoveSrcSliceWindow(c_partial_acc_block_m_n,
727 partial_acc_load_step_m);
728 acc_store.MoveDstSliceWindow(c_grid_desc_mblock_mperblock_nblock_nperblock,
729 partial_acc_store_step_m);
738 (block_mapping.get_acc_buffer_offset_from_block(block_idx + 1) - 1) * MPerBlock *
743 uint32_t current_iter_length = __builtin_amdgcn_readfirstlane(
744 block_mapping.get_current_iter_length(iter_start, iter_end, total_iter_length));
746 block_mapping.get_tile_idx_with_offset(iter_end - 1, tile_idx, iter_offset);
747 iter_offset = __builtin_amdgcn_readfirstlane(iter_offset - current_iter_length + 1);
748 auto spatial_idx = block_mapping.tile_to_spatial(tile_idx, m, n);
750 const index_t m_block_data_idx_on_grid =
751 __builtin_amdgcn_readfirstlane(spatial_idx[
I0] * MPerBlock);
753 const index_t n_block_data_idx_on_grid =
754 __builtin_amdgcn_readfirstlane(spatial_idx[
I1] * NPerBlock);
756 const index_t k0_block_data_idx_on_grid =
757 __builtin_amdgcn_readfirstlane(iter_offset * K0PerBlock);
760 auto a_blockwise_copy =
762 AElementwiseOperation,
766 ABlockTransferThreadClusterLengths_K0_M_K1,
767 ABlockTransferThreadClusterArrangeOrder,
770 decltype(a_k0_m_k1_grid_desc),
771 decltype(a_block_desc_k0_m_k1),
772 ABlockTransferSrcAccessOrder,
774 ABlockTransferSrcVectorDim,
776 ABlockTransferSrcScalarPerVector,
777 ABlockTransferDstScalarPerVector_K1,
780 AThreadTransferSrcResetCoordinateAfterRun,
785 a_block_desc_k0_m_k1,
790 auto b_blockwise_copy =
792 BElementwiseOperation,
796 BBlockTransferThreadClusterLengths_K0_N_K1,
797 BBlockTransferThreadClusterArrangeOrder,
800 decltype(b_k0_n_k1_grid_desc),
801 decltype(b_block_desc_k0_n_k1),
802 BBlockTransferSrcAccessOrder,
804 BBlockTransferSrcVectorDim,
806 BBlockTransferSrcScalarPerVector,
807 BBlockTransferDstScalarPerVector_K1,
810 BThreadTransferSrcResetCoordinateAfterRun,
815 b_block_desc_k0_n_k1,
819 const index_t num_k_block_main_loop = current_iter_length;
821 gridwise_gemm_pipeline.Run(a_k0_m_k1_grid_desc,
822 a_block_desc_k0_m_k1,
826 a_block_slice_copy_step,
828 b_block_desc_k0_n_k1,
832 b_block_slice_copy_step,
835 num_k_block_main_loop);
839 constexpr
index_t MWave = MPerBlock / (MRepeat * MPerXdl);
840 constexpr
index_t NWave = NPerBlock / (NRepeat * NPerXdl);
842 constexpr
auto c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc =
843 blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
845 constexpr
auto c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc =
846 blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
848 constexpr
auto M0 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(
I0);
849 constexpr
auto N0 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(
I1);
850 constexpr
auto M1 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(
I2);
851 constexpr
auto N1 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(
I3);
852 constexpr
auto M2 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(
I4);
853 constexpr
auto M3 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(
I5);
854 constexpr
auto M4 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(
I6);
855 constexpr
auto N2 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(
I7);
857 constexpr
auto c_block_desc_mblock_mpershuffle_nblock_npershuffle =
860 constexpr
auto c_block_desc_mshuffle_mpershuffle_nshuffle_npershuffle =
863 auto c_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
865 c_block_desc_mblock_mpershuffle_nblock_npershuffle.GetElementSpaceSize());
867 auto c_partial_acc_buf =
868 make_dynamic_buffer<AddressSpaceEnum::Global, AmdBufferCoherenceEnum::GLC>(
869 reinterpret_cast<FloatAcc*
>(p_workspace) + block_acc_offset,
870 c_block_desc_mshuffle_mpershuffle_nshuffle_npershuffle
871 .GetElementSpaceSize());
874 c_block_desc_mblock_mpershuffle_nblock_npershuffle,
895 const auto c_thread_mtx_on_block =
896 blockwise_gemm.CalculateCThreadOriginDataIndex(
I0,
I0,
I0,
I0);
898 const index_t m_thread_data_on_block = c_thread_mtx_on_block[
I0];
899 const index_t n_thread_data_on_block = c_thread_mtx_on_block[
I1];
901 const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor =
907 const auto m_thread_data_on_block_idx =
908 m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex(
911 const auto n_thread_data_on_block_to_n0_n1_n2_adaptor =
917 const auto n_thread_data_on_block_idx =
918 n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex(
925 decltype(c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc),
926 decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2),
929 CShuffleNRepeatPerShuffle,
941 true>{c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
944 m_thread_data_on_block_idx[
I1],
945 n_thread_data_on_block_idx[
I1],
946 m_thread_data_on_block_idx[
I2],
947 m_thread_data_on_block_idx[
I3],
948 m_thread_data_on_block_idx[
I4],
949 n_thread_data_on_block_idx[
I2]),
955 CElementwiseOperation,
958 CShuffleMRepeatPerShuffle * MWave * MPerXdl,
960 CShuffleNRepeatPerShuffle * NWave * NPerXdl>,
961 CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
965 decltype(c_block_desc_mblock_mpershuffle_nblock_npershuffle),
966 decltype(c_grid_desc_mblock_mperblock_nblock_nperblock),
969 CBlockTransferScalarPerVector_NWaveNPerXDL,
972 {c_block_desc_mblock_mpershuffle_nblock_npershuffle,
974 c_grid_desc_mblock_mperblock_nblock_nperblock,
977 __builtin_amdgcn_readfirstlane(spatial_idx[
I1]),
984 CElementwiseOperation,
987 CShuffleMRepeatPerShuffle * MWave * MPerXdl,
989 CShuffleNRepeatPerShuffle * NWave * NPerXdl>,
990 CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
994 decltype(c_block_desc_mblock_mpershuffle_nblock_npershuffle),
995 decltype(c_block_desc_mshuffle_mpershuffle_nshuffle_npershuffle),
998 CBlockTransferScalarPerVector_NWaveNPerXDL,
1003 {c_block_desc_mblock_mpershuffle_nblock_npershuffle,
1005 c_block_desc_mshuffle_mpershuffle_nshuffle_npershuffle,
1009 constexpr
auto mxdlperwave_forward_step =
1011 constexpr
auto nxdlperwave_forward_step =
1013 constexpr
auto nxdlperwave_backward_step =
1017 constexpr
auto mxdlperwave = mxdlperwave_iter;
1020 constexpr
bool nxdlperwave_forward_sweep =
1021 (mxdlperwave % (2 * CShuffleMRepeatPerShuffle) == 0);
1023 constexpr
index_t nxdlperwave_value =
1024 nxdlperwave_forward_sweep
1026 : (NRepeat - nxdlperwave_iter - CShuffleNRepeatPerShuffle);
1034 c_thread_copy_vgpr_to_lds.Run(
1035 c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc,
1038 c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
1044 c_block_copy_lds_to_global.SetSrcSliceOrigin(
1045 c_block_desc_mblock_mpershuffle_nblock_npershuffle,
1050 c_block_copy_lds_to_global.template
Run<decltype(c_block_buf),
1051 decltype(c_grid_buf),
1053 c_block_desc_mblock_mpershuffle_nblock_npershuffle,
1055 c_grid_desc_mblock_mperblock_nblock_nperblock,
1057 else if(is_sk_block)
1059 if constexpr(Block2CTileMap::ReductionStrategy ==
1063 c_block_copy_lds_to_partial_acc.SetSrcSliceOrigin(
1064 c_block_desc_mblock_mpershuffle_nblock_npershuffle,
1067 c_block_copy_lds_to_partial_acc.SetDstSliceOrigin(
1068 c_block_desc_mshuffle_mpershuffle_nshuffle_npershuffle,
1069 make_tuple(mxdlperwave.value, 0, nxdlperwave.value, 0));
1071 c_block_copy_lds_to_partial_acc
1072 .template
Run<decltype(c_block_buf),
1073 decltype(c_partial_acc_buf),
1075 c_block_desc_mblock_mpershuffle_nblock_npershuffle,
1077 c_block_desc_mshuffle_mpershuffle_nshuffle_npershuffle,
1080 else if constexpr(Block2CTileMap::ReductionStrategy ==
1083 c_block_copy_lds_to_global
1084 .template
Run<decltype(c_block_buf),
1085 decltype(c_grid_buf),
1087 c_block_desc_mblock_mpershuffle_nblock_npershuffle,
1089 c_grid_desc_mblock_mperblock_nblock_nperblock,
1095 if constexpr(nxdlperwave_forward_sweep &&
1096 (nxdlperwave < NRepeat - CShuffleNRepeatPerShuffle))
1098 c_block_copy_lds_to_global.MoveDstSliceWindow(
1099 c_grid_desc_mblock_mperblock_nblock_nperblock,
1100 nxdlperwave_forward_step);
1102 else if constexpr((!nxdlperwave_forward_sweep) && (nxdlperwave > 0))
1104 c_block_copy_lds_to_global.MoveDstSliceWindow(
1105 c_grid_desc_mblock_mperblock_nblock_nperblock,
1106 nxdlperwave_backward_step);
1111 if constexpr(mxdlperwave < MRepeat - CShuffleMRepeatPerShuffle)
1113 c_block_copy_lds_to_global.MoveDstSliceWindow(
1114 c_grid_desc_mblock_mperblock_nblock_nperblock,
1115 mxdlperwave_forward_step);
1119 if constexpr(Block2CTileMap::ReductionStrategy ==
1126 wg_barrier.
inc(tile_idx);
1132 iter_end -= current_iter_length;
1133 if(iter_end <= iter_start)
1138 block_acc_offset -= MPerBlock * NPerBlock;
1145 template <
typename Layout>
1148 static std::string
Get() {
return ""; }
1154 static std::string
Get() {
return "R"; }
1160 static std::string
Get() {
return "C"; }
1165 auto str = std::stringstream();
1168 str <<
"GemmXdlStreamK_"
1169 << std::string(ALayout::name)[0]
1170 << std::string(BLayout::name)[0]
1171 << std::string(CLayout::name)[0]
1173 <<
"B" << BlockSize <<
"_"
1174 <<
"Vec" << ABlockTransferSrcScalarPerVector <<
"x"
1175 << BBlockTransferSrcScalarPerVector <<
"x"
1176 << CBlockTransferScalarPerVector_NWaveNPerXDL <<
"_"
1179 << K0PerBlock <<
"x"
#define CK_MIN_BLOCK_PER_CU
Definition: ck.hpp:31
#define CK_MAX_THREAD_PER_BLOCK
Definition: ck.hpp:30
#define IS_VALID_COMPILATION_PARAMETER_IMPL(CDataType_)
Definition: device_base.hpp:178
__host__ constexpr __device__ auto integer_least_multiple(X x, Y y)
Definition: math.hpp:78
__host__ constexpr __device__ auto integer_divide_ceil(X x, Y y)
Definition: math.hpp:72
__host__ constexpr __device__ T max(T x)
Definition: math.hpp:84
__host__ constexpr __device__ auto make_multi_index(Xs &&... xs)
Definition: array_multi_index.hpp:15
@ Atomic
Definition: block_to_ctile_map.hpp:1012
@ Reduction
Definition: block_to_ctile_map.hpp:1013
__host__ constexpr __device__ auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition: tensor_descriptor_helper.hpp:49
__host__ constexpr __device__ auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition: tensor_descriptor_helper.hpp:101
__host__ constexpr __device__ auto make_merge_transform(const LowLengths &low_lengths)
Definition: multi_index_transform_helper.hpp:55
__host__ constexpr __device__ auto make_single_stage_tensor_adaptor(const Transforms &transforms, LowerDimensionOldTopIdss, UpperDimensionNewTopIdss)
Definition: tensor_adaptor.hpp:425
__host__ constexpr __device__ auto make_freeze_transform(const LowerIndex &low_idx)
Definition: multi_index_transform_helper.hpp:151
constexpr detail::ignore_t ignore
Definition: ignore.hpp:20
__global__ void kernel_gemm_xdlops_streamk(const typename GridwiseGemm::FloatAB *p_a_grid, const typename GridwiseGemm::FloatAB *p_b_grid, typename GridwiseGemm::FloatC *p_c_grid, void *p_workspace, index_t M, index_t N, index_t K, index_t StrideA, index_t StrideB, index_t StrideC, typename GridwiseGemm::Block2CTileMap block_mapping)
Definition: gridwise_gemm_xdlops_streamk.hpp:28
__host__ constexpr __device__ auto make_pass_through_transform(const LowLength &low_length)
Definition: multi_index_transform_helper.hpp:12
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition: type.hpp:297
__host__ constexpr __device__ auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition: multi_index_transform_helper.hpp:90
int32_t index_t
Definition: ck.hpp:299
__device__ index_t get_thread_local_1d_id()
Definition: get_id.hpp:41
__host__ constexpr __device__ auto make_dynamic_buffer(T *p, ElementSpaceSize element_space_size)
Definition: dynamic_buffer.hpp:461
__host__ constexpr __device__ auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition: tensor_descriptor.hpp:319
__host__ constexpr __device__ auto make_right_pad_transform(const LowLength &low_length, const RightPadLength &right_pad, integral_constant< bool, SkipIsValidCheck >=integral_constant< bool, false >{})
Definition: multi_index_transform_helper.hpp:37
__device__ void block_sync_lds()
Definition: synchronization.hpp:10
__host__ constexpr __device__ auto make_cluster_descriptor(const Lengths &lengths, ArrangeOrder order=typename arithmetic_sequence_gen< 0, Lengths::Size(), 1 >::type{})
Definition: cluster_descriptor.hpp:13
unsigned int uint32_t
Definition: stdint.h:126
unsigned char uint8_t
Definition: stdint.h:124
Definition: block_to_ctile_map.hpp:541
Definition: blockwise_gemm_smfmac_xdlops.hpp:44
Definition: gridwise_gemm_xdlops_streamk.hpp:140
index_t StrideB
Definition: gridwise_gemm_xdlops_streamk.hpp:148
index_t StrideC
Definition: gridwise_gemm_xdlops_streamk.hpp:149
FloatC * p_c_grid
Definition: gridwise_gemm_xdlops_streamk.hpp:143
Argument(const FloatAB *p_a_grid_, const FloatAB *p_b_grid_, FloatC *p_c_grid_, index_t M_, index_t N_, index_t K_, index_t StrideA_, index_t StrideB_, index_t StrideC_, uint32_t num_cu, uint32_t occupancy, uint32_t num_sk_blocks_)
Definition: gridwise_gemm_xdlops_streamk.hpp:152
Block2CTileMap block_mapping
Definition: gridwise_gemm_xdlops_streamk.hpp:150
index_t M
Definition: gridwise_gemm_xdlops_streamk.hpp:144
index_t N
Definition: gridwise_gemm_xdlops_streamk.hpp:145
index_t K
Definition: gridwise_gemm_xdlops_streamk.hpp:146
void Print() const
Definition: gridwise_gemm_xdlops_streamk.hpp:177
index_t StrideA
Definition: gridwise_gemm_xdlops_streamk.hpp:147
const FloatAB * p_b_grid
Definition: gridwise_gemm_xdlops_streamk.hpp:142
const FloatAB * p_a_grid
Definition: gridwise_gemm_xdlops_streamk.hpp:141
static std::string Get()
Definition: gridwise_gemm_xdlops_streamk.hpp:1160
static std::string Get()
Definition: gridwise_gemm_xdlops_streamk.hpp:1154
Definition: gridwise_gemm_xdlops_streamk.hpp:1147
static std::string Get()
Definition: gridwise_gemm_xdlops_streamk.hpp:1148
Definition: gridwise_gemm_xdlops_streamk.hpp:115
FloatC_ FloatC
Definition: gridwise_gemm_xdlops_streamk.hpp:137
static constexpr auto I6
Definition: gridwise_gemm_xdlops_streamk.hpp:122
static constexpr auto I5
Definition: gridwise_gemm_xdlops_streamk.hpp:121
static __device__ void Run(const FloatAB *p_a_grid, const FloatAB *p_b_grid, FloatC *p_c_grid, void *p_workspace, index_t M, index_t N, index_t K, index_t StrideA, index_t StrideB, index_t StrideC, Block2CTileMap block_mapping, void *__restrict__ p_shared_block)
Definition: gridwise_gemm_xdlops_streamk.hpp:444
FloatAB_ FloatAB
Definition: gridwise_gemm_xdlops_streamk.hpp:136
static constexpr auto N01
Definition: gridwise_gemm_xdlops_streamk.hpp:128
static constexpr auto I3
Definition: gridwise_gemm_xdlops_streamk.hpp:119
__host__ static constexpr __device__ auto GetCBlockDescriptor_MBlock_MPerShuffle_NBlock_NPerShuffle()
Definition: gridwise_gemm_xdlops_streamk.hpp:388
__host__ static constexpr __device__ bool CheckValidity(const Argument &karg)
Definition: gridwise_gemm_xdlops_streamk.hpp:315
__host__ static constexpr __device__ auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()
Definition: gridwise_gemm_xdlops_streamk.hpp:281
__host__ static constexpr __device__ index_t GetSharedMemoryNumberOfByte()
Definition: gridwise_gemm_xdlops_streamk.hpp:289
FloatAcc FloatCShuffle
Definition: gridwise_gemm_xdlops_streamk.hpp:133
static std::string GetTypeString()
Definition: gridwise_gemm_xdlops_streamk.hpp:1163
__host__ static __device__ auto CalculateGridSize(const Argument &karg)
Definition: gridwise_gemm_xdlops_streamk.hpp:185
static constexpr auto I1
Definition: gridwise_gemm_xdlops_streamk.hpp:117
__host__ static __device__ auto MakeAGridDescriptor_K0_M_K1(index_t M, index_t MPad, index_t K, index_t KPad, index_t StrideA)
Definition: gridwise_gemm_xdlops_streamk.hpp:195
__host__ static constexpr __device__ bool CalculateHasMainK0BlockLoop(index_t K0)
Definition: gridwise_gemm_xdlops_streamk.hpp:353
__host__ static constexpr __device__ auto GetClusterLengthReduction()
Definition: gridwise_gemm_xdlops_streamk.hpp:413
static constexpr auto I4
Definition: gridwise_gemm_xdlops_streamk.hpp:120
__host__ static __device__ auto MakeBGridDescriptor_K0_N_K1(index_t K, index_t KPad, index_t N, index_t NPad, index_t StrideB)
Definition: gridwise_gemm_xdlops_streamk.hpp:224
Block2CTileMap_ Block2CTileMap
Definition: gridwise_gemm_xdlops_streamk.hpp:135
__host__ static constexpr __device__ auto MakeCBlockClusterAdaptor(const CGridDesc &c_m_n_grid_desc, index_t, index_t, index_t KBatch)
Definition: gridwise_gemm_xdlops_streamk.hpp:380
static constexpr index_t MXdlPerWave
Definition: gridwise_gemm_xdlops_streamk.hpp:311
static constexpr auto I2
Definition: gridwise_gemm_xdlops_streamk.hpp:118
FloatAcc_ FloatAcc
Definition: gridwise_gemm_xdlops_streamk.hpp:132
static constexpr auto K1
Definition: gridwise_gemm_xdlops_streamk.hpp:126
remove_cvref_t< decltype(MakeCGridDescriptor_M_N(1, 1, 1, 1, 1))> CGridDesc_M_N
Definition: gridwise_gemm_xdlops_streamk.hpp:442
static constexpr auto I0
Definition: gridwise_gemm_xdlops_streamk.hpp:116
__host__ static constexpr __device__ auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
Definition: gridwise_gemm_xdlops_streamk.hpp:273
__host__ static constexpr __device__ auto GetPartialAccBlockDescriptor()
Definition: gridwise_gemm_xdlops_streamk.hpp:425
static constexpr auto KPerBlock
Definition: gridwise_gemm_xdlops_streamk.hpp:129
__host__ static constexpr __device__ auto MakeCGridDesc_MBlock_MPerBlock_NBlock_NPerBlock(const CGridDesc &c_m_n_grid_desc)
Definition: gridwise_gemm_xdlops_streamk.hpp:362
static constexpr auto I7
Definition: gridwise_gemm_xdlops_streamk.hpp:123
__host__ static __device__ auto MakeCGridDescriptor_M_N(index_t M, index_t MPad, index_t N, index_t NPad, index_t StrideC)
Definition: gridwise_gemm_xdlops_streamk.hpp:253
ThisThreadBlock< BlockSize > ThisThreadBlock
Definition: gridwise_gemm_xdlops_streamk.hpp:131
__host__ static __device__ auto CalculateK0(index_t KPad)
Definition: gridwise_gemm_xdlops_streamk.hpp:192
static constexpr auto M01
Definition: gridwise_gemm_xdlops_streamk.hpp:127
__host__ static constexpr __device__ auto GetCBlockDescriptor_MShuffleRepeat_MPerShuffle_NShuffleRepeat_NPerShuffle()
Definition: gridwise_gemm_xdlops_streamk.hpp:401
static constexpr index_t NXdlPerWave
Definition: gridwise_gemm_xdlops_streamk.hpp:312
Definition: gridwise_gemm_pipeline_v3.hpp:11
Definition: sequence.hpp:43
Definition: static_buffer.hpp:16
__host__ __device__ void Clear()
Definition: static_buffer.hpp:63
Blockwise data transfer.
Definition: thread_group_tensor_slice_transfer_v4r1.hpp:46
Definition: thread_group_tensor_slice_transfer_v6r1r2.hpp:33
Definition: threadwise_tensor_slice_transfer.hpp:39
Helper structure that facilitates transfer of source (grid) data to destination threads.
Definition: threadwise_tensor_slice_transfer.hpp:234
Definition: integral_constant.hpp:20
Definition: reduction_operator.hpp:37
Definition: functional2.hpp:33
Definition: tensor_layout.hpp:31
Definition: tensor_layout.hpp:26
Definition: device_base.hpp:197
Definition: unary_element_wise_operation.hpp:334
Definition: workgroup_barrier.hpp:7
__device__ void inc(uint32_t offset)
Definition: workgroup_barrier.hpp:62
__device__ void wait_eq(uint32_t offset, uint32_t value)
Definition: workgroup_barrier.hpp:29