23 template <
typename Gr
idwiseGemm>
 
   25 #if CK_USE_LAUNCH_BOUNDS 
   29                                const typename GridwiseGemm::FloatAB* p_b_grid,
 
   30                                typename GridwiseGemm::FloatC* p_c_grid,
 
   38                                typename GridwiseGemm::Block2CTileMap block_mapping)
 
   40 #if defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx94__) || defined(__gfx12__) 
   41     if constexpr(GridwiseGemm::template IsValidCompilationParameter<>())
 
   43         constexpr 
index_t shared_size = GridwiseGemm::GetSharedMemoryNumberOfByte();
 
   45         __shared__ 
uint8_t p_shared[shared_size];
 
   47         GridwiseGemm::Run(p_a_grid,
 
   58                           static_cast<void*
>(p_shared));
 
   76           typename Block2CTileMap_,
 
   83           typename AElementwiseOperation,
 
   84           typename BElementwiseOperation,
 
   85           typename CElementwiseOperation,
 
   94           typename ABlockTransferThreadClusterLengths_K0_M_K1,
 
   95           typename ABlockTransferThreadClusterArrangeOrder,
 
   96           typename ABlockTransferSrcAccessOrder,
 
   97           index_t ABlockTransferSrcVectorDim,
 
   98           index_t ABlockTransferSrcScalarPerVector,
 
   99           index_t ABlockTransferDstScalarPerVector_K1,
 
  100           bool AThreadTransferSrcResetCoordinateAfterRun,
 
  102           typename BBlockTransferThreadClusterLengths_K0_N_K1,
 
  103           typename BBlockTransferThreadClusterArrangeOrder,
 
  104           typename BBlockTransferSrcAccessOrder,
 
  105           index_t BBlockTransferSrcVectorDim,
 
  106           index_t BBlockTransferSrcScalarPerVector,
 
  107           index_t BBlockTransferDstScalarPerVector_K1,
 
  108           bool BThreadTransferSrcResetCoordinateAfterRun,
 
  110           index_t CShuffleMRepeatPerShuffle,
 
  111           index_t CShuffleNRepeatPerShuffle,
 
  112           index_t CBlockTransferScalarPerVector_NWaveNPerXDL,
 
  113           typename CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock>
 
  127     static constexpr 
auto M01       = 1;
 
  128     static constexpr 
auto N01       = 1;
 
  179             std::cout << 
"arg {" << 
"M:" << 
M << 
", " << 
"N:" << 
N << 
", " << 
"K:" << 
K << 
", " 
  194     __host__ __device__ 
static auto 
  199         const auto a_grid_desc_m_k = [&]() {
 
  223     __host__ __device__ 
static auto 
  228         const auto b_grid_desc_k_n = [&]() {
 
  252     __host__ __device__ 
static auto 
  255         const auto c_grid_desc_m_n = [&]() {
 
  291         constexpr 
auto max_lds_align = 
K1;
 
  297         constexpr 
auto a_block_space_size_aligned =
 
  300         constexpr 
auto b_block_space_size_aligned =
 
  303         constexpr 
auto c_block_size =
 
  306         return math::max((a_block_space_size_aligned + b_block_space_size_aligned) *
 
  319             if(karg.K % ABlockTransferSrcScalarPerVector != 0)
 
  324             if(karg.M % ABlockTransferSrcScalarPerVector != 0)
 
  330             if(karg.N % BBlockTransferSrcScalarPerVector != 0)
 
  335             if(karg.K % BBlockTransferSrcScalarPerVector != 0)
 
  341             if(karg.N % CBlockTransferScalarPerVector_NWaveNPerXDL != 0)
 
  346             if(karg.M % CBlockTransferScalarPerVector_NWaveNPerXDL != 0)
 
  355         const bool has_main_k0_block_loop = K0 > K0PerBlock;
 
  357         return has_main_k0_block_loop;
 
  360     template <
typename CGr
idDesc>
 
  361     __host__ __device__ 
static constexpr 
auto 
  364         const auto M = c_m_n_grid_desc.GetLength(
I0);
 
  365         const auto N = c_m_n_grid_desc.GetLength(
I1);
 
  367         const auto MBlock = M / MPerBlock;
 
  368         const auto NBlock = N / NPerBlock;
 
  379     template <
typename CGr
idDesc>
 
  384             c_m_n_grid_desc, 8, KBatch);
 
  387     __host__ __device__ 
static constexpr 
auto 
  390         constexpr 
index_t MWave = MPerBlock / (MRepeat * MPerXdl);
 
  391         constexpr 
index_t NWave = NRepeat * NPerXdl == 0 ? 1 : NPerBlock / (NRepeat * NPerXdl);
 
  400     __host__ __device__ 
static constexpr 
auto 
  403         constexpr 
index_t MWave = MPerBlock / (MRepeat * MPerXdl);
 
  404         constexpr 
index_t NWave = NRepeat * NPerXdl == 0 ? 1 : NPerBlock / (NRepeat * NPerXdl);
 
  409                        Number<NRepeat / CShuffleNRepeatPerShuffle>{},
 
  417         constexpr 
auto NPerBlockPow2 = math::next_power_of_two<NPerBlock>();
 
  418         constexpr 
auto NPerBlockReduction =
 
  419             NPerBlockPow2 / CBlockTransferScalarPerVector_NWaveNPerXDL;
 
  420         constexpr 
auto MPerBlockReduction =
 
  421             (BlockSize + NPerBlockReduction - 1) / NPerBlockReduction;
 
  427         const auto c_partial_acc_block_m_n = [&]() {
 
  439         return c_partial_acc_block_m_n;
 
  455                                void* __restrict__ p_shared_block)
 
  460         uint32_t pad_m    = (m + MPerBlock - 1) / MPerBlock * MPerBlock;
 
  461         uint32_t pad_n    = (n + NPerBlock - 1) / NPerBlock * NPerBlock;
 
  471         const auto c_grid_desc_mblock_mperblock_nblock_nperblock =
 
  473         const AElementwiseOperation a_element_op = AElementwiseOperation{};
 
  474         const BElementwiseOperation b_element_op = BElementwiseOperation{};
 
  475         const CElementwiseOperation c_element_op = CElementwiseOperation{};
 
  477         const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
 
  478             p_a_grid, a_k0_m_k1_grid_desc.GetElementSpaceSize());
 
  479         const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
 
  480             p_b_grid, b_k0_n_k1_grid_desc.GetElementSpaceSize());
 
  481         auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
 
  482             p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
 
  485         constexpr 
auto max_lds_align = 
K1;
 
  493         auto blockwise_gemm =
 
  498                                                                 decltype(a_block_desc_k0_m_k1),
 
  499                                                                 decltype(b_block_desc_k0_n_k1),
 
  506         auto c_thread_buf = blockwise_gemm.GetCThreadBuffer();
 
  509         constexpr 
auto a_block_space_size =
 
  513         FloatAB* p_b_block = 
static_cast<FloatAB*
>(p_shared_block) + a_block_space_size;
 
  515         constexpr 
auto a_block_slice_copy_step = 
make_multi_index(K0PerBlock, 0, 0);
 
  516         constexpr 
auto b_block_slice_copy_step = 
make_multi_index(K0PerBlock, 0, 0);
 
  518         auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
 
  519             p_a_block, a_block_desc_k0_m_k1.GetElementSpaceSize());
 
  520         auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
 
  521             p_b_block, b_block_desc_k0_n_k1.GetElementSpaceSize());
 
  526         uint32_t block_idx = block_mapping.get_block_idx();
 
  527         bool is_sk_block   = block_idx < block_mapping.sk_num_blocks;
 
  528         bool is_dp_block   = block_idx >= block_mapping.dp_start_block_idx &&
 
  529                            block_idx < block_mapping.reduction_start_block_idx;
 
  530         bool is_reduction_block = block_idx >= block_mapping.reduction_start_block_idx;
 
  531         bool is_padding_block   = block_idx >= block_mapping.sk_num_blocks &&
 
  532                                 block_idx < block_mapping.dp_start_block_idx;
 
  534         block_mapping.get_block_itr(block_idx, iter_start, iter_end);
 
  535         uint32_t total_iter_length = iter_end - iter_start;
 
  541             reinterpret_cast<uint32_t*
>(
reinterpret_cast<char*
>(p_workspace) +
 
  542                                         block_mapping.get_workspace_size_for_acc(
sizeof(
FloatAcc)));
 
  546             if(is_reduction_block)
 
  551                 const auto reduce_thread_cluster_idx =
 
  553                 const auto thread_m_cluster_id = reduce_thread_cluster_idx[
I0];
 
  554                 const auto thread_n_cluster_id = reduce_thread_cluster_idx[
I1];
 
  556                 constexpr 
auto MReduceIters =
 
  560                     cluster_length_reduce.At(
I1) *
 
  571                     0, cluster_length_reduce.At(
I1) * CBlockTransferScalarPerVector_NWaveNPerXDL);
 
  572                 constexpr 
auto partial_acc_load_step_n_reverse =
 
  574                                      -1 * cluster_length_reduce.At(
I1).value * (NReduceIters - 1) *
 
  575                                          CBlockTransferScalarPerVector_NWaveNPerXDL);
 
  576                 constexpr 
auto partial_acc_load_step_m =
 
  583                     cluster_length_reduce.At(
I1) * CBlockTransferScalarPerVector_NWaveNPerXDL);
 
  584                 constexpr 
auto partial_acc_store_step_n_reverse =
 
  588                                      -1 * cluster_length_reduce.At(
I1).value * (NReduceIters - 1) *
 
  589                                          CBlockTransferScalarPerVector_NWaveNPerXDL);
 
  590                 constexpr 
auto partial_acc_store_step_m =
 
  595                              CBlockTransferScalarPerVector_NWaveNPerXDL,
 
  600                              CBlockTransferScalarPerVector_NWaveNPerXDL,
 
  605                 auto reduction_idx = blockIdx.x - block_mapping.reduction_start_block_idx;
 
  606                 auto spatial_idx   = block_mapping.tile_to_spatial(reduction_idx, m, n);
 
  611                     block_mapping.get_acc_buffer_offset_from_tile(reduction_idx);
 
  613                     block_mapping.get_acc_buffer_offset_from_tile(reduction_idx + 1);
 
  618                     decltype(c_partial_acc_block_m_n),                       
 
  619                     decltype(acc_thread_buf_load_desc),                      
 
  623                     CBlockTransferScalarPerVector_NWaveNPerXDL,              
 
  626                     >{c_partial_acc_block_m_n,
 
  628                                        thread_n_cluster_id *
 
  629                                            CBlockTransferScalarPerVector_NWaveNPerXDL)};
 
  634                     decltype(acc_thread_buf_store_desc),                     
 
  635                     decltype(c_grid_desc_mblock_mperblock_nblock_nperblock), 
 
  636                     CElementwiseOperation, 
 
  640                     CBlockTransferScalarPerVector_NWaveNPerXDL, 
 
  644                     >{c_grid_desc_mblock_mperblock_nblock_nperblock,
 
  647                                        __builtin_amdgcn_readfirstlane(spatial_idx[
I1]),
 
  648                                        thread_n_cluster_id *
 
  649                                            CBlockTransferScalarPerVector_NWaveNPerXDL),
 
  650                       CElementwiseOperation{}};
 
  653                 wg_barrier.
wait_eq(reduction_idx, tile_acc_offset_end - tile_acc_offset_start);
 
  656                 if(threadIdx.x == 0) {
 
  657                     printf(
"bid:%d, rid:%d, os:%d,%d, spatial:%d,%d\n", 
static_cast<int>(blockIdx.x),
 
  658                         reduction_idx, __builtin_amdgcn_readfirstlane(tile_acc_offset_start), __builtin_amdgcn_readfirstlane(tile_acc_offset_end),
 
  659                         __builtin_amdgcn_readfirstlane(spatial_idx[
I0]),
 
  660                         __builtin_amdgcn_readfirstlane(spatial_idx[
I1]));
 
  664                 using Accumulation = ck::detail::
 
  667                 for(
int i_m = 0; i_m < MReduceIters; i_m++)
 
  671                         for(
auto i = tile_acc_offset_start; i < tile_acc_offset_end; i++)
 
  673                             auto c_partial_acc_buf =
 
  676                                     reinterpret_cast<FloatAcc*
>(p_workspace) +
 
  677                                         i * c_partial_acc_block_m_n.GetElementSpaceSize(),
 
  678                                     c_partial_acc_block_m_n.GetElementSpaceSize());
 
  680                             acc_load.Run(c_partial_acc_block_m_n,
 
  682                                          acc_thread_buf_load_desc,
 
  688                                     constexpr 
auto offset =
 
  689                                         acc_thread_buf_load_desc.CalculateOffset(
 
  696                         if(thread_n_cluster_id * CBlockTransferScalarPerVector_NWaveNPerXDL <
 
  699                             acc_store.Run(acc_thread_buf_store_desc,
 
  702                                           c_grid_desc_mblock_mperblock_nblock_nperblock,
 
  705                         if constexpr(NReduceIters != 1)
 
  707                             if constexpr(i_n_reduce != (NReduceIters - 1))
 
  709                                 acc_load.MoveSrcSliceWindow(c_partial_acc_block_m_n,
 
  710                                                             partial_acc_load_step_n);
 
  711                                 acc_store.MoveDstSliceWindow(
 
  712                                     c_grid_desc_mblock_mperblock_nblock_nperblock,
 
  713                                     partial_acc_store_step_n);
 
  717                                 acc_load.MoveSrcSliceWindow(c_partial_acc_block_m_n,
 
  718                                                             partial_acc_load_step_n_reverse);
 
  719                                 acc_store.MoveDstSliceWindow(
 
  720                                     c_grid_desc_mblock_mperblock_nblock_nperblock,
 
  721                                     partial_acc_store_step_n_reverse);
 
  726                         acc_load.MoveSrcSliceWindow(c_partial_acc_block_m_n,
 
  727                                                     partial_acc_load_step_m);
 
  728                         acc_store.MoveDstSliceWindow(c_grid_desc_mblock_mperblock_nblock_nperblock,
 
  729                                                      partial_acc_store_step_m);
 
  738             (block_mapping.get_acc_buffer_offset_from_block(block_idx + 1) - 1) * MPerBlock *
 
  743             uint32_t current_iter_length = __builtin_amdgcn_readfirstlane(
 
  744                 block_mapping.get_current_iter_length(iter_start, iter_end, total_iter_length));
 
  746             block_mapping.get_tile_idx_with_offset(iter_end - 1, tile_idx, iter_offset);
 
  747             iter_offset = __builtin_amdgcn_readfirstlane(iter_offset - current_iter_length + 1);
 
  748             auto spatial_idx = block_mapping.tile_to_spatial(tile_idx, m, n);
 
  750             const index_t m_block_data_idx_on_grid =
 
  751                 __builtin_amdgcn_readfirstlane(spatial_idx[
I0] * MPerBlock);
 
  753             const index_t n_block_data_idx_on_grid =
 
  754                 __builtin_amdgcn_readfirstlane(spatial_idx[
I1] * NPerBlock);
 
  756             const index_t k0_block_data_idx_on_grid =
 
  757                 __builtin_amdgcn_readfirstlane(iter_offset * K0PerBlock);
 
  760             auto a_blockwise_copy =
 
  762                                                     AElementwiseOperation,
 
  766                                                     ABlockTransferThreadClusterLengths_K0_M_K1,
 
  767                                                     ABlockTransferThreadClusterArrangeOrder,
 
  770                                                     decltype(a_k0_m_k1_grid_desc),
 
  771                                                     decltype(a_block_desc_k0_m_k1),
 
  772                                                     ABlockTransferSrcAccessOrder,
 
  774                                                     ABlockTransferSrcVectorDim,
 
  776                                                     ABlockTransferSrcScalarPerVector,
 
  777                                                     ABlockTransferDstScalarPerVector_K1,
 
  780                                                     AThreadTransferSrcResetCoordinateAfterRun,
 
  785                     a_block_desc_k0_m_k1,
 
  790             auto b_blockwise_copy =
 
  792                                                     BElementwiseOperation,
 
  796                                                     BBlockTransferThreadClusterLengths_K0_N_K1,
 
  797                                                     BBlockTransferThreadClusterArrangeOrder,
 
  800                                                     decltype(b_k0_n_k1_grid_desc),
 
  801                                                     decltype(b_block_desc_k0_n_k1),
 
  802                                                     BBlockTransferSrcAccessOrder,
 
  804                                                     BBlockTransferSrcVectorDim,
 
  806                                                     BBlockTransferSrcScalarPerVector,
 
  807                                                     BBlockTransferDstScalarPerVector_K1,
 
  810                                                     BThreadTransferSrcResetCoordinateAfterRun,
 
  815                     b_block_desc_k0_n_k1,
 
  819             const index_t num_k_block_main_loop = current_iter_length;
 
  821             gridwise_gemm_pipeline.Run(a_k0_m_k1_grid_desc,
 
  822                                        a_block_desc_k0_m_k1,
 
  826                                        a_block_slice_copy_step,
 
  828                                        b_block_desc_k0_n_k1,
 
  832                                        b_block_slice_copy_step,
 
  835                                        num_k_block_main_loop);
 
  839                 constexpr 
index_t MWave = MPerBlock / (MRepeat * MPerXdl);
 
  840                 constexpr 
index_t NWave = NPerBlock / (NRepeat * NPerXdl);
 
  842                 constexpr 
auto c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc =
 
  843                     blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
 
  845                 constexpr 
auto c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc =
 
  846                     blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
 
  848                 constexpr 
auto M0 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(
I0);
 
  849                 constexpr 
auto N0 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(
I1);
 
  850                 constexpr 
auto M1 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(
I2);
 
  851                 constexpr 
auto N1 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(
I3);
 
  852                 constexpr 
auto M2 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(
I4);
 
  853                 constexpr 
auto M3 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(
I5);
 
  854                 constexpr 
auto M4 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(
I6);
 
  855                 constexpr 
auto N2 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(
I7);
 
  857                 constexpr 
auto c_block_desc_mblock_mpershuffle_nblock_npershuffle =
 
  860                 constexpr 
auto c_block_desc_mshuffle_mpershuffle_nshuffle_npershuffle =
 
  863                 auto c_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
 
  865                     c_block_desc_mblock_mpershuffle_nblock_npershuffle.GetElementSpaceSize());
 
  867                 auto c_partial_acc_buf =
 
  868                     make_dynamic_buffer<AddressSpaceEnum::Global, AmdBufferCoherenceEnum::GLC>(
 
  869                         reinterpret_cast<FloatAcc*
>(p_workspace) + block_acc_offset,
 
  870                         c_block_desc_mshuffle_mpershuffle_nshuffle_npershuffle
 
  871                             .GetElementSpaceSize());
 
  874                     c_block_desc_mblock_mpershuffle_nblock_npershuffle,
 
  895                 const auto c_thread_mtx_on_block =
 
  896                     blockwise_gemm.CalculateCThreadOriginDataIndex(
I0, 
I0, 
I0, 
I0);
 
  898                 const index_t m_thread_data_on_block = c_thread_mtx_on_block[
I0];
 
  899                 const index_t n_thread_data_on_block = c_thread_mtx_on_block[
I1];
 
  901                 const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor =
 
  907                 const auto m_thread_data_on_block_idx =
 
  908                     m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex(
 
  911                 const auto n_thread_data_on_block_to_n0_n1_n2_adaptor =
 
  917                 const auto n_thread_data_on_block_idx =
 
  918                     n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex(
 
  925                     decltype(c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc),
 
  926                     decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2),
 
  929                              CShuffleNRepeatPerShuffle,
 
  941                     true>{c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
 
  944                                            m_thread_data_on_block_idx[
I1],
 
  945                                            n_thread_data_on_block_idx[
I1],
 
  946                                            m_thread_data_on_block_idx[
I2],
 
  947                                            m_thread_data_on_block_idx[
I3],
 
  948                                            m_thread_data_on_block_idx[
I4],
 
  949                                            n_thread_data_on_block_idx[
I2]),
 
  955                     CElementwiseOperation, 
 
  958                              CShuffleMRepeatPerShuffle * MWave * MPerXdl,
 
  960                              CShuffleNRepeatPerShuffle * NWave * NPerXdl>, 
 
  961                     CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
 
  965                     decltype(c_block_desc_mblock_mpershuffle_nblock_npershuffle),
 
  966                     decltype(c_grid_desc_mblock_mperblock_nblock_nperblock),
 
  969                     CBlockTransferScalarPerVector_NWaveNPerXDL, 
 
  972                     {c_block_desc_mblock_mpershuffle_nblock_npershuffle,
 
  974                      c_grid_desc_mblock_mperblock_nblock_nperblock,
 
  977                                       __builtin_amdgcn_readfirstlane(spatial_idx[
I1]),
 
  984                     CElementwiseOperation, 
 
  987                              CShuffleMRepeatPerShuffle * MWave * MPerXdl,
 
  989                              CShuffleNRepeatPerShuffle * NWave * NPerXdl>, 
 
  990                     CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
 
  994                     decltype(c_block_desc_mblock_mpershuffle_nblock_npershuffle),
 
  995                     decltype(c_block_desc_mshuffle_mpershuffle_nshuffle_npershuffle),
 
  998                     CBlockTransferScalarPerVector_NWaveNPerXDL, 
 
 1003                     {c_block_desc_mblock_mpershuffle_nblock_npershuffle,
 
 1005                      c_block_desc_mshuffle_mpershuffle_nshuffle_npershuffle,
 
 1009                 constexpr 
auto mxdlperwave_forward_step =
 
 1011                 constexpr 
auto nxdlperwave_forward_step =
 
 1013                 constexpr 
auto nxdlperwave_backward_step =
 
 1017                     constexpr 
auto mxdlperwave = mxdlperwave_iter;
 
 1020                         constexpr 
bool nxdlperwave_forward_sweep =
 
 1021                             (mxdlperwave % (2 * CShuffleMRepeatPerShuffle) == 0);
 
 1023                         constexpr 
index_t nxdlperwave_value =
 
 1024                             nxdlperwave_forward_sweep
 
 1026                                 : (NRepeat - nxdlperwave_iter - CShuffleNRepeatPerShuffle);
 
 1034                         c_thread_copy_vgpr_to_lds.Run(
 
 1035                             c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc,
 
 1038                             c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
 
 1044                         c_block_copy_lds_to_global.SetSrcSliceOrigin(
 
 1045                             c_block_desc_mblock_mpershuffle_nblock_npershuffle,
 
 1050                             c_block_copy_lds_to_global.template 
Run<decltype(c_block_buf),
 
 1051                                                                     decltype(c_grid_buf),
 
 1053                                 c_block_desc_mblock_mpershuffle_nblock_npershuffle,
 
 1055                                 c_grid_desc_mblock_mperblock_nblock_nperblock,
 
 1057                         else if(is_sk_block)
 
 1059                             if constexpr(Block2CTileMap::ReductionStrategy ==
 
 1063                                 c_block_copy_lds_to_partial_acc.SetSrcSliceOrigin(
 
 1064                                     c_block_desc_mblock_mpershuffle_nblock_npershuffle,
 
 1067                                 c_block_copy_lds_to_partial_acc.SetDstSliceOrigin(
 
 1068                                     c_block_desc_mshuffle_mpershuffle_nshuffle_npershuffle,
 
 1069                                     make_tuple(mxdlperwave.value, 0, nxdlperwave.value, 0));
 
 1071                                 c_block_copy_lds_to_partial_acc
 
 1072                                     .template 
Run<decltype(c_block_buf),
 
 1073                                                   decltype(c_partial_acc_buf),
 
 1075                                         c_block_desc_mblock_mpershuffle_nblock_npershuffle,
 
 1077                                         c_block_desc_mshuffle_mpershuffle_nshuffle_npershuffle,
 
 1080                             else if constexpr(Block2CTileMap::ReductionStrategy ==
 
 1083                                 c_block_copy_lds_to_global
 
 1084                                     .template 
Run<decltype(c_block_buf),
 
 1085                                                   decltype(c_grid_buf),
 
 1087                                         c_block_desc_mblock_mpershuffle_nblock_npershuffle,
 
 1089                                         c_grid_desc_mblock_mperblock_nblock_nperblock,
 
 1095                         if constexpr(nxdlperwave_forward_sweep &&
 
 1096                                      (nxdlperwave < NRepeat - CShuffleNRepeatPerShuffle))
 
 1098                             c_block_copy_lds_to_global.MoveDstSliceWindow(
 
 1099                                 c_grid_desc_mblock_mperblock_nblock_nperblock,
 
 1100                                 nxdlperwave_forward_step);
 
 1102                         else if constexpr((!nxdlperwave_forward_sweep) && (nxdlperwave > 0))
 
 1104                             c_block_copy_lds_to_global.MoveDstSliceWindow(
 
 1105                                 c_grid_desc_mblock_mperblock_nblock_nperblock,
 
 1106                                 nxdlperwave_backward_step);
 
 1111                     if constexpr(mxdlperwave < MRepeat - CShuffleMRepeatPerShuffle)
 
 1113                         c_block_copy_lds_to_global.MoveDstSliceWindow(
 
 1114                             c_grid_desc_mblock_mperblock_nblock_nperblock,
 
 1115                             mxdlperwave_forward_step);
 
 1119                 if constexpr(Block2CTileMap::ReductionStrategy ==
 
 1126                         wg_barrier.
inc(tile_idx);
 
 1132             iter_end -= current_iter_length;
 
 1133             if(iter_end <= iter_start)
 
 1138                 block_acc_offset -= MPerBlock * NPerBlock;
 
 1145     template <
typename Layout>
 
 1148         static std::string 
Get() { 
return ""; }
 
 1154         static std::string 
Get() { 
return "R"; }
 
 1160         static std::string 
Get() { 
return "C"; }
 
 1165         auto str = std::stringstream();
 
 1168         str << 
"GemmXdlStreamK_" 
 1169             << std::string(ALayout::name)[0]
 
 1170             << std::string(BLayout::name)[0]
 
 1171             << std::string(CLayout::name)[0]
 
 1173             << 
"B" << BlockSize << 
"_" 
 1174             << 
"Vec" << ABlockTransferSrcScalarPerVector << 
"x" 
 1175             << BBlockTransferSrcScalarPerVector << 
"x" 
 1176             << CBlockTransferScalarPerVector_NWaveNPerXDL << 
"_" 
 1179             << K0PerBlock << 
"x" 
#define CK_MIN_BLOCK_PER_CU
Definition: ck.hpp:31
 
#define CK_MAX_THREAD_PER_BLOCK
Definition: ck.hpp:30
 
#define IS_VALID_COMPILATION_PARAMETER_IMPL(CDataType_)
Definition: device_base.hpp:178
 
__host__ constexpr __device__ auto integer_least_multiple(X x, Y y)
Definition: math.hpp:78
 
__host__ constexpr __device__ auto integer_divide_ceil(X x, Y y)
Definition: math.hpp:72
 
__host__ constexpr __device__ T max(T x)
Definition: math.hpp:84
 
ck_tile::element_wise::PassThrough PassThrough
Definition: grouped_convolution_utils.hpp:47
 
__host__ constexpr __device__ auto make_multi_index(Xs &&... xs)
Definition: array_multi_index.hpp:15
 
@ Atomic
Definition: block_to_ctile_map.hpp:1012
 
@ Reduction
Definition: block_to_ctile_map.hpp:1013
 
__host__ constexpr __device__ auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition: tensor_descriptor_helper.hpp:49
 
__host__ constexpr __device__ auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition: tensor_descriptor_helper.hpp:101
 
__host__ constexpr __device__ auto make_merge_transform(const LowLengths &low_lengths)
Definition: multi_index_transform_helper.hpp:55
 
__host__ constexpr __device__ auto make_single_stage_tensor_adaptor(const Transforms &transforms, LowerDimensionOldTopIdss, UpperDimensionNewTopIdss)
Definition: tensor_adaptor.hpp:425
 
__host__ constexpr __device__ auto make_freeze_transform(const LowerIndex &low_idx)
Definition: multi_index_transform_helper.hpp:151
 
constexpr detail::ignore_t ignore
Definition: ignore.hpp:20
 
__global__ void kernel_gemm_xdlops_streamk(const typename GridwiseGemm::FloatAB *p_a_grid, const typename GridwiseGemm::FloatAB *p_b_grid, typename GridwiseGemm::FloatC *p_c_grid, void *p_workspace, index_t M, index_t N, index_t K, index_t StrideA, index_t StrideB, index_t StrideC, typename GridwiseGemm::Block2CTileMap block_mapping)
Definition: gridwise_gemm_xdlops_streamk.hpp:28
 
__host__ constexpr __device__ auto make_pass_through_transform(const LowLength &low_length)
Definition: multi_index_transform_helper.hpp:12
 
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
 
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition: type.hpp:297
 
__host__ constexpr __device__ auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition: multi_index_transform_helper.hpp:90
 
int32_t index_t
Definition: ck.hpp:299
 
__device__ index_t get_thread_local_1d_id()
Definition: get_id.hpp:41
 
__host__ constexpr __device__ auto make_dynamic_buffer(T *p, ElementSpaceSize element_space_size)
Definition: dynamic_buffer.hpp:472
 
__host__ constexpr __device__ auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition: tensor_descriptor.hpp:319
 
__host__ constexpr __device__ auto make_right_pad_transform(const LowLength &low_length, const RightPadLength &right_pad, integral_constant< bool, SkipIsValidCheck >=integral_constant< bool, false >{})
Definition: multi_index_transform_helper.hpp:37
 
__device__ void block_sync_lds()
Definition: synchronization.hpp:16
 
__host__ constexpr __device__ auto make_cluster_descriptor(const Lengths &lengths, ArrangeOrder order=typename arithmetic_sequence_gen< 0, Lengths::Size(), 1 >::type{})
Definition: cluster_descriptor.hpp:13
 
unsigned int uint32_t
Definition: stdint.h:126
 
unsigned char uint8_t
Definition: stdint.h:124
 
Definition: block_to_ctile_map.hpp:541
 
Definition: blockwise_gemm_smfmac_xdlops.hpp:44
 
Definition: gridwise_gemm_xdlops_streamk.hpp:140
 
index_t StrideB
Definition: gridwise_gemm_xdlops_streamk.hpp:148
 
index_t StrideC
Definition: gridwise_gemm_xdlops_streamk.hpp:149
 
FloatC * p_c_grid
Definition: gridwise_gemm_xdlops_streamk.hpp:143
 
Argument(const FloatAB *p_a_grid_, const FloatAB *p_b_grid_, FloatC *p_c_grid_, index_t M_, index_t N_, index_t K_, index_t StrideA_, index_t StrideB_, index_t StrideC_, uint32_t num_cu, uint32_t occupancy, uint32_t num_sk_blocks_)
Definition: gridwise_gemm_xdlops_streamk.hpp:152
 
Block2CTileMap block_mapping
Definition: gridwise_gemm_xdlops_streamk.hpp:150
 
index_t M
Definition: gridwise_gemm_xdlops_streamk.hpp:144
 
index_t N
Definition: gridwise_gemm_xdlops_streamk.hpp:145
 
index_t K
Definition: gridwise_gemm_xdlops_streamk.hpp:146
 
void Print() const
Definition: gridwise_gemm_xdlops_streamk.hpp:177
 
index_t StrideA
Definition: gridwise_gemm_xdlops_streamk.hpp:147
 
const FloatAB * p_b_grid
Definition: gridwise_gemm_xdlops_streamk.hpp:142
 
const FloatAB * p_a_grid
Definition: gridwise_gemm_xdlops_streamk.hpp:141
 
static std::string Get()
Definition: gridwise_gemm_xdlops_streamk.hpp:1160
 
static std::string Get()
Definition: gridwise_gemm_xdlops_streamk.hpp:1154
 
Definition: gridwise_gemm_xdlops_streamk.hpp:1147
 
static std::string Get()
Definition: gridwise_gemm_xdlops_streamk.hpp:1148
 
Definition: gridwise_gemm_xdlops_streamk.hpp:115
 
FloatC_ FloatC
Definition: gridwise_gemm_xdlops_streamk.hpp:137
 
static constexpr auto I6
Definition: gridwise_gemm_xdlops_streamk.hpp:122
 
static constexpr auto I5
Definition: gridwise_gemm_xdlops_streamk.hpp:121
 
static __device__ void Run(const FloatAB *p_a_grid, const FloatAB *p_b_grid, FloatC *p_c_grid, void *p_workspace, index_t M, index_t N, index_t K, index_t StrideA, index_t StrideB, index_t StrideC, Block2CTileMap block_mapping, void *__restrict__ p_shared_block)
Definition: gridwise_gemm_xdlops_streamk.hpp:444
 
FloatAB_ FloatAB
Definition: gridwise_gemm_xdlops_streamk.hpp:136
 
static constexpr auto N01
Definition: gridwise_gemm_xdlops_streamk.hpp:128
 
static constexpr auto I3
Definition: gridwise_gemm_xdlops_streamk.hpp:119
 
__host__ static constexpr __device__ auto GetCBlockDescriptor_MBlock_MPerShuffle_NBlock_NPerShuffle()
Definition: gridwise_gemm_xdlops_streamk.hpp:388
 
__host__ static constexpr __device__ bool CheckValidity(const Argument &karg)
Definition: gridwise_gemm_xdlops_streamk.hpp:315
 
__host__ static constexpr __device__ auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()
Definition: gridwise_gemm_xdlops_streamk.hpp:281
 
__host__ static constexpr __device__ index_t GetSharedMemoryNumberOfByte()
Definition: gridwise_gemm_xdlops_streamk.hpp:289
 
FloatAcc FloatCShuffle
Definition: gridwise_gemm_xdlops_streamk.hpp:133
 
static std::string GetTypeString()
Definition: gridwise_gemm_xdlops_streamk.hpp:1163
 
__host__ static __device__ auto CalculateGridSize(const Argument &karg)
Definition: gridwise_gemm_xdlops_streamk.hpp:185
 
static constexpr auto I1
Definition: gridwise_gemm_xdlops_streamk.hpp:117
 
__host__ static __device__ auto MakeAGridDescriptor_K0_M_K1(index_t M, index_t MPad, index_t K, index_t KPad, index_t StrideA)
Definition: gridwise_gemm_xdlops_streamk.hpp:195
 
__host__ static constexpr __device__ bool CalculateHasMainK0BlockLoop(index_t K0)
Definition: gridwise_gemm_xdlops_streamk.hpp:353
 
__host__ static constexpr __device__ auto GetClusterLengthReduction()
Definition: gridwise_gemm_xdlops_streamk.hpp:413
 
static constexpr auto I4
Definition: gridwise_gemm_xdlops_streamk.hpp:120
 
__host__ static __device__ auto MakeBGridDescriptor_K0_N_K1(index_t K, index_t KPad, index_t N, index_t NPad, index_t StrideB)
Definition: gridwise_gemm_xdlops_streamk.hpp:224
 
Block2CTileMap_ Block2CTileMap
Definition: gridwise_gemm_xdlops_streamk.hpp:135
 
__host__ static constexpr __device__ auto MakeCBlockClusterAdaptor(const CGridDesc &c_m_n_grid_desc, index_t, index_t, index_t KBatch)
Definition: gridwise_gemm_xdlops_streamk.hpp:380
 
static constexpr index_t MXdlPerWave
Definition: gridwise_gemm_xdlops_streamk.hpp:311
 
static constexpr auto I2
Definition: gridwise_gemm_xdlops_streamk.hpp:118
 
FloatAcc_ FloatAcc
Definition: gridwise_gemm_xdlops_streamk.hpp:132
 
static constexpr auto K1
Definition: gridwise_gemm_xdlops_streamk.hpp:126
 
remove_cvref_t< decltype(MakeCGridDescriptor_M_N(1, 1, 1, 1, 1))> CGridDesc_M_N
Definition: gridwise_gemm_xdlops_streamk.hpp:442
 
static constexpr auto I0
Definition: gridwise_gemm_xdlops_streamk.hpp:116
 
__host__ static constexpr __device__ auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
Definition: gridwise_gemm_xdlops_streamk.hpp:273
 
__host__ static constexpr __device__ auto GetPartialAccBlockDescriptor()
Definition: gridwise_gemm_xdlops_streamk.hpp:425
 
static constexpr auto KPerBlock
Definition: gridwise_gemm_xdlops_streamk.hpp:129
 
__host__ static constexpr __device__ auto MakeCGridDesc_MBlock_MPerBlock_NBlock_NPerBlock(const CGridDesc &c_m_n_grid_desc)
Definition: gridwise_gemm_xdlops_streamk.hpp:362
 
static constexpr auto I7
Definition: gridwise_gemm_xdlops_streamk.hpp:123
 
__host__ static __device__ auto MakeCGridDescriptor_M_N(index_t M, index_t MPad, index_t N, index_t NPad, index_t StrideC)
Definition: gridwise_gemm_xdlops_streamk.hpp:253
 
ThisThreadBlock< BlockSize > ThisThreadBlock
Definition: gridwise_gemm_xdlops_streamk.hpp:131
 
__host__ static __device__ auto CalculateK0(index_t KPad)
Definition: gridwise_gemm_xdlops_streamk.hpp:192
 
static constexpr auto M01
Definition: gridwise_gemm_xdlops_streamk.hpp:127
 
__host__ static constexpr __device__ auto GetCBlockDescriptor_MShuffleRepeat_MPerShuffle_NShuffleRepeat_NPerShuffle()
Definition: gridwise_gemm_xdlops_streamk.hpp:401
 
static constexpr index_t NXdlPerWave
Definition: gridwise_gemm_xdlops_streamk.hpp:312
 
Definition: gridwise_gemm_pipeline_v3.hpp:11
 
Definition: sequence.hpp:43
 
Definition: static_buffer.hpp:16
 
__host__ __device__ void Clear()
Definition: static_buffer.hpp:63
 
Blockwise data transfer.
Definition: thread_group_tensor_slice_transfer_v4r1.hpp:46
 
Definition: thread_group_tensor_slice_transfer_v6r1r2.hpp:33
 
Definition: threadwise_tensor_slice_transfer.hpp:39
 
Helper structure that facilitates transfer of source (grid) data to destination threads.
Definition: threadwise_tensor_slice_transfer.hpp:234
 
Definition: integral_constant.hpp:20
 
Definition: reduction_operator.hpp:37
 
Definition: functional2.hpp:33
 
Definition: tensor_layout.hpp:31
 
Definition: tensor_layout.hpp:26
 
Definition: device_base.hpp:197
 
Definition: unary_element_wise_operation.hpp:340
 
Definition: workgroup_barrier.hpp:7
 
__device__ void inc(uint32_t offset)
Definition: workgroup_barrier.hpp:62
 
__device__ void wait_eq(uint32_t offset, uint32_t value)
Definition: workgroup_barrier.hpp:29