8 #include <hip/hip_runtime.h> 
   22 namespace tensor_operation {
 
   25 template <
typename ALayout,
 
   30           typename AScaleDataType,
 
   32           typename BScaleDataType,
 
   35           typename GemmAccDataType,
 
   36           typename CShuffleDataType,
 
   37           typename AElementwiseOperation,
 
   38           typename BElementwiseOperation,
 
   39           typename CElementwiseOperation,
 
   54           typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
 
   55           typename ABlockTransferThreadClusterArrangeOrder,
 
   56           typename ABlockTransferSrcAccessOrder,
 
   57           index_t ABlockTransferSrcVectorDim,
 
   58           index_t ABlockTransferSrcScalarPerVector,
 
   59           index_t ABlockTransferDstScalarPerVector_AK1,
 
   61           typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
 
   62           typename BBlockTransferThreadClusterArrangeOrder,
 
   63           typename BBlockTransferSrcAccessOrder,
 
   64           index_t BBlockTransferSrcVectorDim,
 
   65           index_t BBlockTransferSrcScalarPerVector,
 
   66           index_t BBlockTransferDstScalarPerVector_BK1,
 
   68           index_t CShuffleMXdlPerWavePerShuffle,
 
   69           index_t CShuffleNXdlPerWavePerShuffle,
 
   70           typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
 
   71           typename CDEShuffleBlockTransferScalarPerVectors,
 
   75           bool NSwizzle                               = 
false,
 
   76           bool IsInputGemm                            = 
true,
 
   77           bool MulRoutedWeight                        = 
false,
 
   79           typename ComputeTypeA                       = CDataType,
 
   80           typename ComputeTypeB                       = ComputeTypeA,
 
   81           typename LDSTypeA                           = ComputeTypeA,
 
   82           typename LDSTypeB                           = ComputeTypeB>
 
   97                                                         AElementwiseOperation,
 
   98                                                         BElementwiseOperation,
 
   99                                                         CElementwiseOperation>
 
  105     template <index_t NXdlPerWave_>
 
  117         AElementwiseOperation,
 
  118         BElementwiseOperation,
 
  119         CElementwiseOperation,
 
  134         ABlockTransferThreadClusterLengths_AK0_M_AK1,
 
  135         ABlockTransferThreadClusterArrangeOrder,
 
  136         ABlockTransferSrcAccessOrder,
 
  137         ABlockTransferSrcVectorDim,
 
  138         ABlockTransferSrcScalarPerVector,
 
  139         ABlockTransferDstScalarPerVector_AK1,
 
  142         BBlockTransferThreadClusterLengths_BK0_N_BK1,
 
  143         BBlockTransferThreadClusterArrangeOrder,
 
  144         BBlockTransferSrcAccessOrder,
 
  145         BBlockTransferSrcVectorDim,
 
  146         BBlockTransferSrcScalarPerVector,
 
  147         BBlockTransferDstScalarPerVector_BK1,
 
  150         CShuffleMXdlPerWavePerShuffle,
 
  151         math::min(CShuffleNXdlPerWavePerShuffle, NXdlPerWave_),
 
  152         CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
 
  153         CDEShuffleBlockTransferScalarPerVectors,
 
  189         template <
typename Gr
idwiseGemm>
 
  190         float RunImp(
const typename GridwiseGemm::Argument& arg,
 
  193             if(stream_config.log_level_ > 0)
 
  198             if(!GridwiseGemm::CheckValidity(arg))
 
  200                 throw std::runtime_error(
"wrong! GridwiseGemm has invalid setting");
 
  204             std::tie(gdx, gdy, gdz) = GridwiseGemm::CalculateGridSize(arg.M, arg.N);
 
  208             index_t k_grain = arg.KBatch * KPerBlock;
 
  209             index_t K_split = (arg.K + k_grain - 1) / k_grain * KPerBlock;
 
  211             const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K_split);
 
  212             const auto RunKernel             = [&](
const auto& kernel) {
 
  213                 if(stream_config.flush_cache)
 
  216                     std::array<std::size_t, NumDTensor> DsSize;
 
  220                     const auto a_grid_desc_ak0_m_ak1 = GridwiseGemm::MakeAGridDescriptor_AK0_M_AK1(
 
  221                         arg_.M, arg_.MPadded, arg_.K, arg_.KPadded, arg_.StrideA, arg_.AK0);
 
  222                     const auto b_grid_desc_bk0_n_bk1 = GridwiseGemm::MakeBGridDescriptor_BK0_N_BK1(
 
  223                         arg_.K, arg_.KPadded, arg_.N, arg_.NPadded, arg_.StrideB, arg_.BK0);
 
  225                     auto size_a_buffer = a_grid_desc_ak0_m_ak1.GetElementSpaceSize() *
 
  227                     auto size_b_buffer = b_grid_desc_bk0_n_bk1.GetElementSpaceSize() *
 
  230                     const auto ds_grid_desc_m_n = GridwiseGemm::MakeDsGridDescriptor_M_N(
 
  231                         arg_.M, arg_.MPadded, arg_.N, arg_.NPadded, arg_.StrideDs);
 
  235                         DsSize[i] = ds_grid_desc_m_n[i].GetElementSpaceSize() * 
sizeof(DDataType);
 
  240                                      stream_config.rotating_count,
 
  244                     rotating_mem.Print();
 
  246                     auto run_flush_cache = [&]() {
 
  253                             hipGetErrorString(hipMemsetAsync(arg_.p_c_grid,
 
  255                                                              arg_.M * arg_.N * 
sizeof(CDataType),
 
  256                                                              stream_config.stream_id_));
 
  259                     ave_time = ck::utility::launch_and_time_kernel_with_preprocess<false>(
 
  271                         hipGetErrorString(hipMemsetAsync(arg.p_c_grid,
 
  273                                                          arg.M * arg.N * 
sizeof(CDataType),
 
  274                                                          stream_config.stream_id_));
 
  277                         stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, arg);
 
  281             constexpr 
auto estimated_reg_a = MPerBlock * KPerBlock * 
sizeof(ADataType) / BlockSize /
 
  282                                              4 * (1 + GridwiseGemm::NWave);
 
  283             constexpr 
auto estimated_reg_b = NPerBlock * KPerBlock * 
sizeof(BDataType) / BlockSize /
 
  284                                              4 * (2) * (IsInputGemm ? 2 : 1);
 
  285             constexpr 
auto estimated_reg_c = MPerBlock * NPerBlock * 
sizeof(GemmAccDataType) /
 
  286                                              BlockSize / 4 * (IsInputGemm ? 2 : 1);
 
  287             constexpr 
auto estimated_reg_total =
 
  288                 estimated_reg_a + estimated_reg_b + estimated_reg_c;
 
  290             constexpr 
index_t minimum_occupancy = (estimated_reg_total >= 256) ? 1 : 2;
 
  292             constexpr 
auto MemoryDataOp =
 
  295             if(has_main_k_block_loop)
 
  301                         if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == 
TailNumber::Odd)
 
  324                     if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == 
TailNumber::Odd)
 
  345                     throw std::runtime_error(
"todo: only v1 & v2 support now");
 
  354                     if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == 
TailNumber::Odd)
 
  376                     if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == 
TailNumber::Odd)
 
  407             return Run(*
dynamic_cast<const Argument*
>(p_arg), stream_config);
 
  424         if(!ck::is_xdl_wmma_supported<ComputeTypeA, ComputeTypeB, MPerXDL, NPerXDL>())
 
  440         if(arg.N % NPerBlock != 0 || arg.K % KPerBlock != 0)
 
  470                              const void* p_sorted_expert_ids,
 
  471                              const void* p_max_token_id,
 
  474                              std::array<const void*, NumDTensor> p_ds,
 
  483                              std::array<index_t, NumDTensor> StrideDs,
 
  485                              const void* p_a_scale,
 
  486                              const void* p_b_scale,
 
  488                              AElementwiseOperation a_element_op,
 
  489                              BElementwiseOperation b_element_op,
 
  490                              CElementwiseOperation c_element_op)
 
  493                         static_cast<const index_t*
>(p_sorted_expert_ids),
 
  494                         static_cast<const index_t*
>(p_max_token_id),
 
  495                         static_cast<const ADataType*
>(p_a),
 
  496                         static_cast<const BDataType*
>(p_b),
 
  498                         static_cast<CDataType*
>(p_c),
 
  508                         static_cast<const AScaleDataType*
>(p_a_scale),
 
  509                         static_cast<const BScaleDataType*
>(p_b_scale),
 
  521                                                       std::array<const void*, NumDTensor> p_ds,
 
  528                                                       std::array<ck::index_t, NumDTensor> StrideDs,
 
  530                                                       const void* p_a_scale,
 
  531                                                       const void* p_b_scale,
 
  533                                                       AElementwiseOperation a_element_op,
 
  534                                                       BElementwiseOperation b_element_op,
 
  535                                                       CElementwiseOperation c_element_op)
 override 
  537         return std::make_unique<Argument>(
nullptr,
 
  540                                           static_cast<const ADataType*
>(p_a),
 
  541                                           static_cast<const BDataType*
>(p_b),
 
  543                                           static_cast<CDataType*
>(p_c),
 
  553                                           static_cast<const AScaleDataType*
>(p_a_scale),
 
  554                                           static_cast<const BScaleDataType*
>(p_b_scale),
 
  564         return std::make_unique<Invoker>(
Invoker{});
 
  570         auto str = std::stringstream();
 
  572         std::map<BlockGemmPipelineScheduler, std::string> BlkGemmPipelineSchedulerToString{
 
  576         std::map<BlockGemmPipelineVersion, std::string> BlkGemmPipelineVersionToString{
 
  582         str << 
"DeviceMoeGEmm" 
  585             << std::string(ALayout::name)[0]
 
  586             << std::string(BLayout::name)[0]
 
  587             << std::string(CLayout::name)[0]
 
  592             << MPerBlock<<
"x"<<NPerBlock<<
"x"<<KPerBlock << 
", " 
  594             << MPerXDL<<
"x"<<NPerXDL << 
", " 
  596             << MXdlPerWave<<
"x" << NXdlPerWave<<
", " 
  598             << ABlockTransferSrcScalarPerVector<<
"x"<<BBlockTransferSrcScalarPerVector<<
", " 
  599             << 
"BlkGemmPipelineScheduler: " 
  600             << BlkGemmPipelineSchedulerToString[BlkGemmPipeSched] << 
", " 
  601             << 
"BlkGemmPipelineVersion: " 
  602             << BlkGemmPipelineVersionToString[BlkGemmPipelineVer] << 
", " 
  603             << 
"BlkGemmPipelinePrefetchStages: " 
  604             << GridwiseGemm64::BlockwiseGemmPipe::PrefetchStages;
 
#define INVOKER_RUN3_IMPL
Definition: device_base.hpp:114
 
#define GET_NXDL_PER_WAVE_IMPL
Definition: device_base.hpp:81
 
float launch_and_time_kernel(const StreamConfig &stream_config, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args)
Definition: kernel_launch.hpp:14
 
__host__ constexpr __device__ T max(T x)
Definition: math.hpp:84
 
__host__ constexpr __device__ T min(T x)
Definition: math.hpp:116
 
std::string getGemmSpecializationString(const GemmSpecialization &s)
Definition: gemm_specialization.hpp:32
 
GemmSpecialization
Definition: gemm_specialization.hpp:11
 
void flush_icache()
Definition: flush_cache.hpp:383
 
typename tuple_element< I, TTuple >::type tuple_element_t
Definition: tuple.hpp:208
 
BlockGemmPipelineVersion
Definition: blkgemmpipe_scheduler.hpp:12
 
__global__ void kernel_moe_gemm(typename GridwiseGemm::Argument karg)
Definition: gridwise_moe_gemm.hpp:46
 
constexpr Tuple< Args &... > tie(Args &... args) noexcept
Definition: tuple.hpp:218
 
constexpr __device__ index_t get_warp_size()
Definition: get_id.hpp:10
 
constexpr bool is_same_v
Definition: type.hpp:283
 
BlockGemmPipelineScheduler
Definition: blkgemmpipe_scheduler.hpp:25
 
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition: type.hpp:297
 
int32_t index_t
Definition: ck.hpp:299
 
__global__ void kernel_moe_gemm_2lds(typename GridwiseGemm::Argument karg)
Definition: gridwise_moe_gemm.hpp:84
 
bool is_bf16_atomic_supported()
Definition: device_prop.hpp:108
 
Definition: stream_config.hpp:10
 
Definition: gridwise_moe_gemm_blockscale.hpp:666
 
Definition: gridwise_moe_gemm_blockscale.hpp:177
 
static constexpr __host__ bool CheckValidity(const Argument &karg)
Definition: gridwise_moe_gemm_blockscale.hpp:968
 
Definition: data_type.hpp:187
 
Definition: functional2.hpp:33
 
Definition: device_base.hpp:197
 
Definition: device_base.hpp:208
 
Definition: device_gemm_multiple_d_ab_scale.hpp:82
 
Definition: device_moe_gemm_blockscale.hpp:188
 
float RunImp(const typename GridwiseGemm::Argument &arg, const StreamConfig &stream_config=StreamConfig{})
Definition: device_moe_gemm_blockscale.hpp:190
 
INVOKER_RUN3_IMPL float Run(const BaseArgument *p_arg, const StreamConfig &stream_config=StreamConfig{}) override
Definition: device_moe_gemm_blockscale.hpp:404
 
Definition: device_moe_gemm_blockscale.hpp:100
 
static constexpr index_t BPackedSize
Definition: device_moe_gemm_blockscale.hpp:177
 
static constexpr auto NXdlPerWave32
Definition: device_moe_gemm_blockscale.hpp:103
 
std::unique_ptr< BaseInvoker > MakeInvokerPointer() override
Definition: device_moe_gemm_blockscale.hpp:562
 
static bool IsSupportedArgument(const Argument &arg)
Definition: device_moe_gemm_blockscale.hpp:417
 
static constexpr index_t APackedSize
Definition: device_moe_gemm_blockscale.hpp:170
 
int GetPreShuffleParameters() override
Definition: device_moe_gemm_blockscale.hpp:184
 
std::string GetTypeString() const override
Definition: device_moe_gemm_blockscale.hpp:568
 
std::unique_ptr< BaseArgument > MakeArgumentPointer(const void *p_a, const void *p_b, std::array< const void *, NumDTensor > p_ds, void *p_c, index_t M, index_t N, index_t K, index_t StrideA, index_t StrideB, std::array< ck::index_t, NumDTensor > StrideDs, index_t StrideC, const void *p_a_scale, const void *p_b_scale, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CElementwiseOperation c_element_op) override
Definition: device_moe_gemm_blockscale.hpp:519
 
static constexpr GET_NXDL_PER_WAVE_IMPL auto NXdlPerWave64
Definition: device_moe_gemm_blockscale.hpp:102
 
bool IsSupportedArgument(const BaseArgument *p_arg) override
Definition: device_moe_gemm_blockscale.hpp:464
 
typename GridwiseGemm64::Argument Argument
Definition: device_moe_gemm_blockscale.hpp:168
 
static auto MakeArgument(const void *p_sorted_token_ids, const void *p_sorted_expert_ids, const void *p_max_token_id, const void *p_a, const void *p_b, std::array< const void *, NumDTensor > p_ds, void *p_c, index_t NumTokens, index_t TopK, index_t M, index_t N, index_t K, index_t StrideA, index_t StrideB, std::array< index_t, NumDTensor > StrideDs, index_t StrideC, const void *p_a_scale, const void *p_b_scale, index_t KBatch, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CElementwiseOperation c_element_op)
Definition: device_moe_gemm_blockscale.hpp:469
 
static constexpr bool IsValidCompilationParameter()
Definition: device_moe_gemm_blockscale.hpp:411
 
static constexpr index_t NumDTensor
Definition: device_moe_gemm_blockscale.hpp:104
 
static auto MakeInvoker()
Definition: device_moe_gemm_blockscale.hpp:516
 
Definition: flush_cache.hpp:174