21 template <
typename Problem_, 
typename Policy_ = FusedMoeGemmPipelineFlatmmPolicy>
 
   42     using Traits = 
typename Problem::Traits;
 
   60         if constexpr(Problem::kBlockPerCu != -1)
 
   61             return Problem::kBlockPerCu;
 
   69     static constexpr 
const char* 
name = 
"fused_moe_flatmm";
 
   74         return Policy::template GetSmemSize_A<Problem>();
 
   79         return Policy::template GetSmemSize<Problem>();
 
   85         constexpr 
auto a_dist = Policy::template MakeGlobalTileDistribution_A<Problem>();
 
   86         const auto a_coord    = a_dist.calculate_index();
 
   93         constexpr 
auto o_dist = Policy::template MakeOGlobalTileDistribution<Problem>();
 
   94         const auto o_coord    = o_dist.calculate_index();
 
   98     template <
typename AWindow, 
typename GWindow, 
typename DWindow, 
typename OWindow>
 
  100                                    const GWindow& g_window_,
 
  101                                    const DWindow& d_window_,
 
  108         _Pragma(
"clang diagnostic push") _Pragma(
"clang diagnostic ignored \"-Wc++20-extensions\"");
 
  109         constexpr 
auto NEG1  = 
number<-1>{};
 
  118             Policy::template GetSmemSize_A<Problem>());
 
  120         auto g_view = g_window_.get_bottom_tensor_view();
 
  122         auto u_view = [&]() {
 
  129                 index_t nr_0 = intermediate_size / BlockShape::Block_Nr0;
 
  130                 index_t kr_0 = hidden_size / BlockShape::Block_Kr0;
 
  133                     g_window_.get_bottom_tensor_view().get_buffer_view().p_data_;
 
  136                 const auto u_view_ = make_naive_tensor_view<address_space_enum::global>(
 
  142                 const auto u_view_1_ =
 
  153             a_window_, Policy::template MakeGlobalTileDistribution_A<Problem>());
 
  156                                     Policy::template MakeGlobalTileDistribution_G<Problem>(),
 
  160                                     Policy::template MakeGlobalTileDistribution_D<Problem>(),
 
  163             o_window_, Policy::template MakeGlobalTileDistribution_O<Problem>());
 
  165         using g_thread_type = decltype(
load_tile(g_win));
 
  166         using d_thread_type = decltype(
load_tile(d_win));
 
  168         using WarpGemm0  = decltype(Policy::template GetWarpGemm0<Problem>());
 
  169         using WarpGemm1  = decltype(Policy::template GetWarpGemm1<Problem>());
 
  170         auto warp_gemm_0 = WarpGemm0{};
 
  171         auto warp_gemm_1 = WarpGemm1{};
 
  176                                  smem_0, Policy::template MakeLdsStoreDesc_A<Problem>()),
 
  177                              Policy::template MakeLdsStoreDesc_A<Problem>().get_lengths(),
 
  182                                  smem_1, Policy::template MakeLdsStoreDesc_A<Problem>()),
 
  183                              Policy::template MakeLdsStoreDesc_A<Problem>().get_lengths(),
 
  186         auto a_sld_win0 = [&]() {
 
  187             using WG                        = WarpGemm0;
 
  197                 a_outer_dstr_enc, 
typename WG::AWarpDstrEncoding{});
 
  199                 make_tensor_view<address_space_enum::lds>(
 
  200                     smem_0, Policy::template MakeLdsLoadDesc_A<Problem>()),
 
  201                 Policy::template MakeLdsLoadDesc_A<Problem>().get_lengths(),
 
  207         auto a_sld_win1 = [&]() {
 
  208             using WG                        = WarpGemm0;
 
  218                 a_outer_dstr_enc, 
typename WG::AWarpDstrEncoding{});
 
  220                 make_tensor_view<address_space_enum::lds>(
 
  221                     smem_1, Policy::template MakeLdsLoadDesc_A<Problem>()),
 
  222                 Policy::template MakeLdsLoadDesc_A<Problem>().get_lengths(),
 
  227         auto bridge_sst_win = [&]() {
 
  229                 make_tensor_view<address_space_enum::lds>(
 
  231                     Policy::template MakeBridgeLdsStoreDesc<Problem>()),
 
  232                 Policy::template MakeBridgeLdsStoreDesc<Problem>().get_lengths(),
 
  236         auto bridge_sld_win = [&]() {
 
  238                 make_tensor_view<address_space_enum::lds>(
 
  240                     Policy::template MakeBridgeLdsLoadDesc<Problem>()),
 
  241                 Policy::template MakeBridgeLdsLoadDesc<Problem>().get_lengths(),
 
  243                 Policy::template MakeYTileDistribution<Problem>());
 
  249         constexpr 
auto issues_a = 
number<a_win.get_num_of_access()>{};
 
  250         constexpr 
auto issues_g = 
number<g_win.get_num_of_access()>{};
 
  253         constexpr 
auto issues_gemm0 =
 
  254             number<BlockShape::Repeat_M0 * BlockShape::Repeat_N0 * BlockShape::Repeat_K0 *
 
  255                    warp_gemm_0.get_num_of_access()>{};
 
  256         constexpr 
auto issues_gemm1 =
 
  257             number<BlockShape::Repeat_M1 * BlockShape::Repeat_N1 * BlockShape::Repeat_K1 *
 
  258                    warp_gemm_1.get_num_of_access()>{};
 
  262             (hidden_size + BlockShape::Block_K0 - 1) / BlockShape::Block_K0;
 
  264             (hidden_size + BlockShape::Block_N1 - 1) / BlockShape::Block_N1;
 
  266         using a_thread_type = decltype(
load_tile(a_sld_win0));
 
  270                          auto& a_store_, 
auto i_access, PreNop = {}) {
 
  273         auto move_a = [&]() {
 
  276         auto sld_a = [&](
auto& a_, 
auto& win_, 
auto i_access) {
 
  285                     if constexpr(i_access.
value == 0)
 
  287                         g_win.bottom_tensor_view_ = g_view;
 
  289                     else if constexpr(i_access.
value == issues_g / 2)
 
  291                         g_win.bottom_tensor_view_ = u_view;
 
  296         auto move_g = [&]() {
 
  305         auto move_d = [&]() {
 
  315         auto acc_0  = Policy::template MakeCBlockTile_Gemm0<Problem>();
 
  317             [&](
auto) { 
return Policy::template MakeCBlockTile_Gemm1<Problem>(); }, 
number<2>{});
 
  321         (
auto& t_c, 
auto& t_a, 
auto& t_b, 
auto i_access, PostNop = {}) {
 
  324             constexpr 
auto repeat_sub = WarpGemm::get_num_of_access();
 
  325             constexpr 
auto repeat_m = BlockShape::Repeat_M0;
 
  327             constexpr 
auto repeat_k = BlockShape::Repeat_K0;
 
  329             constexpr 
auto i_sub = i_access % repeat_sub;
 
  330             constexpr 
auto i_k = (i_access / repeat_sub) % repeat_k;
 
  331             constexpr 
auto i_m = (i_access / (repeat_sub * repeat_k )) % repeat_m;
 
  332             constexpr 
auto i_n = (i_access / (repeat_sub * repeat_k )) / repeat_m;
 
  334             using AWarpTensor = 
typename WarpGemm::AWarpTensor;
 
  335             using BWarpTensor = 
typename WarpGemm::BWarpTensor;
 
  336             using CWarpTensor = 
typename WarpGemm::CWarpTensor;
 
  337             using AWarpDstr = 
typename WarpGemm::AWarpDstr;
 
  338             using BWarpDstr = 
typename WarpGemm::BWarpDstr;
 
  339             using CWarpDstr = 
typename WarpGemm::CWarpDstr;
 
  345             constexpr 
auto a_warp_y_lengths = 
to_sequence(AWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
 
  346             constexpr 
auto b_warp_y_lengths = 
to_sequence(BWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
 
  347             constexpr 
auto c_warp_y_lengths = 
to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
 
  350             w_a.get_thread_buffer() = t_a.get_y_sliced_thread_data(
 
  355             w_b.get_thread_buffer() = t_b.get_y_sliced_thread_data(
 
  360             w_c.get_thread_buffer() = t_c.get_y_sliced_thread_data(
 
  366             t_c.set_y_sliced_thread_data(
 
  369                         w_c.get_thread_buffer());
 
  375         (
auto& t_c, 
auto& t_a, 
auto& t_b, 
auto i_access, PostNop = {}) {
 
  378             constexpr 
auto repeat_sub = WarpGemm::get_num_of_access();
 
  379             constexpr 
auto repeat_m = BlockShape::Repeat_M0;
 
  381             constexpr 
auto repeat_k = BlockShape::Repeat_K0;
 
  383             constexpr 
auto i_sub = i_access % repeat_sub;
 
  384             constexpr 
auto i_k = (i_access / repeat_sub) % repeat_k;
 
  385             constexpr 
auto i_m = (i_access / (repeat_sub * repeat_k )) % repeat_m;
 
  386             constexpr 
auto i_n = (i_access / (repeat_sub * repeat_k )) / repeat_m;
 
  388             using AWarpTensor = 
typename WarpGemm::AWarpTensor;
 
  389             using BWarpTensor = 
typename WarpGemm::BWarpTensor;
 
  390             using CWarpTensor = 
typename WarpGemm::CWarpTensor;
 
  391             using AWarpDstr = 
typename WarpGemm::AWarpDstr;
 
  392             using BWarpDstr = 
typename WarpGemm::BWarpDstr;
 
  393             using CWarpDstr = 
typename WarpGemm::CWarpDstr;
 
  399             constexpr 
auto a_warp_y_lengths = 
to_sequence(AWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
 
  400             constexpr 
auto b_warp_y_lengths = 
to_sequence(BWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
 
  401             constexpr 
auto c_warp_y_lengths = 
to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
 
  404             w_a.get_thread_buffer() = t_a.get_y_sliced_thread_data(
 
  409             w_b.get_thread_buffer() = t_b.get_y_sliced_thread_data(
 
  414             w_c.get_thread_buffer() = t_c.get_y_sliced_thread_data(
 
  420             t_c.set_y_sliced_thread_data(
 
  423                         w_c.get_thread_buffer());
 
  426         _Pragma(
"clang diagnostic pop");
 
  435         auto pipeline_gemm0 = [&]() {
 
  436             constexpr 
index_t total_loops = issues_gemm0;
 
  437             constexpr 
auto sr             = Policy::template GetSequencer_0<Problem>();
 
  438             static_assert(sr.size() == total_loops);
 
  440             constexpr 
auto c_sld_a_0 = 
MAKE_SC();
 
  441             constexpr 
auto c_gld_a_0 = 
MAKE_SC();
 
  442             constexpr 
auto c_gld_b_0 = 
MAKE_SC();
 
  445                 gemm_0(acc_0, as[I0], gs[I0], i_issue);
 
  446                 constexpr 
index_t slot = sr.at(i_issue);
 
  448                 if constexpr(slot & 
SLD_A)
 
  449                     sld_a(as[I1], a_sld_win1, 
number<
NEXT_SCI(c_sld_a_0, i_issue)>{});
 
  450                 if constexpr(slot & 
GLD_A)
 
  452                 if constexpr(slot & 
GLD_B)
 
  457             block_sync_load_raw(issues_a + issues_g);
 
  460             constexpr 
auto c_sld_a_1 = 
MAKE_SC();
 
  461             constexpr 
auto c_gld_a_1 = 
MAKE_SC();
 
  462             constexpr 
auto c_gld_b_1 = 
MAKE_SC();
 
  466                 gemm_0(acc_0, as[I1], gs[I1], i_issue);
 
  467                 constexpr 
index_t slot = sr.at(i_issue);
 
  469                 if constexpr(slot & 
SLD_A)
 
  470                     sld_a(as[I0], a_sld_win0, 
number<
NEXT_SCI(c_sld_a_1, i_issue)>{});
 
  471                 if constexpr(slot & 
GLD_A)
 
  473                 if constexpr(slot & 
GLD_B)
 
  478             block_sync_load_raw(issues_a + issues_g);
 
  482         auto pipeline_gemm0_tail = [&]() {
 
  483             constexpr 
index_t total_loops = issues_gemm0;
 
  484             constexpr 
auto sr             = Policy::template GetSequencer_0<Problem>();
 
  485             static_assert(sr.size() == total_loops);
 
  487             constexpr 
auto c_gld_b_0 = 
MAKE_SC();
 
  491                 gemm_0(acc_0, as[I0], gs[I0], i_issue);
 
  492                 constexpr 
index_t slot = sr.at(i_issue);
 
  494                 if constexpr(slot & 
GLD_B)
 
  498             block_sync_load_raw(issues_g);
 
  499             sld_a(as[I1], a_sld_win1, NEG1);
 
  503                 constexpr 
auto last_nop = [&]() {
 
  504                     if constexpr(i_issue == (total_loops - 1))
 
  509                 gemm_0(acc_0, as[I1], gs[I1], i_issue, last_nop); 
 
  513         auto y = Policy::template MakeYBlockTile<Problem>();
 
  515         auto pipeline_bridge = [&]() {
 
  517             auto y_pre = cast_tile<YDataType>(acc_0);
 
  526         auto pipeline_gemm1 = [&]() {
 
  527             constexpr 
index_t total_loops = issues_gemm1;
 
  528             constexpr 
auto sr             = Policy::template GetSequencer_1<Problem>();
 
  529             static_assert(sr.size() == total_loops);
 
  531             constexpr 
auto c_gld_b_0 = 
MAKE_SC();
 
  532             constexpr 
auto c_gst_o_0 = 
MAKE_SC();
 
  533             constexpr 
auto c_gld_b_1 = 
MAKE_SC();
 
  534             constexpr 
auto c_gst_o_1 = 
MAKE_SC();
 
  538                 gemm_1(acc_1s[I1], y, ds[I1], i_issue);
 
  539                 constexpr 
index_t slot = sr.at(i_issue);
 
  540                 if constexpr(slot & 
GLD_B)
 
  543                 if constexpr(slot & 
GST_O)
 
  545                     auto out = cast_tile<ODataType>(acc_1s[I0]);
 
  554                 gemm_1(acc_1s[I0], y, ds[I0], i_issue);
 
  555                 constexpr 
index_t slot = sr.at(i_issue);
 
  556                 if constexpr(slot & 
GLD_B)
 
  559                 if constexpr(slot & 
GST_O)
 
  561                     auto out = cast_tile<ODataType>(acc_1s[I1]);
 
  568         auto pipeline_gemm1_head = [&]() {
 
  569             constexpr 
index_t total_loops = issues_gemm1;
 
  570             constexpr 
auto sr             = Policy::template GetSequencer_1<Problem>();
 
  571             static_assert(sr.size() == total_loops);
 
  573             constexpr 
auto c_gld_b_0 = 
MAKE_SC();
 
  577                 gemm_1(acc_1s[I0], y, ds[I0], i_issue);
 
  578                 constexpr 
index_t slot = sr.at(i_issue);
 
  579                 if constexpr(slot & 
GLD_B)
 
  584         auto pipeline_gemm1_tail = [&]() {
 
  585             constexpr 
index_t total_loops = issues_gemm1;
 
  586             constexpr 
auto sr             = Policy::template GetSequencer_1<Problem>();
 
  587             static_assert(sr.size() == total_loops);
 
  589             constexpr 
auto c_gst_o_0 = 
MAKE_SC();
 
  593                 gemm_1(acc_1s[I1], y, ds[I1], i_issue);
 
  595                 constexpr 
index_t slot = sr.at(i_issue);
 
  596                 if constexpr(slot & 
GST_O)
 
  598                     auto out = cast_tile<ODataType>(acc_1s[I0]);
 
  603                 auto out = cast_tile<ODataType>(acc_1s[I1]);
 
  604                 atomic_add_o(out, NEG1);
 
  610         gld_a(a_sst_win0, NEG1, TRUE);
 
  611         gld_g(gs[I0], NEG1, TRUE);
 
  617         gld_a(a_sst_win1, NEG1); 
 
  621         block_sync_load_raw(issues_a + issues_g);
 
  625         const index_t iters_0 = (num_blocks_k0 - 2) / 2;
 
  627         while(i_0++ < iters_0)
 
  631         pipeline_gemm0_tail();
 
  635         const index_t iters_1 = (num_blocks_n1 - 2) / 2;
 
  637         pipeline_gemm1_head();
 
  638         while(i_1++ < iters_1)
 
  642         pipeline_gemm1_tail();
 
#define CK_TILE_DEVICE
Definition: config.hpp:41
 
#define CK_TILE_LDS_ADDR
Definition: config.hpp:58
 
#define CK_TILE_HOST_DEVICE
Definition: config.hpp:42
 
constexpr CK_TILE_HOST_DEVICE auto make_embed_tile_distribution_encoding(OuterDstr, InnerDstr)
Definition: tile_distribution_encoding.hpp:457
 
Definition: cluster_descriptor.hpp:13
 
tuple_array< T, N > statically_indexed_array
Definition: statically_indexed_array.hpp:16
 
int32_t index_t
Definition: integer.hpp:9
 
CK_TILE_DEVICE void lds_load_fence(index_t cnt=0)
Definition: amd_buffer_addressing.hpp:820
 
constexpr CK_TILE_HOST_DEVICE auto pad_tensor_view(const TensorView &tensor_view, const TileLengths &tile_lengths, DoPads)
Definition: tensor_view.hpp:530
 
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition: type_traits.hpp:21
 
constexpr CK_TILE_HOST_DEVICE auto to_sequence(tuple< number< Is >... >)
Definition: sequence.hpp:1055
 
constexpr CK_TILE_HOST_DEVICE auto merge_sequences(Seqs...)
Definition: sequence.hpp:826
 
constexpr CK_TILE_DEVICE auto make_tile_window(null_tensor_view, const WindowLengths &window_lengths, const multi_index< WindowLengths::size()> &, Ts &&...)
Definition: null_tile_window.hpp:75
 
CK_TILE_DEVICE auto load_tile_raw(T &tile, const tile_window_with_static_distribution< BottomTensorView_, WindowLengths_, TileDistribution_, NumCoord > &tile_window, number< i_access >={}, bool_constant< oob_conditional_check >={}, bool_constant< pre_nop >={})
Loads a tile of data using inline assembly.
Definition: load_tile.hpp:81
 
CK_TILE_DEVICE void move_tile_window(null_tile_window< WindowLengths > &, const typename null_tile_window< WindowLengths >::BottomTensorIndex &)
Definition: null_tile_window.hpp:95
 
constexpr CK_TILE_DEVICE auto make_tile_window_linear(const TensorView_ &tensor_view, const WindowLengths_ &window_lengths, const multi_index< TensorView_::get_num_of_dimension()> &origin, const StaticTileDistribution_ &tile_distribution, LinearBottomDims_={})
Definition: tile_window_linear.hpp:993
 
constexpr CK_TILE_HOST_DEVICE auto generate_tuple(F &&f, number< N >)
Definition: tuple.hpp:429
 
constexpr CK_TILE_HOST_DEVICE auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:360
 
CK_TILE_DEVICE auto async_load_tile_raw(LdsTileWindow_ &&lds_tile, const TileWindow_ &tile_window, number< i_access >={}, bool_constant< oob_conditional_check >={}, bool_constant< pre_nop >={})
Definition: load_tile.hpp:133
 
CK_TILE_DEVICE void store_tile(tile_window_with_static_lengths< BottomTensorView_, WindowLengths_ > &tile_window_tmp, const static_distributed_tensor< DataType_, TileDistribution_ > &dstr_tensor)
Definition: store_tile.hpp:23
 
CK_TILE_DEVICE void clear_tile(DstrTensors &dstr_tensor)
Definition: tile_elementwise.hpp:177
 
CK_TILE_DEVICE auto load_tile(const TileWindow_ &tile_window, number< i_access >={}, bool_constant< oob_conditional_check >={})
Definition: load_tile.hpp:22
 
constexpr CK_TILE_HOST_DEVICE auto make_static_tile_distribution(StaticTileDistributionEncoding_)
Definition: tile_distribution.hpp:480
 
typename uniform_sequence_gen< NSize, I >::type uniform_sequence_gen_t
Definition: sequence.hpp:1026
 
CK_TILE_DEVICE void update_tile_raw(tile_window_with_static_distribution< BottomTensorView_, WindowLengths_, TileDistribution_, NumCoord > &tile_window, const static_distributed_tensor< DataType_, TileDistribution_ > &dstr_tensor, number< i_access >={}, bool_constant< oob_conditional_check >={}, bool_constant< pre_nop >={})
Definition: update_tile.hpp:68
 
#define NEXT_SCI(c_, static_i_)
Definition: static_counter.hpp:109
 
#define MAKE_SC()
Definition: static_counter.hpp:104
 
Definition: fused_moegemm_pipeline_flatmm_ex.hpp:23
 
static constexpr index_t SLD_A
Definition: fused_moegemm_pipeline_flatmm_ex.hpp:54
 
static constexpr index_t kAlignmentA
Definition: fused_moegemm_pipeline_flatmm_ex.hpp:49
 
static constexpr bool PadHiddenSize
Definition: fused_moegemm_pipeline_flatmm_ex.hpp:46
 
static constexpr CK_TILE_HOST_DEVICE ck_tile::index_t GetSmemSize_A()
Definition: fused_moegemm_pipeline_flatmm_ex.hpp:72
 
typename Problem::DScaleDataType DScaleDataType
Definition: fused_moegemm_pipeline_flatmm_ex.hpp:36
 
typename Problem::BlockShape BlockShape
Definition: fused_moegemm_pipeline_flatmm_ex.hpp:27
 
static constexpr index_t kAlignmentO
Definition: fused_moegemm_pipeline_flatmm_ex.hpp:52
 
typename Problem::IndexDataType IndexDataType
Definition: fused_moegemm_pipeline_flatmm_ex.hpp:39
 
static constexpr index_t kBlockPerCu
Definition: fused_moegemm_pipeline_flatmm_ex.hpp:59
 
remove_cvref_t< Policy_ > Policy
Definition: fused_moegemm_pipeline_flatmm_ex.hpp:25
 
static constexpr const char * name
Definition: fused_moegemm_pipeline_flatmm_ex.hpp:69
 
static constexpr index_t GLD_B
Definition: fused_moegemm_pipeline_flatmm_ex.hpp:56
 
static constexpr index_t GLD_A
Definition: fused_moegemm_pipeline_flatmm_ex.hpp:55
 
remove_cvref_t< Problem_ > Problem
Definition: fused_moegemm_pipeline_flatmm_ex.hpp:24
 
typename Problem::ADataType ADataType
Definition: fused_moegemm_pipeline_flatmm_ex.hpp:29
 
static CK_TILE_HOST_DEVICE auto GetOCoord()
Definition: fused_moegemm_pipeline_flatmm_ex.hpp:91
 
static constexpr CK_TILE_HOST_DEVICE ck_tile::index_t GetSmemSize()
Definition: fused_moegemm_pipeline_flatmm_ex.hpp:77
 
typename Problem::GDataType GDataType
Definition: fused_moegemm_pipeline_flatmm_ex.hpp:30
 
static constexpr index_t kAlignmentG
Definition: fused_moegemm_pipeline_flatmm_ex.hpp:50
 
typename Problem::DDataType DDataType
Definition: fused_moegemm_pipeline_flatmm_ex.hpp:31
 
static CK_TILE_HOST_DEVICE auto GetACoord()
Definition: fused_moegemm_pipeline_flatmm_ex.hpp:83
 
typename Problem::ODataType ODataType
Definition: fused_moegemm_pipeline_flatmm_ex.hpp:33
 
typename Problem::GScaleDataType GScaleDataType
Definition: fused_moegemm_pipeline_flatmm_ex.hpp:35
 
typename Problem::AccDataType AccDataType
Definition: fused_moegemm_pipeline_flatmm_ex.hpp:32
 
static constexpr index_t GST_O
Definition: fused_moegemm_pipeline_flatmm_ex.hpp:57
 
typename Problem::TopkWeightDataType TopkWeightDataType
Definition: fused_moegemm_pipeline_flatmm_ex.hpp:38
 
static constexpr index_t kAlignmentD
Definition: fused_moegemm_pipeline_flatmm_ex.hpp:51
 
typename Problem::YDataType YDataType
Definition: fused_moegemm_pipeline_flatmm_ex.hpp:40
 
static constexpr bool IsGateOnly
Definition: fused_moegemm_pipeline_flatmm_ex.hpp:44
 
typename Problem::Traits Traits
Definition: fused_moegemm_pipeline_flatmm_ex.hpp:42
 
typename Problem::YSmoothScaleDataType YSmoothScaleDataType
Definition: fused_moegemm_pipeline_flatmm_ex.hpp:37
 
static constexpr bool PadIntermediateSize
Definition: fused_moegemm_pipeline_flatmm_ex.hpp:47
 
typename Problem::AScaleDataType AScaleDataType
Definition: fused_moegemm_pipeline_flatmm_ex.hpp:34
 
static constexpr bool UseSmoothQuant
Definition: fused_moegemm_pipeline_flatmm_ex.hpp:45
 
CK_TILE_DEVICE auto operator()(const AWindow &a_window_, const GWindow &g_window_, const DWindow &d_window_, OWindow &o_window_, TopkWeightDataType, CK_TILE_LDS_ADDR void *smem, index_t hidden_size, index_t intermediate_size)
Definition: fused_moegemm_pipeline_flatmm_ex.hpp:99
 
Definition: integral_constant.hpp:13
 
static constexpr value_type value
Definition: integral_constant.hpp:16
 
Definition: sequence.hpp:49
 
Definition: functional.hpp:43
 
Definition: tile_distribution_encoding.hpp:26
 
Definition: tuple.hpp:192