20 template <
typename ABDataType,
 
   21           typename FloatGemmAcc,
 
   22           typename EDataTypeShuffle,
 
   24           typename AElementwiseOperation,
 
   25           typename BElementwiseOperation,
 
   26           typename EElementwiseOperation,
 
   28           typename AGridDesc_M_K,
 
   29           typename BGridDesc_N_K,
 
   30           typename EGridDesc_M_N,
 
   32           index_t TileLoadThreadGroupSize,
 
   33           index_t TileMathThreadGroupSize,
 
   43           typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
 
   44           typename ABlockTransferThreadClusterArrangeOrder,
 
   45           typename ABlockTransferSrcAccessOrder,
 
   46           index_t ABlockTransferSrcVectorDim,
 
   47           index_t ABlockTransferSrcScalarPerVector,
 
   48           index_t ABlockTransferDstScalarPerVector_AK1,
 
   49           bool AThreadTransferSrcResetCoordinateAfterRun,
 
   51           typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
 
   52           typename BBlockTransferThreadClusterArrangeOrder,
 
   53           typename BBlockTransferSrcAccessOrder,
 
   54           index_t BBlockTransferSrcVectorDim,
 
   55           index_t BBlockTransferSrcScalarPerVector,
 
   56           index_t BBlockTransferDstScalarPerVector_BK1,
 
   57           bool BThreadTransferSrcResetCoordinateAfterRun,
 
   59           index_t CShuffleMXdlPerWavePerShuffle,
 
   60           index_t CShuffleNXdlPerWavePerShuffle,
 
   61           typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
 
   62           index_t CShuffleBlockTransferScalarPerVector_NPerBlock>
 
  132     __host__ __device__ 
static constexpr 
auto 
  135         constexpr 
index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
 
  136         constexpr 
index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
 
  138         constexpr 
auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
 
  145         return c_shuffle_block_desc_mblock_mperblock_nblock_nperblock;
 
  158             a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
 
  161             b_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align);
 
  164         constexpr 
auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
 
  167         constexpr 
auto c_block_size =
 
  168             c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize();
 
  170         return math::max((a_block_space_size_aligned + b_block_space_size_aligned) *
 
  172                          c_block_size * 
sizeof(EDataTypeShuffle));
 
  179         return ck::tensor_operation::device::IsValidGemmCompilationParameter<
 
  188             CGlobalMemoryDataOperation>();
 
  192     template <
typename Block2ETileMap>
 
  193     __host__ __device__ 
static constexpr 
bool 
  195                   const BGridDesc_N_K& b_grid_desc_n_k,
 
  196                   const EGridDesc_M_N& e_grid_desc_m_n,
 
  197                   const Block2ETileMap& )
 
  199         static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) &&
 
  200                           (NPerBlock % (NXdlPerWave * NPerXdl)) == 0,
 
  201                       "Invalid tuning param!");
 
  203         const auto M = a_grid_desc_m_k.GetLength(
I0);
 
  204         const auto N = b_grid_desc_n_k.GetLength(
I0);
 
  205         const auto K = a_grid_desc_m_k.GetLength(
I1);
 
  208         if(!(M == e_grid_desc_m_n.GetLength(
I0) && N == e_grid_desc_m_n.GetLength(
I1) &&
 
  209              K == b_grid_desc_n_k.GetLength(
I1)))
 
  215         if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K % KPerBlock == 0))
 
  221         const auto num_k_loop = K / KPerBlock;
 
  223         if(!GridwiseGemmMath::IsSupported(num_k_loop))
 
  233         if(!(a_grid_desc_m_k.GetElementSpaceSize() * 
sizeof(ABDataType) <= TwoGB &&
 
  234              b_grid_desc_n_k.GetElementSpaceSize() * 
sizeof(ABDataType) <= TwoGB &&
 
  235              e_grid_desc_m_n.GetElementSpaceSize() * 
sizeof(EDataType) <= TwoGB))
 
  245         const index_t num_loop = K / KPerBlock;
 
  247         return GridwiseGemmMath::CalculateHasMainLoop(num_loop);
 
  251     __host__ __device__ 
static constexpr 
auto 
  254         const auto M = e_grid_desc_m_n.GetLength(
I0);
 
  255         const auto N = e_grid_desc_m_n.GetLength(
I1);
 
  260         const auto M0 = M / M1;
 
  261         const auto N0 = N / N1;
 
  263         constexpr 
auto M01 = 
I1;
 
  264         constexpr 
auto N01 = 
I1;
 
  266         const auto m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor =
 
  273         const auto cblockid_to_m00_m01_n00_n01_block_cluster_adaptor =
 
  279         const auto cblockid_to_m0_n0_block_cluster_adaptor =
 
  281                                   cblockid_to_m00_m01_n00_n01_block_cluster_adaptor);
 
  283         return cblockid_to_m0_n0_block_cluster_adaptor;
 
  286     __host__ __device__ 
static constexpr 
index_t 
  289         const auto M = e_grid_desc_m_n.GetLength(
I0);
 
  290         const auto N = e_grid_desc_m_n.GetLength(
I1);
 
  292         const index_t grid_size = (M / MPerBlock) * (N / NPerBlock);
 
  298     __host__ __device__ 
static constexpr 
auto 
  301         const auto M = a_grid_desc_m_k.GetLength(
I0);
 
  302         const auto K = a_grid_desc_m_k.GetLength(
I1);
 
  304         const auto AK0 = K / 
AK1;
 
  314     __host__ __device__ 
static constexpr 
auto 
  317         const auto N = b_grid_desc_n_k.GetLength(
I0);
 
  318         const auto K = b_grid_desc_n_k.GetLength(
I1);
 
  320         const auto BK0 = K / 
BK1;
 
  330     template <
typename EGr
idDescriptor_M_N>
 
  332         const EGridDescriptor_M_N& e_grid_desc_m_n)
 
  334         const auto M = e_grid_desc_m_n.GetLength(
I0);
 
  335         const auto N = e_grid_desc_m_n.GetLength(
I1);
 
  337         const auto MBlock = M / MPerBlock;
 
  338         const auto NBlock = N / NPerBlock;
 
  347         return e_grid_desc_mblock_mperblock_nblock_nperblock;
 
  357     template <
bool HasMainKBlockLoop,
 
  358               typename AGridDesc_AK0_M_AK1,
 
  359               typename BGridDesc_BK0_N_BK1,
 
  360               typename Block2ETileMap>
 
  361     __device__ 
static void Run(
const ABDataType* __restrict__ p_a_grid,
 
  362                                const ABDataType* __restrict__ p_b_grid,
 
  363                                EDataType* __restrict__ p_e_grid,
 
  364                                void* __restrict__ p_shared,
 
  365                                const AElementwiseOperation& a_element_op,
 
  366                                const BElementwiseOperation& b_element_op,
 
  367                                const EElementwiseOperation& e_element_op,
 
  368                                const AGridDesc_AK0_M_AK1& a_grid_desc_ak0_m_ak1,
 
  369                                const BGridDesc_BK0_N_BK1& b_grid_desc_bk0_n_bk1,
 
  371                                    e_grid_desc_mblock_mperblock_nblock_nperblock,
 
  372                                const Block2ETileMap& block_2_etile_map)
 
  388             a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
 
  390         auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
 
  391             static_cast<ABDataType*
>(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
 
  393         auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
 
  394             static_cast<ABDataType*
>(p_shared) + a_block_space_size_aligned,
 
  395             b_block_desc_bk0_n_bk1.GetElementSpaceSize());
 
  400         const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(
 
  401             (a_grid_desc_ak0_m_ak1.GetLength(
I0) * a_grid_desc_ak0_m_ak1.GetLength(
I2)) /
 
  405         const auto block_work_idx =
 
  409         const index_t m_block_data_idx_on_grid =
 
  410             __builtin_amdgcn_readfirstlane(block_work_idx[
I0] * MPerBlock);
 
  412         const index_t n_block_data_idx_on_grid =
 
  413             __builtin_amdgcn_readfirstlane(block_work_idx[
I1] * NPerBlock);
 
  419             const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
 
  420                 p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
 
  421             const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
 
  422                 p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize());
 
  425             auto a_blockwise_copy =
 
  427                                                     AElementwiseOperation,
 
  431                                                     ABlockTransferThreadClusterLengths_AK0_M_AK1,
 
  432                                                     ABlockTransferThreadClusterArrangeOrder,
 
  435                                                     decltype(a_grid_desc_ak0_m_ak1),
 
  436                                                     decltype(a_block_desc_ak0_m_ak1),
 
  437                                                     ABlockTransferSrcAccessOrder,
 
  439                                                     ABlockTransferSrcVectorDim,
 
  441                                                     ABlockTransferSrcScalarPerVector,
 
  442                                                     ABlockTransferDstScalarPerVector_AK1,
 
  445                                                     AThreadTransferSrcResetCoordinateAfterRun,
 
  447                                                     NumGemmKPrefetchStage>(
 
  448                     a_grid_desc_ak0_m_ak1,
 
  451                     a_block_desc_ak0_m_ak1,
 
  456             auto b_blockwise_copy =
 
  458                                                     BElementwiseOperation,
 
  462                                                     BBlockTransferThreadClusterLengths_BK0_N_BK1,
 
  463                                                     BBlockTransferThreadClusterArrangeOrder,
 
  466                                                     decltype(b_grid_desc_bk0_n_bk1),
 
  467                                                     decltype(b_block_desc_bk0_n_bk1),
 
  468                                                     BBlockTransferSrcAccessOrder,
 
  470                                                     BBlockTransferSrcVectorDim,
 
  472                                                     BBlockTransferSrcScalarPerVector,
 
  473                                                     BBlockTransferDstScalarPerVector_BK1,
 
  476                                                     BThreadTransferSrcResetCoordinateAfterRun,
 
  478                                                     NumGemmKPrefetchStage>(
 
  479                     b_grid_desc_bk0_n_bk1,
 
  482                     b_block_desc_bk0_n_bk1,
 
  486             GridwiseGemmLoad::template RunLoadWavePipeline<HasMainKBlockLoop>(
 
  487                 a_grid_desc_ak0_m_ak1,
 
  488                 a_block_desc_ak0_m_ak1,
 
  492                 a_block_slice_copy_step,
 
  493                 b_grid_desc_bk0_n_bk1,
 
  494                 b_block_desc_bk0_n_bk1,
 
  498                 b_block_slice_copy_step,
 
  499                 num_k_block_main_loop);
 
  508             constexpr 
bool is_single_rate_mfma =
 
  516             constexpr 
auto is_scale_mfma = 
false;
 
  524                                        is_scale_mfma>::selected_mfma.k_per_blk);
 
  527                 TileMathThreadGroupSize,
 
  531                 decltype(a_block_desc_ak0_m_ak1),
 
  532                 decltype(b_block_desc_bk0_n_bk1),
 
  540             auto c_grid_buf   = make_dynamic_buffer<AddressSpaceEnum::Global>(
 
  541                 p_e_grid, e_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
 
  545             GridwiseGemmMath::template RunMathWavePipeline<HasMainKBlockLoop>(
 
  546                 a_block_buf, b_block_buf, blockwise_gemm, c_thread_buf, num_k_block_main_loop);
 
  558                 static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 &&
 
  559                                   NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0,
 
  562                 constexpr 
index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
 
  563                 constexpr 
index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
 
  566                 constexpr 
auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
 
  567                     blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
 
  571                 constexpr 
auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp =
 
  572                     blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
 
  574                 constexpr 
auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I0);
 
  575                 constexpr 
auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I1);
 
  576                 constexpr 
auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I2);
 
  577                 constexpr 
auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I3);
 
  578                 constexpr 
auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I4);
 
  579                 constexpr 
auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I5);
 
  580                 constexpr 
auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I6);
 
  581                 constexpr 
auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I7);
 
  583                 constexpr 
auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
 
  586                 auto c_shuffle_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
 
  587                     static_cast<EDataTypeShuffle*
>(p_shared),
 
  588                     c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
 
  591                     c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
 
  613                 const auto c_thread_mtx_on_block =
 
  614                     blockwise_gemm.CalculateCThreadOriginDataIndex(
I0, 
I0, 
I0, 
I0);
 
  616                 const index_t m_thread_data_on_block = c_thread_mtx_on_block[
I0];
 
  617                 const index_t n_thread_data_on_block = c_thread_mtx_on_block[
I1];
 
  619                 const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor =
 
  625                 const auto m_thread_data_on_block_idx =
 
  626                     m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex(
 
  629                 const auto n_thread_data_on_block_to_n0_n1_n2_adaptor =
 
  635                 const auto n_thread_data_on_block_idx =
 
  636                     n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex(
 
  643                     decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2),
 
  644                     decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2),
 
  646                     Sequence<CShuffleMXdlPerWavePerShuffle,
 
  647                              CShuffleNXdlPerWavePerShuffle,
 
  659                     true>{c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
 
  662                                            m_thread_data_on_block_idx[
I1],
 
  663                                            n_thread_data_on_block_idx[
I1],
 
  664                                            m_thread_data_on_block_idx[
I2],
 
  665                                            m_thread_data_on_block_idx[
I3],
 
  666                                            m_thread_data_on_block_idx[
I4],
 
  667                                            n_thread_data_on_block_idx[
I2]),
 
  673                     EElementwiseOperation,            
 
  674                     CGlobalMemoryDataOperation,       
 
  676                              CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
 
  678                              CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, 
 
  679                     CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
 
  683                     decltype(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock),
 
  684                     decltype(e_grid_desc_mblock_mperblock_nblock_nperblock),
 
  687                     CShuffleBlockTransferScalarPerVector_NPerBlock, 
 
  690                     {c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
 
  692                      e_grid_desc_mblock_mperblock_nblock_nperblock,
 
  697                 constexpr 
auto sfc_c_vgpr =
 
  700                                       Sequence<CShuffleMXdlPerWavePerShuffle,
 
  701                                                CShuffleNXdlPerWavePerShuffle,
 
  710                 constexpr 
auto sfc_c_global =
 
  714                                                CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
 
  716                                                CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{};
 
  718                 constexpr 
index_t num_access = sfc_c_vgpr.GetNumOfAccess();
 
  720                 static_assert(num_access == sfc_c_global.GetNumOfAccess(), 
"wrong!");
 
  749                     c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2,
 
  750                                                   sfc_c_vgpr.GetIndexTupleOfNumber(access_id),
 
  752                                                   c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
 
  753                                                   c_shuffle_block_buf);
 
  758                     c_shuffle_block_copy_lds_to_global.Run(
 
  759                         c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
 
  761                         e_grid_desc_mblock_mperblock_nblock_nperblock,
 
  764                     if constexpr(access_id < num_access - 1)
 
  766                         constexpr 
auto c_global_step = sfc_c_global.GetForwardStep(access_id);
 
  769                         c_shuffle_block_copy_lds_to_global.MoveDstSliceWindow(
 
  770                             e_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step);
 
Y __host__ constexpr __device__ auto lcm(X x, Y y)
Definition: math.hpp:198
 
__host__ constexpr __device__ auto integer_least_multiple(X x, Y y)
Definition: math.hpp:78
 
__host__ constexpr __device__ T max(T x)
Definition: math.hpp:84
 
ck_tile::element_wise::PassThrough PassThrough
Definition: grouped_convolution_utils.hpp:47
 
__host__ constexpr __device__ auto make_multi_index(Xs &&... xs)
Definition: array_multi_index.hpp:15
 
__host__ constexpr __device__ auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition: tensor_descriptor_helper.hpp:49
 
InMemoryDataOperationEnum
Definition: ck.hpp:277
 
__host__ constexpr __device__ auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition: tensor_descriptor_helper.hpp:101
 
__host__ constexpr __device__ auto make_merge_transform(const LowLengths &low_lengths)
Definition: multi_index_transform_helper.hpp:55
 
int64_t long_index_t
Definition: ck.hpp:300
 
__host__ constexpr __device__ auto make_single_stage_tensor_adaptor(const Transforms &transforms, LowerDimensionOldTopIdss, UpperDimensionNewTopIdss)
Definition: tensor_adaptor.hpp:425
 
__host__ constexpr __device__ auto make_freeze_transform(const LowerIndex &low_idx)
Definition: multi_index_transform_helper.hpp:151
 
__device__ index_t get_block_1d_id()
Definition: get_id.hpp:47
 
__host__ constexpr __device__ auto make_pass_through_transform(const LowLength &low_length)
Definition: multi_index_transform_helper.hpp:12
 
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
 
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition: type.hpp:297
 
__host__ constexpr __device__ auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition: multi_index_transform_helper.hpp:90
 
int32_t index_t
Definition: ck.hpp:299
 
__device__ index_t get_thread_local_1d_id()
Definition: get_id.hpp:41
 
__host__ constexpr __device__ auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition: tensor_descriptor.hpp:319
 
__device__ void block_sync_lds()
Definition: synchronization.hpp:16
 
__host__ constexpr __device__ auto chain_tensor_adaptors(const TensorAdaptor0 &adaptor0, const TensorAdaptor1 &adaptor1)
Definition: tensor_adaptor.hpp:245
 
Definition: blockwise_gemm_smfmac_xdlops.hpp:44
 
__host__ constexpr __device__ auto & GetCThreadBuffer()
Definition: blockwise_gemm_smfmac_xdlops.hpp:78
 
Definition: gridwise_gemm_xdl_waveletmodel_cshuffle.hpp:83
 
static __device__ index_t GetThreadId()
Definition: gridwise_gemm_xdl_waveletmodel_cshuffle.hpp:91
 
static constexpr __device__ bool IsBelong()
Definition: gridwise_gemm_xdl_waveletmodel_cshuffle.hpp:86
 
static constexpr __device__ index_t GetNumOfThread()
Definition: gridwise_gemm_xdl_waveletmodel_cshuffle.hpp:84
 
Definition: gridwise_gemm_xdl_waveletmodel_cshuffle.hpp:98
 
static __device__ index_t GetThreadId()
Definition: gridwise_gemm_xdl_waveletmodel_cshuffle.hpp:106
 
static constexpr __device__ index_t GetNumOfThread()
Definition: gridwise_gemm_xdl_waveletmodel_cshuffle.hpp:99
 
static constexpr __device__ bool IsBelong()
Definition: gridwise_gemm_xdl_waveletmodel_cshuffle.hpp:101
 
Definition: gridwise_gemm_xdl_waveletmodel_cshuffle.hpp:64
 
static constexpr auto BlockSize
Definition: gridwise_gemm_xdl_waveletmodel_cshuffle.hpp:80
 
__host__ static constexpr __device__ auto MakeDefaultBlock2ETileMap(const EGridDesc_M_N &e_grid_desc_m_n)
Definition: gridwise_gemm_xdl_waveletmodel_cshuffle.hpp:252
 
__host__ static constexpr __device__ auto GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock()
Definition: gridwise_gemm_xdl_waveletmodel_cshuffle.hpp:133
 
static constexpr auto I3
Definition: gridwise_gemm_xdl_waveletmodel_cshuffle.hpp:69
 
__host__ static constexpr __device__ index_t GetSharedMemoryNumberOfByte()
Definition: gridwise_gemm_xdl_waveletmodel_cshuffle.hpp:148
 
static constexpr auto AK0PerBlock
Definition: gridwise_gemm_xdl_waveletmodel_cshuffle.hpp:78
 
__host__ static constexpr __device__ auto MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const EGridDescriptor_M_N &e_grid_desc_m_n)
Definition: gridwise_gemm_xdl_waveletmodel_cshuffle.hpp:331
 
__host__ static constexpr __device__ auto MakeDefaultBGridDescriptor_BK0_N_BK1(const BGridDesc_N_K &b_grid_desc_n_k)
Definition: gridwise_gemm_xdl_waveletmodel_cshuffle.hpp:315
 
__host__ static constexpr __device__ bool CheckValidity(const AGridDesc_M_K &a_grid_desc_m_k, const BGridDesc_N_K &b_grid_desc_n_k, const EGridDesc_M_N &e_grid_desc_m_n, const Block2ETileMap &)
Definition: gridwise_gemm_xdl_waveletmodel_cshuffle.hpp:194
 
static constexpr auto BK0PerBlock
Definition: gridwise_gemm_xdl_waveletmodel_cshuffle.hpp:79
 
static constexpr auto I7
Definition: gridwise_gemm_xdl_waveletmodel_cshuffle.hpp:73
 
__host__ static constexpr __device__ auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
Definition: gridwise_gemm_xdl_waveletmodel_cshuffle.hpp:116
 
static constexpr auto I0
Definition: gridwise_gemm_xdl_waveletmodel_cshuffle.hpp:66
 
__host__ static constexpr __device__ auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()
Definition: gridwise_gemm_xdl_waveletmodel_cshuffle.hpp:124
 
ThisThreadBlock< TileMathThreadGroupSize > CShuffleBlockTransferThreadGroup
Definition: gridwise_gemm_xdl_waveletmodel_cshuffle.hpp:109
 
remove_cvref_t< decltype(MakeDefaultBlock2ETileMap(EGridDesc_M_N{}))> DefaultBlock2ETileMap
Definition: gridwise_gemm_xdl_waveletmodel_cshuffle.hpp:355
 
static __device__ void Run(const ABDataType *__restrict__ p_a_grid, const ABDataType *__restrict__ p_b_grid, EDataType *__restrict__ p_e_grid, void *__restrict__ p_shared, const AElementwiseOperation &a_element_op, const BElementwiseOperation &b_element_op, const EElementwiseOperation &e_element_op, const AGridDesc_AK0_M_AK1 &a_grid_desc_ak0_m_ak1, const BGridDesc_BK0_N_BK1 &b_grid_desc_bk0_n_bk1, const EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock &e_grid_desc_mblock_mperblock_nblock_nperblock, const Block2ETileMap &block_2_etile_map)
Definition: gridwise_gemm_xdl_waveletmodel_cshuffle.hpp:361
 
__host__ static constexpr __device__ bool CalculateHasMainKBlockLoop(index_t K)
Definition: gridwise_gemm_xdl_waveletmodel_cshuffle.hpp:243
 
static constexpr auto AK1
Definition: gridwise_gemm_xdl_waveletmodel_cshuffle.hpp:76
 
static constexpr auto I2
Definition: gridwise_gemm_xdl_waveletmodel_cshuffle.hpp:68
 
__host__ static constexpr __device__ index_t CalculateGridSize(const EGridDesc_M_N &e_grid_desc_m_n)
Definition: gridwise_gemm_xdl_waveletmodel_cshuffle.hpp:287
 
remove_cvref_t< decltype(MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(EGridDesc_M_N{}))> EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
Definition: gridwise_gemm_xdl_waveletmodel_cshuffle.hpp:352
 
static constexpr auto I5
Definition: gridwise_gemm_xdl_waveletmodel_cshuffle.hpp:71
 
static constexpr auto I6
Definition: gridwise_gemm_xdl_waveletmodel_cshuffle.hpp:72
 
static constexpr auto I4
Definition: gridwise_gemm_xdl_waveletmodel_cshuffle.hpp:70
 
static constexpr auto BK1
Definition: gridwise_gemm_xdl_waveletmodel_cshuffle.hpp:77
 
static constexpr auto I1
Definition: gridwise_gemm_xdl_waveletmodel_cshuffle.hpp:67
 
static __device__ constexpr bool IsValidCompilationParameter()
Definition: gridwise_gemm_xdl_waveletmodel_cshuffle.hpp:177
 
__host__ static constexpr __device__ auto MakeDefaultAGridDescriptor_AK0_M_AK1(const AGridDesc_M_K &a_grid_desc_m_k)
Definition: gridwise_gemm_xdl_waveletmodel_cshuffle.hpp:299
 
Definition: gridwise_gemm_waveletmodel.hpp:11
 
Definition: gridwise_gemm_waveletmodel.hpp:103
 
Selects the appropriate MFMA instruction type and configuration for given data types and tile sizes o...
Definition: xdlops_gemm.hpp:1208
 
Definition: sequence.hpp:43
 
Definition: tensor_space_filling_curve.hpp:20
 
Definition: thread_group.hpp:12
 
Blockwise data transfer.
Definition: thread_group_tensor_slice_transfer_v4r1.hpp:46
 
Definition: thread_group_tensor_slice_transfer_v6r1.hpp:34
 
Definition: threadwise_tensor_slice_transfer.hpp:39
 
Definition: integral_constant.hpp:20
 
Definition: functional2.hpp:33
 
Definition: unary_element_wise_operation.hpp:340