21 template <
typename Problem_,
typename Policy_ = FusedMoeGemmPipelineFlatmmPolicy>
42 using Traits =
typename Problem::Traits;
60 if constexpr(Problem::kBlockPerCu != -1)
61 return Problem::kBlockPerCu;
69 static constexpr
const char*
name =
"fused_moe_flatmm";
74 return Policy::template GetSmemSize_A<Problem>();
79 return Policy::template GetSmemSize<Problem>();
85 constexpr
auto a_dist = Policy::template MakeGlobalTileDistribution_A<Problem>();
86 const auto a_coord = a_dist.calculate_index();
93 constexpr
auto o_dist = Policy::template MakeOGlobalTileDistribution<Problem>();
94 const auto o_coord = o_dist.calculate_index();
98 template <
typename AWindow,
typename GWindow,
typename DWindow,
typename OWindow>
100 const GWindow& g_window_,
101 const DWindow& d_window_,
108 _Pragma(
"clang diagnostic push") _Pragma(
"clang diagnostic ignored \"-Wc++20-extensions\"");
109 constexpr
auto NEG1 =
number<-1>{};
118 Policy::template GetSmemSize_A<Problem>());
120 auto g_view = g_window_.get_bottom_tensor_view();
122 auto u_view = [&]() {
129 index_t nr_0 = intermediate_size / BlockShape::Block_Nr0;
130 index_t kr_0 = hidden_size / BlockShape::Block_Kr0;
133 g_window_.get_bottom_tensor_view().get_buffer_view().p_data_;
136 const auto u_view_ = make_naive_tensor_view<address_space_enum::global>(
142 const auto u_view_1_ =
153 a_window_, Policy::template MakeGlobalTileDistribution_A<Problem>());
156 Policy::template MakeGlobalTileDistribution_G<Problem>(),
160 Policy::template MakeGlobalTileDistribution_D<Problem>(),
163 o_window_, Policy::template MakeGlobalTileDistribution_O<Problem>());
165 using g_thread_type = decltype(
load_tile(g_win));
166 using d_thread_type = decltype(
load_tile(d_win));
168 using WarpGemm0 = decltype(Policy::template GetWarpGemm0<Problem>());
169 using WarpGemm1 = decltype(Policy::template GetWarpGemm1<Problem>());
170 auto warp_gemm_0 = WarpGemm0{};
171 auto warp_gemm_1 = WarpGemm1{};
176 smem_0, Policy::template MakeLdsStoreDesc_A<Problem>()),
177 Policy::template MakeLdsStoreDesc_A<Problem>().get_lengths(),
182 smem_1, Policy::template MakeLdsStoreDesc_A<Problem>()),
183 Policy::template MakeLdsStoreDesc_A<Problem>().get_lengths(),
186 auto a_sld_win0 = [&]() {
187 using WG = WarpGemm0;
197 a_outer_dstr_enc,
typename WG::AWarpDstrEncoding{});
199 make_tensor_view<address_space_enum::lds>(
200 smem_0, Policy::template MakeLdsLoadDesc_A<Problem>()),
201 Policy::template MakeLdsLoadDesc_A<Problem>().get_lengths(),
207 auto a_sld_win1 = [&]() {
208 using WG = WarpGemm0;
218 a_outer_dstr_enc,
typename WG::AWarpDstrEncoding{});
220 make_tensor_view<address_space_enum::lds>(
221 smem_1, Policy::template MakeLdsLoadDesc_A<Problem>()),
222 Policy::template MakeLdsLoadDesc_A<Problem>().get_lengths(),
227 auto bridge_sst_win = [&]() {
229 make_tensor_view<address_space_enum::lds>(
231 Policy::template MakeBridgeLdsStoreDesc<Problem>()),
232 Policy::template MakeBridgeLdsStoreDesc<Problem>().get_lengths(),
236 auto bridge_sld_win = [&]() {
238 make_tensor_view<address_space_enum::lds>(
240 Policy::template MakeBridgeLdsLoadDesc<Problem>()),
241 Policy::template MakeBridgeLdsLoadDesc<Problem>().get_lengths(),
243 Policy::template MakeYTileDistribution<Problem>());
249 constexpr
auto issues_a =
number<a_win.get_num_of_access()>{};
250 constexpr
auto issues_g =
number<g_win.get_num_of_access()>{};
253 constexpr
auto issues_gemm0 =
254 number<BlockShape::Repeat_M0 * BlockShape::Repeat_N0 * BlockShape::Repeat_K0 *
255 warp_gemm_0.get_num_of_access()>{};
256 constexpr
auto issues_gemm1 =
257 number<BlockShape::Repeat_M1 * BlockShape::Repeat_N1 * BlockShape::Repeat_K1 *
258 warp_gemm_1.get_num_of_access()>{};
262 (hidden_size + BlockShape::Block_K0 - 1) / BlockShape::Block_K0;
264 (hidden_size + BlockShape::Block_N1 - 1) / BlockShape::Block_N1;
266 using a_thread_type = decltype(
load_tile(a_sld_win0));
270 auto& a_store_,
auto i_access, PreNop = {})
274 auto move_a = [&]() {
277 auto sld_a = [&](
auto& a_,
auto& win_,
auto i_access) {
282 auto& g_,
auto i_access, PreNop = {})
287 if constexpr(i_access.
value == 0)
289 g_win.bottom_tensor_view_ = g_view;
291 else if constexpr(i_access.
value == issues_g / 2)
293 g_win.bottom_tensor_view_ = u_view;
298 auto move_g = [&]() {
304 auto& d_,
auto i_access, PreNop = {})
308 auto move_d = [&]() {
314 auto& o_,
auto i_access, PreNop = {})
319 auto acc_0 = Policy::template MakeCBlockTile_Gemm0<Problem>();
321 [&](
auto) {
return Policy::template MakeCBlockTile_Gemm1<Problem>(); },
number<2>{});
325 (
auto& t_c,
auto& t_a,
auto& t_b,
auto i_access, PostNop = {}) {
328 constexpr
auto repeat_sub = WarpGemm::get_num_of_access();
329 constexpr
auto repeat_m = BlockShape::Repeat_M0;
331 constexpr
auto repeat_k = BlockShape::Repeat_K0;
333 constexpr
auto i_sub = i_access % repeat_sub;
334 constexpr
auto i_k = (i_access / repeat_sub) % repeat_k;
335 constexpr
auto i_m = (i_access / (repeat_sub * repeat_k )) % repeat_m;
336 constexpr
auto i_n = (i_access / (repeat_sub * repeat_k )) / repeat_m;
338 using AWarpTensor =
typename WarpGemm::AWarpTensor;
339 using BWarpTensor =
typename WarpGemm::BWarpTensor;
340 using CWarpTensor =
typename WarpGemm::CWarpTensor;
341 using AWarpDstr =
typename WarpGemm::AWarpDstr;
342 using BWarpDstr =
typename WarpGemm::BWarpDstr;
343 using CWarpDstr =
typename WarpGemm::CWarpDstr;
349 constexpr
auto a_warp_y_lengths =
to_sequence(AWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
350 constexpr
auto b_warp_y_lengths =
to_sequence(BWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
351 constexpr
auto c_warp_y_lengths =
to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
354 w_a.get_thread_buffer() = t_a.get_y_sliced_thread_data(
359 w_b.get_thread_buffer() = t_b.get_y_sliced_thread_data(
364 w_c.get_thread_buffer() = t_c.get_y_sliced_thread_data(
370 t_c.set_y_sliced_thread_data(
373 w_c.get_thread_buffer());
379 (
auto& t_c,
auto& t_a,
auto& t_b,
auto i_access, PostNop = {}) {
382 constexpr
auto repeat_sub = WarpGemm::get_num_of_access();
383 constexpr
auto repeat_m = BlockShape::Repeat_M0;
385 constexpr
auto repeat_k = BlockShape::Repeat_K0;
387 constexpr
auto i_sub = i_access % repeat_sub;
388 constexpr
auto i_k = (i_access / repeat_sub) % repeat_k;
389 constexpr
auto i_m = (i_access / (repeat_sub * repeat_k )) % repeat_m;
390 constexpr
auto i_n = (i_access / (repeat_sub * repeat_k )) / repeat_m;
392 using AWarpTensor =
typename WarpGemm::AWarpTensor;
393 using BWarpTensor =
typename WarpGemm::BWarpTensor;
394 using CWarpTensor =
typename WarpGemm::CWarpTensor;
395 using AWarpDstr =
typename WarpGemm::AWarpDstr;
396 using BWarpDstr =
typename WarpGemm::BWarpDstr;
397 using CWarpDstr =
typename WarpGemm::CWarpDstr;
403 constexpr
auto a_warp_y_lengths =
to_sequence(AWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
404 constexpr
auto b_warp_y_lengths =
to_sequence(BWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
405 constexpr
auto c_warp_y_lengths =
to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
408 w_a.get_thread_buffer() = t_a.get_y_sliced_thread_data(
413 w_b.get_thread_buffer() = t_b.get_y_sliced_thread_data(
418 w_c.get_thread_buffer() = t_c.get_y_sliced_thread_data(
424 t_c.set_y_sliced_thread_data(
427 w_c.get_thread_buffer());
430 _Pragma(
"clang diagnostic pop");
439 auto pipeline_gemm0 = [&]() {
440 constexpr
index_t total_loops = issues_gemm0;
441 constexpr
auto sr = Policy::template GetSequencer_0<Problem>();
442 static_assert(sr.size() == total_loops);
444 constexpr
auto c_sld_a_0 =
MAKE_SC();
445 constexpr
auto c_gld_a_0 =
MAKE_SC();
446 constexpr
auto c_gld_b_0 =
MAKE_SC();
449 gemm_0(acc_0, as[I0], gs[I0], i_issue);
450 constexpr
index_t slot = sr.at(i_issue);
452 if constexpr(slot &
SLD_A)
453 sld_a(as[I1], a_sld_win1,
number<
NEXT_SCI(c_sld_a_0, i_issue)>{});
454 if constexpr(slot &
GLD_A)
456 if constexpr(slot &
GLD_B)
464 constexpr
auto c_sld_a_1 =
MAKE_SC();
465 constexpr
auto c_gld_a_1 =
MAKE_SC();
466 constexpr
auto c_gld_b_1 =
MAKE_SC();
470 gemm_0(acc_0, as[I1], gs[I1], i_issue);
471 constexpr
index_t slot = sr.at(i_issue);
473 if constexpr(slot &
SLD_A)
474 sld_a(as[I0], a_sld_win0,
number<
NEXT_SCI(c_sld_a_1, i_issue)>{});
475 if constexpr(slot &
GLD_A)
477 if constexpr(slot &
GLD_B)
486 auto pipeline_gemm0_tail = [&]() {
487 constexpr
index_t total_loops = issues_gemm0;
488 constexpr
auto sr = Policy::template GetSequencer_0<Problem>();
489 static_assert(sr.size() == total_loops);
491 constexpr
auto c_gld_b_0 =
MAKE_SC();
495 gemm_0(acc_0, as[I0], gs[I0], i_issue);
496 constexpr
index_t slot = sr.at(i_issue);
498 if constexpr(slot &
GLD_B)
503 sld_a(as[I1], a_sld_win1, NEG1);
507 constexpr
auto last_nop = [&]() {
508 if constexpr(i_issue == (total_loops - 1))
513 gemm_0(acc_0, as[I1], gs[I1], i_issue, last_nop);
517 auto y = Policy::template MakeYBlockTile<Problem>();
519 auto pipeline_bridge = [&]() {
521 auto y_pre = cast_tile<YDataType>(acc_0);
530 auto pipeline_gemm1 = [&]() {
531 constexpr
index_t total_loops = issues_gemm1;
532 constexpr
auto sr = Policy::template GetSequencer_1<Problem>();
533 static_assert(sr.size() == total_loops);
535 constexpr
auto c_gld_b_0 =
MAKE_SC();
536 constexpr
auto c_gst_o_0 =
MAKE_SC();
537 constexpr
auto c_gld_b_1 =
MAKE_SC();
538 constexpr
auto c_gst_o_1 =
MAKE_SC();
542 gemm_1(acc_1s[I1], y, ds[I1], i_issue);
543 constexpr
index_t slot = sr.at(i_issue);
544 if constexpr(slot &
GLD_B)
547 if constexpr(slot &
GST_O)
549 auto out = cast_tile<ODataType>(acc_1s[I0]);
558 gemm_1(acc_1s[I0], y, ds[I0], i_issue);
559 constexpr
index_t slot = sr.at(i_issue);
560 if constexpr(slot &
GLD_B)
563 if constexpr(slot &
GST_O)
565 auto out = cast_tile<ODataType>(acc_1s[I1]);
572 auto pipeline_gemm1_head = [&]() {
573 constexpr
index_t total_loops = issues_gemm1;
574 constexpr
auto sr = Policy::template GetSequencer_1<Problem>();
575 static_assert(sr.size() == total_loops);
577 constexpr
auto c_gld_b_0 =
MAKE_SC();
581 gemm_1(acc_1s[I0], y, ds[I0], i_issue);
582 constexpr
index_t slot = sr.at(i_issue);
583 if constexpr(slot &
GLD_B)
588 auto pipeline_gemm1_tail = [&]() {
589 constexpr
index_t total_loops = issues_gemm1;
590 constexpr
auto sr = Policy::template GetSequencer_1<Problem>();
591 static_assert(sr.size() == total_loops);
593 constexpr
auto c_gst_o_0 =
MAKE_SC();
597 gemm_1(acc_1s[I1], y, ds[I1], i_issue);
599 constexpr
index_t slot = sr.at(i_issue);
600 if constexpr(slot &
GST_O)
602 auto out = cast_tile<ODataType>(acc_1s[I0]);
607 auto out = cast_tile<ODataType>(acc_1s[I1]);
608 atomic_add_o(out, NEG1);
614 gld_a(a_sst_win0, NEG1, TRUE);
615 gld_g(gs[I0], NEG1, TRUE);
621 gld_a(a_sst_win1, NEG1);
629 const index_t iters_0 = (num_blocks_k0 - 2) / 2;
631 while(i_0++ < iters_0)
635 pipeline_gemm0_tail();
639 const index_t iters_1 = (num_blocks_n1 - 2) / 2;
641 pipeline_gemm1_head();
642 while(i_1++ < iters_1)
646 pipeline_gemm1_tail();
#define CK_TILE_DEVICE
Definition: config.hpp:40
#define CK_TILE_LDS_ADDR
Definition: config.hpp:56
#define CK_TILE_HOST_DEVICE
Definition: config.hpp:41
constexpr CK_TILE_HOST_DEVICE auto make_embed_tile_distribution_encoding(OuterDstr, InnerDstr)
Definition: tile_distribution_encoding.hpp:420
Definition: cluster_descriptor.hpp:13
CK_TILE_DEVICE auto async_load_tile_raw(LdsTileWindow_ &&lds_tile, const tile_window_with_static_distribution< BottomTensorView_, WindowLengths_, TileDistribution_, NumCoord > &tile_window, number< i_access >={}, bool_constant< oob_conditional_check >={}, bool_constant< pre_nop >={})
Definition: load_tile.hpp:149
tuple_array< T, N > statically_indexed_array
Definition: statically_indexed_array.hpp:16
int32_t index_t
Definition: integer.hpp:9
CK_TILE_DEVICE void lds_load_fence(index_t cnt=0)
Definition: amd_buffer_addressing.hpp:624
constexpr CK_TILE_HOST_DEVICE auto pad_tensor_view(const TensorView &tensor_view, const TileLengths &tile_lengths, DoPads)
Definition: tensor_view.hpp:480
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition: type_traits.hpp:20
CK_TILE_DEVICE auto load_tile(const tile_window_with_static_distribution< BottomTensorView_, WindowLengths_, TileDistribution_, NumCoord > &tile_window, number< i_access >={}, bool_constant< oob_conditional_check >={})
Definition: load_tile.hpp:27
constexpr CK_TILE_HOST_DEVICE auto to_sequence(tuple< number< Is >... >)
Definition: sequence.hpp:1046
constexpr CK_TILE_HOST_DEVICE auto merge_sequences(Seqs...)
Definition: sequence.hpp:817
constexpr CK_TILE_DEVICE auto make_tile_window(null_tensor_view, const WindowLengths &window_lengths, const multi_index< WindowLengths::size()> &, Ts &&...)
Definition: null_tile_window.hpp:72
CK_TILE_DEVICE auto load_tile_raw(T &tile, const tile_window_with_static_distribution< BottomTensorView_, WindowLengths_, TileDistribution_, NumCoord > &tile_window, number< i_access >={}, bool_constant< oob_conditional_check >={}, bool_constant< pre_nop >={})
Loads a tile of data using inline assembly.
Definition: load_tile.hpp:106
CK_TILE_DEVICE void move_tile_window(null_tile_window< WindowLengths > &, const typename null_tile_window< WindowLengths >::BottomTensorIndex &)
Definition: null_tile_window.hpp:92
constexpr CK_TILE_DEVICE auto make_tile_window_linear(const TensorView_ &tensor_view, const WindowLengths_ &window_lengths, const multi_index< TensorView_::get_num_of_dimension()> &origin, const StaticTileDistribution_ &tile_distribution, LinearBottomDims_={})
Definition: tile_window_linear.hpp:1124
constexpr CK_TILE_HOST_DEVICE auto generate_tuple(F &&f, number< N >)
Definition: tuple.hpp:400
constexpr CK_TILE_HOST_DEVICE auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:337
CK_TILE_DEVICE void store_tile(tile_window_with_static_lengths< BottomTensorView_, WindowLengths_ > &tile_window_tmp, const static_distributed_tensor< DataType_, TileDistribution_ > &dstr_tensor)
Definition: store_tile.hpp:23
CK_TILE_DEVICE void clear_tile(DstrTensors &dstr_tensor)
Definition: tile_elementwise.hpp:145
CK_TILE_DEVICE void block_sync_load_raw(index_t cnt=0)
Definition: arch.hpp:95
constexpr CK_TILE_HOST_DEVICE auto make_static_tile_distribution(StaticTileDistributionEncoding_)
Definition: tile_distribution.hpp:498
typename uniform_sequence_gen< NSize, I >::type uniform_sequence_gen_t
Definition: sequence.hpp:1017
CK_TILE_DEVICE void update_tile_raw(tile_window_with_static_distribution< BottomTensorView_, WindowLengths_, TileDistribution_, NumCoord > &tile_window, const static_distributed_tensor< DataType_, TileDistribution_ > &dstr_tensor, number< i_access >={}, bool_constant< oob_conditional_check >={}, bool_constant< pre_nop >={})
Definition: update_tile.hpp:68
#define NEXT_SCI(c_, static_i_)
Definition: static_counter.hpp:109
#define MAKE_SC()
Definition: static_counter.hpp:104
Definition: fused_moegemm_pipeline_flatmm_ex.hpp:23
static constexpr index_t SLD_A
Definition: fused_moegemm_pipeline_flatmm_ex.hpp:54
static constexpr index_t kAlignmentA
Definition: fused_moegemm_pipeline_flatmm_ex.hpp:49
static constexpr bool PadHiddenSize
Definition: fused_moegemm_pipeline_flatmm_ex.hpp:46
static constexpr CK_TILE_HOST_DEVICE ck_tile::index_t GetSmemSize_A()
Definition: fused_moegemm_pipeline_flatmm_ex.hpp:72
typename Problem::DScaleDataType DScaleDataType
Definition: fused_moegemm_pipeline_flatmm_ex.hpp:36
typename Problem::BlockShape BlockShape
Definition: fused_moegemm_pipeline_flatmm_ex.hpp:27
static constexpr index_t kAlignmentO
Definition: fused_moegemm_pipeline_flatmm_ex.hpp:52
typename Problem::IndexDataType IndexDataType
Definition: fused_moegemm_pipeline_flatmm_ex.hpp:39
static constexpr index_t kBlockPerCu
Definition: fused_moegemm_pipeline_flatmm_ex.hpp:59
remove_cvref_t< Policy_ > Policy
Definition: fused_moegemm_pipeline_flatmm_ex.hpp:25
static constexpr const char * name
Definition: fused_moegemm_pipeline_flatmm_ex.hpp:69
static constexpr index_t GLD_B
Definition: fused_moegemm_pipeline_flatmm_ex.hpp:56
static constexpr index_t GLD_A
Definition: fused_moegemm_pipeline_flatmm_ex.hpp:55
remove_cvref_t< Problem_ > Problem
Definition: fused_moegemm_pipeline_flatmm_ex.hpp:24
typename Problem::ADataType ADataType
Definition: fused_moegemm_pipeline_flatmm_ex.hpp:29
static CK_TILE_HOST_DEVICE auto GetOCoord()
Definition: fused_moegemm_pipeline_flatmm_ex.hpp:91
static constexpr CK_TILE_HOST_DEVICE ck_tile::index_t GetSmemSize()
Definition: fused_moegemm_pipeline_flatmm_ex.hpp:77
typename Problem::GDataType GDataType
Definition: fused_moegemm_pipeline_flatmm_ex.hpp:30
static constexpr index_t kAlignmentG
Definition: fused_moegemm_pipeline_flatmm_ex.hpp:50
typename Problem::DDataType DDataType
Definition: fused_moegemm_pipeline_flatmm_ex.hpp:31
static CK_TILE_HOST_DEVICE auto GetACoord()
Definition: fused_moegemm_pipeline_flatmm_ex.hpp:83
typename Problem::ODataType ODataType
Definition: fused_moegemm_pipeline_flatmm_ex.hpp:33
typename Problem::GScaleDataType GScaleDataType
Definition: fused_moegemm_pipeline_flatmm_ex.hpp:35
typename Problem::AccDataType AccDataType
Definition: fused_moegemm_pipeline_flatmm_ex.hpp:32
static constexpr index_t GST_O
Definition: fused_moegemm_pipeline_flatmm_ex.hpp:57
typename Problem::TopkWeightDataType TopkWeightDataType
Definition: fused_moegemm_pipeline_flatmm_ex.hpp:38
static constexpr index_t kAlignmentD
Definition: fused_moegemm_pipeline_flatmm_ex.hpp:51
typename Problem::YDataType YDataType
Definition: fused_moegemm_pipeline_flatmm_ex.hpp:40
static constexpr bool IsGateOnly
Definition: fused_moegemm_pipeline_flatmm_ex.hpp:44
typename Problem::Traits Traits
Definition: fused_moegemm_pipeline_flatmm_ex.hpp:42
typename Problem::YSmoothScaleDataType YSmoothScaleDataType
Definition: fused_moegemm_pipeline_flatmm_ex.hpp:37
static constexpr bool PadIntermediateSize
Definition: fused_moegemm_pipeline_flatmm_ex.hpp:47
typename Problem::AScaleDataType AScaleDataType
Definition: fused_moegemm_pipeline_flatmm_ex.hpp:34
static constexpr bool UseSmoothQuant
Definition: fused_moegemm_pipeline_flatmm_ex.hpp:45
CK_TILE_DEVICE auto operator()(const AWindow &a_window_, const GWindow &g_window_, const DWindow &d_window_, OWindow &o_window_, TopkWeightDataType, CK_TILE_LDS_ADDR void *smem, index_t hidden_size, index_t intermediate_size)
Definition: fused_moegemm_pipeline_flatmm_ex.hpp:99
Definition: integral_constant.hpp:13
static constexpr value_type value
Definition: integral_constant.hpp:16
Definition: sequence.hpp:52
Definition: functional.hpp:43
Definition: tile_distribution_encoding.hpp:26
Definition: tuple.hpp:192