23 template <
typename Gr
idwiseGemm>
25 #if CK_USE_LAUNCH_BOUNDS
29 const typename GridwiseGemm::FloatAB* p_b_grid,
30 typename GridwiseGemm::FloatC* p_c_grid,
38 typename GridwiseGemm::Block2CTileMap block_mapping)
40 #if(defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx94__))
41 constexpr
index_t shared_size = GridwiseGemm::GetSharedMemoryNumberOfByte();
43 __shared__
uint8_t p_shared[shared_size];
45 GridwiseGemm::Run(p_a_grid,
56 static_cast<void*
>(p_shared));
73 typename Block2CTileMap_,
80 typename AElementwiseOperation,
81 typename BElementwiseOperation,
82 typename CElementwiseOperation,
91 typename ABlockTransferThreadClusterLengths_K0_M_K1,
92 typename ABlockTransferThreadClusterArrangeOrder,
93 typename ABlockTransferSrcAccessOrder,
94 index_t ABlockTransferSrcVectorDim,
95 index_t ABlockTransferSrcScalarPerVector,
96 index_t ABlockTransferDstScalarPerVector_K1,
97 bool AThreadTransferSrcResetCoordinateAfterRun,
99 typename BBlockTransferThreadClusterLengths_K0_N_K1,
100 typename BBlockTransferThreadClusterArrangeOrder,
101 typename BBlockTransferSrcAccessOrder,
102 index_t BBlockTransferSrcVectorDim,
103 index_t BBlockTransferSrcScalarPerVector,
104 index_t BBlockTransferDstScalarPerVector_K1,
105 bool BThreadTransferSrcResetCoordinateAfterRun,
107 index_t CShuffleMRepeatPerShuffle,
108 index_t CShuffleNRepeatPerShuffle,
109 index_t CBlockTransferScalarPerVector_NWaveNPerXDL,
110 typename CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock>
124 static constexpr
auto M01 = 1;
125 static constexpr
auto N01 = 1;
176 std::cout <<
"arg {" <<
"M:" <<
M <<
", " <<
"N:" <<
N <<
", " <<
"K:" <<
K <<
", "
191 __host__ __device__
static auto
196 const auto a_grid_desc_m_k = [&]() {
220 __host__ __device__
static auto
225 const auto b_grid_desc_k_n = [&]() {
249 __host__ __device__
static auto
252 const auto c_grid_desc_m_n = [&]() {
288 constexpr
auto max_lds_align =
K1;
294 constexpr
auto a_block_space_size_aligned =
297 constexpr
auto b_block_space_size_aligned =
300 constexpr
auto c_block_size =
303 return math::max((a_block_space_size_aligned + b_block_space_size_aligned) *
312 if(karg.
K % ABlockTransferSrcScalarPerVector != 0)
317 if(karg.
M % ABlockTransferSrcScalarPerVector != 0)
323 if(karg.
N % BBlockTransferSrcScalarPerVector != 0)
328 if(karg.
K % BBlockTransferSrcScalarPerVector != 0)
334 if(karg.
N % CBlockTransferScalarPerVector_NWaveNPerXDL != 0)
339 if(karg.
M % CBlockTransferScalarPerVector_NWaveNPerXDL != 0)
348 const bool has_main_k0_block_loop = K0 > K0PerBlock;
350 return has_main_k0_block_loop;
353 template <
typename CGr
idDesc>
354 __host__ __device__
static constexpr
auto
357 const auto M = c_m_n_grid_desc.GetLength(
I0);
358 const auto N = c_m_n_grid_desc.GetLength(
I1);
360 const auto MBlock = M / MPerBlock;
361 const auto NBlock = N / NPerBlock;
372 template <
typename CGr
idDesc>
377 c_m_n_grid_desc, 8, KBatch);
380 __host__ __device__
static constexpr
auto
383 constexpr
index_t MWave = MPerBlock / (MRepeat * MPerXDL);
384 constexpr
index_t NWave = NPerBlock / (NRepeat * NPerXDL);
393 __host__ __device__
static constexpr
auto
396 constexpr
index_t MWave = MPerBlock / (MRepeat * MPerXDL);
397 constexpr
index_t NWave = NPerBlock / (NRepeat * NPerXDL);
402 Number<NRepeat / CShuffleNRepeatPerShuffle>{},
410 constexpr
auto NPerBlockPow2 = math::next_power_of_two<NPerBlock>();
411 constexpr
auto NPerBlockReduction =
412 NPerBlockPow2 / CBlockTransferScalarPerVector_NWaveNPerXDL;
413 constexpr
auto MPerBlockReduction =
414 (BlockSize + NPerBlockReduction - 1) / NPerBlockReduction;
420 const auto c_partial_acc_block_m_n = [&]() {
432 return c_partial_acc_block_m_n;
448 void* __restrict__ p_shared_block)
453 uint32_t pad_m = (m + MPerBlock - 1) / MPerBlock * MPerBlock;
454 uint32_t pad_n = (n + NPerBlock - 1) / NPerBlock * NPerBlock;
464 const auto c_grid_desc_mblock_mperblock_nblock_nperblock =
466 const AElementwiseOperation a_element_op = AElementwiseOperation{};
467 const BElementwiseOperation b_element_op = BElementwiseOperation{};
468 const CElementwiseOperation c_element_op = CElementwiseOperation{};
470 const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
471 p_a_grid, a_k0_m_k1_grid_desc.GetElementSpaceSize());
472 const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
473 p_b_grid, b_k0_n_k1_grid_desc.GetElementSpaceSize());
474 auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
475 p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
478 constexpr
auto max_lds_align =
K1;
486 auto blockwise_gemm =
491 decltype(a_block_desc_k0_m_k1),
492 decltype(b_block_desc_k0_n_k1),
499 auto c_thread_buf = blockwise_gemm.GetCThreadBuffer();
502 constexpr
auto a_block_space_size =
506 FloatAB* p_b_block =
static_cast<FloatAB*
>(p_shared_block) + a_block_space_size;
508 constexpr
auto a_block_slice_copy_step =
make_multi_index(K0PerBlock, 0, 0);
509 constexpr
auto b_block_slice_copy_step =
make_multi_index(K0PerBlock, 0, 0);
511 auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
512 p_a_block, a_block_desc_k0_m_k1.GetElementSpaceSize());
513 auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
514 p_b_block, b_block_desc_k0_n_k1.GetElementSpaceSize());
519 uint32_t block_idx = block_mapping.get_block_idx();
520 bool is_sk_block = block_idx < block_mapping.sk_num_blocks;
521 bool is_dp_block = block_idx >= block_mapping.dp_start_block_idx &&
522 block_idx < block_mapping.reduction_start_block_idx;
523 bool is_reduction_block = block_idx >= block_mapping.reduction_start_block_idx;
524 bool is_padding_block = block_idx >= block_mapping.sk_num_blocks &&
525 block_idx < block_mapping.dp_start_block_idx;
527 block_mapping.get_block_itr(block_idx, iter_start, iter_end);
528 uint32_t total_iter_length = iter_end - iter_start;
534 reinterpret_cast<uint32_t*
>(
reinterpret_cast<char*
>(p_workspace) +
535 block_mapping.get_workspace_size_for_acc(
sizeof(
FloatAcc)));
539 if(is_reduction_block)
544 const auto reduce_thread_cluster_idx =
546 const auto thread_m_cluster_id = reduce_thread_cluster_idx[
I0];
547 const auto thread_n_cluster_id = reduce_thread_cluster_idx[
I1];
549 constexpr
auto MReduceIters =
553 cluster_length_reduce.At(
I1) *
564 0, cluster_length_reduce.At(
I1) * CBlockTransferScalarPerVector_NWaveNPerXDL);
565 constexpr
auto partial_acc_load_step_n_reverse =
567 -1 * cluster_length_reduce.At(
I1).value * (NReduceIters - 1) *
568 CBlockTransferScalarPerVector_NWaveNPerXDL);
569 constexpr
auto partial_acc_load_step_m =
576 cluster_length_reduce.At(
I1) * CBlockTransferScalarPerVector_NWaveNPerXDL);
577 constexpr
auto partial_acc_store_step_n_reverse =
581 -1 * cluster_length_reduce.At(
I1).value * (NReduceIters - 1) *
582 CBlockTransferScalarPerVector_NWaveNPerXDL);
583 constexpr
auto partial_acc_store_step_m =
588 CBlockTransferScalarPerVector_NWaveNPerXDL,
593 CBlockTransferScalarPerVector_NWaveNPerXDL,
598 auto reduction_idx = blockIdx.x - block_mapping.reduction_start_block_idx;
599 auto spatial_idx = block_mapping.tile_to_spatial(reduction_idx, m, n);
604 block_mapping.get_acc_buffer_offset_from_tile(reduction_idx);
606 block_mapping.get_acc_buffer_offset_from_tile(reduction_idx + 1);
611 decltype(c_partial_acc_block_m_n),
612 decltype(acc_thread_buf_load_desc),
616 CBlockTransferScalarPerVector_NWaveNPerXDL,
619 >{c_partial_acc_block_m_n,
621 thread_n_cluster_id *
622 CBlockTransferScalarPerVector_NWaveNPerXDL)};
627 decltype(acc_thread_buf_store_desc),
628 decltype(c_grid_desc_mblock_mperblock_nblock_nperblock),
629 CElementwiseOperation,
633 CBlockTransferScalarPerVector_NWaveNPerXDL,
637 >{c_grid_desc_mblock_mperblock_nblock_nperblock,
640 __builtin_amdgcn_readfirstlane(spatial_idx[
I1]),
641 thread_n_cluster_id *
642 CBlockTransferScalarPerVector_NWaveNPerXDL),
643 CElementwiseOperation{}};
646 wg_barrier.
wait_eq(reduction_idx, tile_acc_offset_end - tile_acc_offset_start);
649 if(threadIdx.x == 0) {
650 printf(
"bid:%d, rid:%d, os:%d,%d, spatial:%d,%d\n",
static_cast<int>(blockIdx.x),
651 reduction_idx, __builtin_amdgcn_readfirstlane(tile_acc_offset_start), __builtin_amdgcn_readfirstlane(tile_acc_offset_end),
652 __builtin_amdgcn_readfirstlane(spatial_idx[
I0]),
653 __builtin_amdgcn_readfirstlane(spatial_idx[
I1]));
657 using Accumulation = ck::detail::
660 for(
int i_m = 0; i_m < MReduceIters; i_m++)
664 for(
auto i = tile_acc_offset_start; i < tile_acc_offset_end; i++)
666 auto c_partial_acc_buf =
669 reinterpret_cast<FloatAcc*
>(p_workspace) +
670 i * c_partial_acc_block_m_n.GetElementSpaceSize(),
671 c_partial_acc_block_m_n.GetElementSpaceSize());
673 acc_load.Run(c_partial_acc_block_m_n,
675 acc_thread_buf_load_desc,
681 constexpr
auto offset =
682 acc_thread_buf_load_desc.CalculateOffset(
689 if(thread_n_cluster_id * CBlockTransferScalarPerVector_NWaveNPerXDL <
692 acc_store.Run(acc_thread_buf_store_desc,
695 c_grid_desc_mblock_mperblock_nblock_nperblock,
698 if constexpr(NReduceIters != 1)
700 if constexpr(i_n_reduce != (NReduceIters - 1))
702 acc_load.MoveSrcSliceWindow(c_partial_acc_block_m_n,
703 partial_acc_load_step_n);
704 acc_store.MoveDstSliceWindow(
705 c_grid_desc_mblock_mperblock_nblock_nperblock,
706 partial_acc_store_step_n);
710 acc_load.MoveSrcSliceWindow(c_partial_acc_block_m_n,
711 partial_acc_load_step_n_reverse);
712 acc_store.MoveDstSliceWindow(
713 c_grid_desc_mblock_mperblock_nblock_nperblock,
714 partial_acc_store_step_n_reverse);
719 acc_load.MoveSrcSliceWindow(c_partial_acc_block_m_n,
720 partial_acc_load_step_m);
721 acc_store.MoveDstSliceWindow(c_grid_desc_mblock_mperblock_nblock_nperblock,
722 partial_acc_store_step_m);
731 (block_mapping.get_acc_buffer_offset_from_block(block_idx + 1) - 1) * MPerBlock *
736 uint32_t current_iter_length = __builtin_amdgcn_readfirstlane(
737 block_mapping.get_current_iter_length(iter_start, iter_end, total_iter_length));
739 block_mapping.get_tile_idx_with_offset(iter_end - 1, tile_idx, iter_offset);
740 iter_offset = __builtin_amdgcn_readfirstlane(iter_offset - current_iter_length + 1);
741 auto spatial_idx = block_mapping.tile_to_spatial(tile_idx, m, n);
743 const index_t m_block_data_idx_on_grid =
744 __builtin_amdgcn_readfirstlane(spatial_idx[
I0] * MPerBlock);
746 const index_t n_block_data_idx_on_grid =
747 __builtin_amdgcn_readfirstlane(spatial_idx[
I1] * NPerBlock);
749 const index_t k0_block_data_idx_on_grid =
750 __builtin_amdgcn_readfirstlane(iter_offset * K0PerBlock);
753 auto a_blockwise_copy =
755 AElementwiseOperation,
759 ABlockTransferThreadClusterLengths_K0_M_K1,
760 ABlockTransferThreadClusterArrangeOrder,
763 decltype(a_k0_m_k1_grid_desc),
764 decltype(a_block_desc_k0_m_k1),
765 ABlockTransferSrcAccessOrder,
767 ABlockTransferSrcVectorDim,
769 ABlockTransferSrcScalarPerVector,
770 ABlockTransferDstScalarPerVector_K1,
773 AThreadTransferSrcResetCoordinateAfterRun,
778 a_block_desc_k0_m_k1,
783 auto b_blockwise_copy =
785 BElementwiseOperation,
789 BBlockTransferThreadClusterLengths_K0_N_K1,
790 BBlockTransferThreadClusterArrangeOrder,
793 decltype(b_k0_n_k1_grid_desc),
794 decltype(b_block_desc_k0_n_k1),
795 BBlockTransferSrcAccessOrder,
797 BBlockTransferSrcVectorDim,
799 BBlockTransferSrcScalarPerVector,
800 BBlockTransferDstScalarPerVector_K1,
803 BThreadTransferSrcResetCoordinateAfterRun,
808 b_block_desc_k0_n_k1,
812 const index_t num_k_block_main_loop = current_iter_length;
814 gridwise_gemm_pipeline.Run(a_k0_m_k1_grid_desc,
815 a_block_desc_k0_m_k1,
819 a_block_slice_copy_step,
821 b_block_desc_k0_n_k1,
825 b_block_slice_copy_step,
828 num_k_block_main_loop);
832 constexpr
index_t MWave = MPerBlock / (MRepeat * MPerXDL);
833 constexpr
index_t NWave = NPerBlock / (NRepeat * NPerXDL);
835 constexpr
auto c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc =
836 blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
838 constexpr
auto c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc =
839 blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
841 constexpr
auto M0 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(
I0);
842 constexpr
auto N0 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(
I1);
843 constexpr
auto M1 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(
I2);
844 constexpr
auto N1 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(
I3);
845 constexpr
auto M2 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(
I4);
846 constexpr
auto M3 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(
I5);
847 constexpr
auto M4 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(
I6);
848 constexpr
auto N2 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(
I7);
850 constexpr
auto c_block_desc_mblock_mpershuffle_nblock_npershuffle =
853 constexpr
auto c_block_desc_mshuffle_mpershuffle_nshuffle_npershuffle =
856 auto c_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
858 c_block_desc_mblock_mpershuffle_nblock_npershuffle.GetElementSpaceSize());
860 auto c_partial_acc_buf =
861 make_dynamic_buffer<AddressSpaceEnum::Global, AmdBufferCoherenceEnum::GLC>(
862 reinterpret_cast<FloatAcc*
>(p_workspace) + block_acc_offset,
863 c_block_desc_mshuffle_mpershuffle_nshuffle_npershuffle
864 .GetElementSpaceSize());
867 c_block_desc_mblock_mpershuffle_nblock_npershuffle,
888 const auto c_thread_mtx_on_block =
889 blockwise_gemm.CalculateCThreadOriginDataIndex(
I0,
I0,
I0,
I0);
891 const index_t m_thread_data_on_block = c_thread_mtx_on_block[
I0];
892 const index_t n_thread_data_on_block = c_thread_mtx_on_block[
I1];
894 const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor =
900 const auto m_thread_data_on_block_idx =
901 m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex(
904 const auto n_thread_data_on_block_to_n0_n1_n2_adaptor =
910 const auto n_thread_data_on_block_idx =
911 n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex(
918 decltype(c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc),
919 decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2),
922 CShuffleNRepeatPerShuffle,
934 true>{c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
937 m_thread_data_on_block_idx[
I1],
938 n_thread_data_on_block_idx[
I1],
939 m_thread_data_on_block_idx[
I2],
940 m_thread_data_on_block_idx[
I3],
941 m_thread_data_on_block_idx[
I4],
942 n_thread_data_on_block_idx[
I2]),
948 CElementwiseOperation,
951 CShuffleMRepeatPerShuffle * MWave * MPerXDL,
953 CShuffleNRepeatPerShuffle * NWave * NPerXDL>,
954 CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
958 decltype(c_block_desc_mblock_mpershuffle_nblock_npershuffle),
959 decltype(c_grid_desc_mblock_mperblock_nblock_nperblock),
962 CBlockTransferScalarPerVector_NWaveNPerXDL,
965 {c_block_desc_mblock_mpershuffle_nblock_npershuffle,
967 c_grid_desc_mblock_mperblock_nblock_nperblock,
970 __builtin_amdgcn_readfirstlane(spatial_idx[
I1]),
977 CElementwiseOperation,
980 CShuffleMRepeatPerShuffle * MWave * MPerXDL,
982 CShuffleNRepeatPerShuffle * NWave * NPerXDL>,
983 CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
987 decltype(c_block_desc_mblock_mpershuffle_nblock_npershuffle),
988 decltype(c_block_desc_mshuffle_mpershuffle_nshuffle_npershuffle),
991 CBlockTransferScalarPerVector_NWaveNPerXDL,
996 {c_block_desc_mblock_mpershuffle_nblock_npershuffle,
998 c_block_desc_mshuffle_mpershuffle_nshuffle_npershuffle,
1002 constexpr
auto mxdlperwave_forward_step =
1004 constexpr
auto nxdlperwave_forward_step =
1006 constexpr
auto nxdlperwave_backward_step =
1010 constexpr
auto mxdlperwave = mxdlperwave_iter;
1013 constexpr
bool nxdlperwave_forward_sweep =
1014 (mxdlperwave % (2 * CShuffleMRepeatPerShuffle) == 0);
1016 constexpr
index_t nxdlperwave_value =
1017 nxdlperwave_forward_sweep
1019 : (NRepeat - nxdlperwave_iter - CShuffleNRepeatPerShuffle);
1027 c_thread_copy_vgpr_to_lds.Run(
1028 c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc,
1031 c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
1037 c_block_copy_lds_to_global.SetSrcSliceOrigin(
1038 c_block_desc_mblock_mpershuffle_nblock_npershuffle,
1043 c_block_copy_lds_to_global.template
Run<decltype(c_block_buf),
1044 decltype(c_grid_buf),
1046 c_block_desc_mblock_mpershuffle_nblock_npershuffle,
1048 c_grid_desc_mblock_mperblock_nblock_nperblock,
1050 else if(is_sk_block)
1052 if constexpr(Block2CTileMap::ReductionStrategy ==
1056 c_block_copy_lds_to_partial_acc.SetSrcSliceOrigin(
1057 c_block_desc_mblock_mpershuffle_nblock_npershuffle,
1060 c_block_copy_lds_to_partial_acc.SetDstSliceOrigin(
1061 c_block_desc_mshuffle_mpershuffle_nshuffle_npershuffle,
1062 make_tuple(mxdlperwave.value, 0, nxdlperwave.value, 0));
1064 c_block_copy_lds_to_partial_acc
1065 .template
Run<decltype(c_block_buf),
1066 decltype(c_partial_acc_buf),
1068 c_block_desc_mblock_mpershuffle_nblock_npershuffle,
1070 c_block_desc_mshuffle_mpershuffle_nshuffle_npershuffle,
1073 else if constexpr(Block2CTileMap::ReductionStrategy ==
1076 c_block_copy_lds_to_global
1077 .template
Run<decltype(c_block_buf),
1078 decltype(c_grid_buf),
1080 c_block_desc_mblock_mpershuffle_nblock_npershuffle,
1082 c_grid_desc_mblock_mperblock_nblock_nperblock,
1088 if constexpr(nxdlperwave_forward_sweep &&
1089 (nxdlperwave < NRepeat - CShuffleNRepeatPerShuffle))
1091 c_block_copy_lds_to_global.MoveDstSliceWindow(
1092 c_grid_desc_mblock_mperblock_nblock_nperblock,
1093 nxdlperwave_forward_step);
1095 else if constexpr((!nxdlperwave_forward_sweep) && (nxdlperwave > 0))
1097 c_block_copy_lds_to_global.MoveDstSliceWindow(
1098 c_grid_desc_mblock_mperblock_nblock_nperblock,
1099 nxdlperwave_backward_step);
1104 if constexpr(mxdlperwave < MRepeat - CShuffleMRepeatPerShuffle)
1106 c_block_copy_lds_to_global.MoveDstSliceWindow(
1107 c_grid_desc_mblock_mperblock_nblock_nperblock,
1108 mxdlperwave_forward_step);
1112 if constexpr(Block2CTileMap::ReductionStrategy ==
1119 wg_barrier.
inc(tile_idx);
1125 iter_end -= current_iter_length;
1126 if(iter_end <= iter_start)
1131 block_acc_offset -= MPerBlock * NPerBlock;
1138 template <
typename Layout>
1141 static std::string
Get() {
return ""; }
1147 static std::string
Get() {
return "R"; }
1153 static std::string
Get() {
return "C"; }
1158 auto str = std::stringstream();
1161 str <<
"GemmXdlStreamK_"
1162 << std::string(ALayout::name)[0]
1163 << std::string(BLayout::name)[0]
1164 << std::string(CLayout::name)[0]
1166 <<
"B" << BlockSize <<
"_"
1167 <<
"Vec" << ABlockTransferSrcScalarPerVector <<
"x"
1168 << BBlockTransferSrcScalarPerVector <<
"x"
1169 << CBlockTransferScalarPerVector_NWaveNPerXDL <<
"_"
1172 << K0PerBlock <<
"x"
#define CK_MIN_BLOCK_PER_CU
Definition: ck.hpp:30
#define CK_MAX_THREAD_PER_BLOCK
Definition: ck.hpp:29
__host__ constexpr __device__ auto integer_least_multiple(X x, Y y)
Definition: math.hpp:78
__host__ constexpr __device__ auto integer_divide_ceil(X x, Y y)
Definition: math.hpp:72
__host__ constexpr __device__ T max(T x)
Definition: math.hpp:84
__host__ constexpr __device__ auto make_multi_index(Xs &&... xs)
Definition: array_multi_index.hpp:15
@ Atomic
Definition: block_to_ctile_map.hpp:1011
@ Reduction
Definition: block_to_ctile_map.hpp:1012
__host__ constexpr __device__ auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition: tensor_descriptor_helper.hpp:49
__host__ constexpr __device__ auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition: tensor_descriptor_helper.hpp:101
__host__ constexpr __device__ auto make_merge_transform(const LowLengths &low_lengths)
Definition: multi_index_transform_helper.hpp:55
__host__ constexpr __device__ auto make_single_stage_tensor_adaptor(const Transforms &transforms, LowerDimensionOldTopIdss, UpperDimensionNewTopIdss)
Definition: tensor_adaptor.hpp:425
__host__ constexpr __device__ auto make_freeze_transform(const LowerIndex &low_idx)
Definition: multi_index_transform_helper.hpp:98
constexpr detail::ignore_t ignore
Definition: ignore.hpp:20
__global__ void kernel_gemm_xdlops_streamk(const typename GridwiseGemm::FloatAB *p_a_grid, const typename GridwiseGemm::FloatAB *p_b_grid, typename GridwiseGemm::FloatC *p_c_grid, void *p_workspace, index_t M, index_t N, index_t K, index_t StrideA, index_t StrideB, index_t StrideC, typename GridwiseGemm::Block2CTileMap block_mapping)
Definition: gridwise_gemm_xdlops_streamk.hpp:28
__host__ constexpr __device__ auto make_pass_through_transform(const LowLength &low_length)
Definition: multi_index_transform_helper.hpp:12
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition: type.hpp:297
__host__ constexpr __device__ auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition: multi_index_transform_helper.hpp:90
int32_t index_t
Definition: ck.hpp:298
__device__ index_t get_thread_local_1d_id()
Definition: get_id.hpp:52
__host__ constexpr __device__ auto make_dynamic_buffer(T *p, ElementSpaceSize element_space_size)
Definition: dynamic_buffer.hpp:461
__host__ constexpr __device__ auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition: tensor_descriptor.hpp:319
__host__ constexpr __device__ auto make_right_pad_transform(const LowLength &low_length, const RightPadLength &right_pad, integral_constant< bool, SkipIsValidCheck >=integral_constant< bool, false >{})
Definition: multi_index_transform_helper.hpp:37
__device__ void block_sync_lds()
Definition: synchronization.hpp:10
__host__ constexpr __device__ auto make_cluster_descriptor(const Lengths &lengths, ArrangeOrder order=typename arithmetic_sequence_gen< 0, Lengths::Size(), 1 >::type{})
Definition: cluster_descriptor.hpp:13
unsigned int uint32_t
Definition: stdint.h:126
unsigned char uint8_t
Definition: stdint.h:124
Definition: block_to_ctile_map.hpp:540
Definition: blockwise_gemm_smfmac_xdlops.hpp:44
Definition: gridwise_gemm_xdlops_streamk.hpp:137
index_t K
Definition: gridwise_gemm_xdlops_streamk.hpp:143
const FloatAB * p_b_grid
Definition: gridwise_gemm_xdlops_streamk.hpp:139
void Print() const
Definition: gridwise_gemm_xdlops_streamk.hpp:174
index_t M
Definition: gridwise_gemm_xdlops_streamk.hpp:141
Argument(const FloatAB *p_a_grid_, const FloatAB *p_b_grid_, FloatC *p_c_grid_, index_t M_, index_t N_, index_t K_, index_t StrideA_, index_t StrideB_, index_t StrideC_, uint32_t num_cu, uint32_t occupancy, uint32_t num_sk_blocks_)
Definition: gridwise_gemm_xdlops_streamk.hpp:149
FloatC * p_c_grid
Definition: gridwise_gemm_xdlops_streamk.hpp:140
const FloatAB * p_a_grid
Definition: gridwise_gemm_xdlops_streamk.hpp:138
index_t StrideC
Definition: gridwise_gemm_xdlops_streamk.hpp:146
index_t StrideB
Definition: gridwise_gemm_xdlops_streamk.hpp:145
index_t StrideA
Definition: gridwise_gemm_xdlops_streamk.hpp:144
index_t N
Definition: gridwise_gemm_xdlops_streamk.hpp:142
Block2CTileMap block_mapping
Definition: gridwise_gemm_xdlops_streamk.hpp:147
static std::string Get()
Definition: gridwise_gemm_xdlops_streamk.hpp:1153
static std::string Get()
Definition: gridwise_gemm_xdlops_streamk.hpp:1147
Definition: gridwise_gemm_xdlops_streamk.hpp:1140
static std::string Get()
Definition: gridwise_gemm_xdlops_streamk.hpp:1141
Definition: gridwise_gemm_xdlops_streamk.hpp:112
static constexpr auto I5
Definition: gridwise_gemm_xdlops_streamk.hpp:118
static __device__ void Run(const FloatAB *p_a_grid, const FloatAB *p_b_grid, FloatC *p_c_grid, void *p_workspace, index_t M, index_t N, index_t K, index_t StrideA, index_t StrideB, index_t StrideC, Block2CTileMap block_mapping, void *__restrict__ p_shared_block)
Definition: gridwise_gemm_xdlops_streamk.hpp:437
__host__ static constexpr __device__ auto MakeCGridDesc_MBlock_MPerBlock_NBlock_NPerBlock(const CGridDesc &c_m_n_grid_desc)
Definition: gridwise_gemm_xdlops_streamk.hpp:355
__host__ static __device__ auto MakeAGridDescriptor_K0_M_K1(index_t M, index_t MPad, index_t K, index_t KPad, index_t StrideA)
Definition: gridwise_gemm_xdlops_streamk.hpp:192
__host__ static __device__ auto CalculateK0(index_t KPad)
Definition: gridwise_gemm_xdlops_streamk.hpp:189
static constexpr auto I0
Definition: gridwise_gemm_xdlops_streamk.hpp:113
Block2CTileMap_ Block2CTileMap
Definition: gridwise_gemm_xdlops_streamk.hpp:132
__host__ static constexpr __device__ auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()
Definition: gridwise_gemm_xdlops_streamk.hpp:278
__host__ static __device__ auto CalculateGridSize(const Argument &karg)
Definition: gridwise_gemm_xdlops_streamk.hpp:182
FloatAcc FloatCShuffle
Definition: gridwise_gemm_xdlops_streamk.hpp:130
__host__ static constexpr __device__ auto GetClusterLengthReduction()
Definition: gridwise_gemm_xdlops_streamk.hpp:406
__host__ static constexpr __device__ bool CalculateHasMainK0BlockLoop(index_t K0)
Definition: gridwise_gemm_xdlops_streamk.hpp:346
static constexpr auto N01
Definition: gridwise_gemm_xdlops_streamk.hpp:125
static constexpr auto I6
Definition: gridwise_gemm_xdlops_streamk.hpp:119
__host__ static constexpr __device__ bool CheckValidity(const Argument &karg)
Definition: gridwise_gemm_xdlops_streamk.hpp:308
static constexpr auto M01
Definition: gridwise_gemm_xdlops_streamk.hpp:124
__host__ static constexpr __device__ auto GetCBlockDescriptor_MBlock_MPerShuffle_NBlock_NPerShuffle()
Definition: gridwise_gemm_xdlops_streamk.hpp:381
__host__ static __device__ auto MakeCGridDescriptor_M_N(index_t M, index_t MPad, index_t N, index_t NPad, index_t StrideC)
Definition: gridwise_gemm_xdlops_streamk.hpp:250
static std::string GetTypeString()
Definition: gridwise_gemm_xdlops_streamk.hpp:1156
__host__ static constexpr __device__ auto GetCBlockDescriptor_MShuffleRepeat_MPerShuffle_NShuffleRepeat_NPerShuffle()
Definition: gridwise_gemm_xdlops_streamk.hpp:394
__host__ static constexpr __device__ auto GetPartialAccBlockDescriptor()
Definition: gridwise_gemm_xdlops_streamk.hpp:418
static constexpr auto I2
Definition: gridwise_gemm_xdlops_streamk.hpp:115
static constexpr auto I1
Definition: gridwise_gemm_xdlops_streamk.hpp:114
FloatAB_ FloatAB
Definition: gridwise_gemm_xdlops_streamk.hpp:133
static constexpr auto K1
Definition: gridwise_gemm_xdlops_streamk.hpp:123
static constexpr auto KPerBlock
Definition: gridwise_gemm_xdlops_streamk.hpp:126
__host__ static constexpr __device__ auto MakeCBlockClusterAdaptor(const CGridDesc &c_m_n_grid_desc, index_t, index_t, index_t KBatch)
Definition: gridwise_gemm_xdlops_streamk.hpp:373
FloatAcc_ FloatAcc
Definition: gridwise_gemm_xdlops_streamk.hpp:129
remove_cvref_t< decltype(MakeCGridDescriptor_M_N(1, 1, 1, 1, 1))> CGridDesc_M_N
Definition: gridwise_gemm_xdlops_streamk.hpp:435
static constexpr auto I3
Definition: gridwise_gemm_xdlops_streamk.hpp:116
ThisThreadBlock< BlockSize > ThisThreadBlock
Definition: gridwise_gemm_xdlops_streamk.hpp:128
static constexpr auto I7
Definition: gridwise_gemm_xdlops_streamk.hpp:120
__host__ static constexpr __device__ index_t GetSharedMemoryNumberOfByte()
Definition: gridwise_gemm_xdlops_streamk.hpp:286
static constexpr auto I4
Definition: gridwise_gemm_xdlops_streamk.hpp:117
FloatC_ FloatC
Definition: gridwise_gemm_xdlops_streamk.hpp:134
__host__ static __device__ auto MakeBGridDescriptor_K0_N_K1(index_t K, index_t KPad, index_t N, index_t NPad, index_t StrideB)
Definition: gridwise_gemm_xdlops_streamk.hpp:221
__host__ static constexpr __device__ auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
Definition: gridwise_gemm_xdlops_streamk.hpp:270
Definition: gridwise_gemm_pipeline_v3.hpp:11
Definition: sequence.hpp:43
Definition: static_buffer.hpp:16
__host__ __device__ void Clear()
Definition: static_buffer.hpp:63
Blockwise data transfer.
Definition: thread_group_tensor_slice_transfer_v4r1.hpp:46
Definition: thread_group_tensor_slice_transfer_v6r1r2.hpp:33
Definition: threadwise_tensor_slice_transfer.hpp:39
Helper structure that facilitates transfer of source (grid) data to destination threads.
Definition: threadwise_tensor_slice_transfer.hpp:234
Definition: integral_constant.hpp:20
Definition: reduction_operator.hpp:37
Definition: functional2.hpp:33
Definition: tensor_layout.hpp:21
Definition: tensor_layout.hpp:16
Definition: device_base.hpp:51
Definition: unary_element_wise_operation.hpp:334
Definition: workgroup_barrier.hpp:7
__device__ void inc(uint32_t offset)
Definition: workgroup_barrier.hpp:62
__device__ void wait_eq(uint32_t offset, uint32_t value)
Definition: workgroup_barrier.hpp:29