22 namespace tensor_operation {
 
  124 template <
typename ALayout,
 
  130           typename AccDataType,
 
  131           typename CShuffleDataType,
 
  132           typename AElementwiseOperation,
 
  133           typename BElementwiseOperation,
 
  134           typename CElementwiseOperation,
 
  146           typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
 
  147           typename ABlockTransferThreadClusterArrangeOrder,
 
  148           typename ABlockTransferSrcAccessOrder,
 
  149           index_t ABlockTransferSrcVectorDim,
 
  150           index_t ABlockTransferSrcScalarPerVector,
 
  151           index_t ABlockTransferDstScalarPerVector_AK1,
 
  152           bool ABlockLdsExtraM,
 
  153           typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
 
  154           typename BBlockTransferThreadClusterArrangeOrder,
 
  155           typename BBlockTransferSrcAccessOrder,
 
  156           index_t BBlockTransferSrcVectorDim,
 
  157           index_t BBlockTransferSrcScalarPerVector,
 
  158           index_t BBlockTransferDstScalarPerVector_BK1,
 
  159           bool BBlockLdsExtraN,
 
  160           index_t CShuffleMRepeatPerShuffle,
 
  161           index_t CShuffleNRepeatPerShuffle,
 
  162           typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
 
  163           index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
 
  166           typename ComputeTypeA                       = CDataType,
 
  167           typename ComputeTypeB                       = ComputeTypeA,
 
  168           bool PermuteA                               = 
false,
 
  169           bool PermuteB                               = 
false>
 
  176                                                         AElementwiseOperation,
 
  177                                                         BElementwiseOperation,
 
  178                                                         CElementwiseOperation>
 
  191         AElementwiseOperation,
 
  192         BElementwiseOperation,
 
  193         CElementwiseOperation,
 
  205         ABlockTransferThreadClusterLengths_AK0_M_AK1,
 
  206         ABlockTransferThreadClusterArrangeOrder,
 
  207         ABlockTransferSrcAccessOrder,
 
  208         ABlockTransferSrcVectorDim,
 
  209         ABlockTransferSrcScalarPerVector,
 
  210         ABlockTransferDstScalarPerVector_AK1,
 
  213         BBlockTransferThreadClusterLengths_BK0_N_BK1,
 
  214         BBlockTransferThreadClusterArrangeOrder,
 
  215         BBlockTransferSrcAccessOrder,
 
  216         BBlockTransferSrcVectorDim,
 
  217         BBlockTransferSrcScalarPerVector,
 
  218         BBlockTransferDstScalarPerVector_BK1,
 
  221         CShuffleMRepeatPerShuffle,
 
  222         CShuffleNRepeatPerShuffle,
 
  223         CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
 
  273                              const BDataType* p_b,
 
  282                              AElementwiseOperation a_element_op,
 
  283                              BElementwiseOperation b_element_op,
 
  284                              CElementwiseOperation cde_element_op)
 
  286         return Argument{std::array<const void*, 1>{p_a},
 
  287                         std::array<const void*, 1>{p_b},
 
  288                         std::array<const void*, 0>{}, 
 
  293                         std::array<index_t, 1>{StrideA},
 
  294                         std::array<index_t, 1>{StrideB},
 
  295                         std::array<index_t, 0>{}, 
 
  316                                                       AElementwiseOperation a_element_op,
 
  317                                                       BElementwiseOperation b_element_op,
 
  318                                                       CElementwiseOperation c_element_op)
 override 
  320         return std::make_unique<Argument>(std::array<const void*, 1>{p_a},
 
  321                                           std::array<const void*, 1>{p_b},
 
  322                                           std::array<const void*, 0>{}, 
 
  323                                           static_cast<CDataType*
>(p_c),
 
  327                                           std::array<index_t, 1>{StrideA},
 
  328                                           std::array<index_t, 1>{StrideB},
 
  329                                           std::array<index_t, 0>{}, 
 
  340         return std::make_unique<Invoker>(
Invoker{});
 
  346         auto str = std::stringstream();
 
  348         std::map<BlockGemmPipelineScheduler, std::string> BlkGemmPipelineSchedulerToString{
 
  352         std::map<BlockGemmPipelineVersion, std::string> BlkGemmPipelineVersionToString{
 
  360         str << 
"DeviceGemm_Wmma_CShuffleV3" 
  363             << std::string(ALayout::name)[0]
 
  364             << std::string(BLayout::name)[0]
 
  365             << std::string(CLayout::name)[0]
 
  370             << MPerBlock << 
"x" << NPerBlock << 
"x" << KPerBlock << 
", " 
  372             << MPerWmma << 
"x"<<NPerWmma << 
", " 
  374             << MRepeat << 
"x" << NRepeat << 
", " 
  376             << ABlockTransferSrcScalarPerVector << 
"x" << BBlockTransferSrcScalarPerVector << 
", " 
  377             << 
"BlkGemmPipelineScheduler: " 
  378             << BlkGemmPipelineSchedulerToString[BlkGemmPipeSched] << 
", " 
  379             << 
"BlkGemmPipelineVersion: " 
  380             << BlkGemmPipelineVersionToString[BlkGemmPipelineVer] << 
", " 
  381             << 
"BlkGemmPipelinePrefetchStages: " 
  382             << GridwiseGemm::BlockwiseGemmPipe::PrefetchStages << 
", " 
#define REGISTER_EXTRA_PRINTING_METHODS
Definition: device_base.hpp:47
 
std::string getGemmSpecializationString(const GemmSpecialization &s)
Definition: gemm_specialization.hpp:32
 
GemmSpecialization
Definition: gemm_specialization.hpp:11
 
BlockGemmPipelineVersion
Definition: blkgemmpipe_scheduler.hpp:12
 
BlockGemmPipelineScheduler
Definition: blkgemmpipe_scheduler.hpp:25
 
int32_t index_t
Definition: ck.hpp:299
 
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:406
 
ck::GridwiseGemm_wmma_cshuffle_v3_base< ALayout, BLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOperation, BElementwiseOperation, CDEElementwiseOperation, GemmSpec, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1Value, BK1Value, MPerWmma, NPerWmma, MRepeat, NRepeat, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, AThreadTransferSrcResetCoordinateAfterRun, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, BThreadTransferSrcResetCoordinateAfterRun, BBlockLdsExtraN, CShuffleMRepeatPerShuffle, CShuffleNRepeatPerShuffle, CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CDEShuffleBlockTransferScalarPerVectors, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB, PermuteA, PermuteB, false >::KPack static constexpr index_t KPack
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:154
 
"Universal" GEMM kernel with SplitK support.
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:233
 
Definition: sequence.hpp:43
 
Definition: tuple.hpp:186
 
Definition: tuple.hpp:117
 
Definition: device_base.hpp:197
 
Helper structure responsible for kernel invocation.
Definition: device_gemm_wmma_cshuffle_v3_common.hpp:57
 
Definition: device_gemm_wmma_cshuffle_v3_common.hpp:43
 
static bool IsSupportedArgument(const Argument &arg)
Definition: device_gemm_wmma_cshuffle_v3_common.hpp:268
 
"Universal" GEMM operation with SplitK support.
Definition: device_gemm_wmma_cshuffle_v3.hpp:179
 
static auto MakeArgument(const ADataType *p_a, const BDataType *p_b, CDataType *p_c, index_t M, index_t N, index_t K, index_t StrideA, index_t StrideB, index_t StrideC, index_t KBatch, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CElementwiseOperation cde_element_op)
Definition: device_gemm_wmma_cshuffle_v3.hpp:272
 
std::string GetTypeString() const override
Definition: device_gemm_wmma_cshuffle_v3.hpp:344
 
typename DeviceGemmCommon::Invoker Invoker
Definition: device_gemm_wmma_cshuffle_v3.hpp:254
 
static auto MakeInvoker()
Definition: device_gemm_wmma_cshuffle_v3.hpp:303
 
bool IsSupportedArgument(const BaseArgument *p_arg) override
Definition: device_gemm_wmma_cshuffle_v3.hpp:262
 
bool GetPermuteA() override
Definition: device_gemm_wmma_cshuffle_v3.hpp:269
 
typename GridwiseGemm::Argument Argument
Definition: device_gemm_wmma_cshuffle_v3.hpp:232
 
std::unique_ptr< BaseInvoker > MakeInvokerPointer() override
Definition: device_gemm_wmma_cshuffle_v3.hpp:338
 
bool GetPermuteB() override
Definition: device_gemm_wmma_cshuffle_v3.hpp:270
 
GridwiseGemm_wmma_cshuffle_v3< ALayout, BLayout, Tuple<>, CLayout, Tuple< ADataType >, Tuple< BDataType >, AccDataType, CShuffleDataType, Tuple<>, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerWmma, NPerWmma, MRepeat, NRepeat, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMRepeatPerShuffle, CShuffleNRepeatPerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, Sequence< CShuffleBlockTransferScalarPerVector_NPerBlock >, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB, PermuteA, PermuteB > GridwiseGemm
Definition: device_gemm_wmma_cshuffle_v3.hpp:230
 
index_t GetKPerBlock() override
Definition: device_gemm_wmma_cshuffle_v3.hpp:267
 
std::unique_ptr< BaseArgument > MakeArgumentPointer(const void *p_a, const void *p_b, void *p_c, index_t M, index_t N, index_t K, index_t StrideA, index_t StrideB, index_t StrideC, index_t KBatch, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CElementwiseOperation c_element_op) override
Definition: device_gemm_wmma_cshuffle_v3.hpp:306
 
static bool IsSupportedArgument(const Argument &arg)
Definition: device_gemm_wmma_cshuffle_v3.hpp:256
 
Definition: device_gemm_v2.hpp:22