21 template <
typename Problem_,
typename Policy_ = FusedMoeGemmPipelineFlatmmPolicy>
42 using Traits =
typename Problem::Traits;
60 if constexpr(Problem::kBlockPerCu != -1)
61 return Problem::kBlockPerCu;
69 static constexpr
const char*
name =
"fused_moe_flatmm";
74 return Policy::template GetSmemSize_A<Problem>();
79 return Policy::template GetSmemSize<Problem>();
85 constexpr
auto a_dist = Policy::template MakeGlobalTileDistribution_A<Problem>();
86 const auto a_coord = a_dist.calculate_index();
93 constexpr
auto o_dist = Policy::template MakeOGlobalTileDistribution<Problem>();
94 const auto o_coord = o_dist.calculate_index();
98 template <
typename AWindow,
typename GWindow,
typename DWindow,
typename OWindow>
100 const GWindow& g_window_,
101 const DWindow& d_window_,
108 _Pragma(
"clang diagnostic push") _Pragma(
"clang diagnostic ignored \"-Wc++20-extensions\"");
109 constexpr
auto NEG1 =
number<-1>{};
118 Policy::template GetSmemSize_A<Problem>());
120 auto g_view = g_window_.get_bottom_tensor_view();
122 auto u_view = [&]() {
129 index_t nr_0 = intermediate_size / BlockShape::Block_Nr0;
130 index_t kr_0 = hidden_size / BlockShape::Block_Kr0;
133 g_window_.get_bottom_tensor_view().get_buffer_view().p_data_;
136 const auto u_view_ = make_naive_tensor_view<address_space_enum::global>(
142 const auto u_view_1_ =
153 a_window_, Policy::template MakeGlobalTileDistribution_A<Problem>());
156 Policy::template MakeGlobalTileDistribution_G<Problem>(),
160 Policy::template MakeGlobalTileDistribution_D<Problem>(),
163 o_window_, Policy::template MakeGlobalTileDistribution_O<Problem>());
165 using g_thread_type = decltype(
load_tile(g_win));
166 using d_thread_type = decltype(
load_tile(d_win));
168 using WarpGemm0 = decltype(Policy::template GetWarpGemm0<Problem>());
169 using WarpGemm1 = decltype(Policy::template GetWarpGemm1<Problem>());
170 auto warp_gemm_0 = WarpGemm0{};
171 auto warp_gemm_1 = WarpGemm1{};
176 smem_0, Policy::template MakeLdsStoreDesc_A<Problem>()),
177 Policy::template MakeLdsStoreDesc_A<Problem>().get_lengths(),
182 smem_1, Policy::template MakeLdsStoreDesc_A<Problem>()),
183 Policy::template MakeLdsStoreDesc_A<Problem>().get_lengths(),
186 auto a_sld_win0 = [&]() {
187 using WG = WarpGemm0;
197 a_outer_dstr_enc,
typename WG::AWarpDstrEncoding{});
199 make_tensor_view<address_space_enum::lds>(
200 smem_0, Policy::template MakeLdsLoadDesc_A<Problem>()),
201 Policy::template MakeLdsLoadDesc_A<Problem>().get_lengths(),
207 auto a_sld_win1 = [&]() {
208 using WG = WarpGemm0;
218 a_outer_dstr_enc,
typename WG::AWarpDstrEncoding{});
220 make_tensor_view<address_space_enum::lds>(
221 smem_1, Policy::template MakeLdsLoadDesc_A<Problem>()),
222 Policy::template MakeLdsLoadDesc_A<Problem>().get_lengths(),
227 auto bridge_sst_win = [&]() {
229 make_tensor_view<address_space_enum::lds>(
231 Policy::template MakeBridgeLdsStoreDesc<Problem>()),
232 Policy::template MakeBridgeLdsStoreDesc<Problem>().get_lengths(),
236 auto bridge_sld_win = [&]() {
238 make_tensor_view<address_space_enum::lds>(
240 Policy::template MakeBridgeLdsLoadDesc<Problem>()),
241 Policy::template MakeBridgeLdsLoadDesc<Problem>().get_lengths(),
243 Policy::template MakeYTileDistribution<Problem>());
249 constexpr
auto issues_a =
number<a_win.get_num_of_access()>{};
250 constexpr
auto issues_g =
number<g_win.get_num_of_access()>{};
253 constexpr
auto issues_gemm0 =
254 number<BlockShape::Repeat_M0 * BlockShape::Repeat_N0 * BlockShape::Repeat_K0 *
255 warp_gemm_0.get_num_of_access()>{};
256 constexpr
auto issues_gemm1 =
257 number<BlockShape::Repeat_M1 * BlockShape::Repeat_N1 * BlockShape::Repeat_K1 *
258 warp_gemm_1.get_num_of_access()>{};
262 (hidden_size + BlockShape::Block_K0 - 1) / BlockShape::Block_K0;
264 (hidden_size + BlockShape::Block_N1 - 1) / BlockShape::Block_N1;
266 using a_thread_type = decltype(
load_tile(a_sld_win0));
270 auto& a_store_,
auto i_access, PreNop = {}) {
273 auto move_a = [&]() {
276 auto sld_a = [&](
auto& a_,
auto& win_,
auto i_access) {
285 if constexpr(i_access.
value == 0)
287 g_win.bottom_tensor_view_ = g_view;
289 else if constexpr(i_access.
value == issues_g / 2)
291 g_win.bottom_tensor_view_ = u_view;
296 auto move_g = [&]() {
305 auto move_d = [&]() {
315 auto acc_0 = Policy::template MakeCBlockTile_Gemm0<Problem>();
317 [&](
auto) {
return Policy::template MakeCBlockTile_Gemm1<Problem>(); },
number<2>{});
321 (
auto& t_c,
auto& t_a,
auto& t_b,
auto i_access, PostNop = {}) {
324 constexpr
auto repeat_sub = WarpGemm::get_num_of_access();
325 constexpr
auto repeat_m = BlockShape::Repeat_M0;
327 constexpr
auto repeat_k = BlockShape::Repeat_K0;
329 constexpr
auto i_sub = i_access % repeat_sub;
330 constexpr
auto i_k = (i_access / repeat_sub) % repeat_k;
331 constexpr
auto i_m = (i_access / (repeat_sub * repeat_k )) % repeat_m;
332 constexpr
auto i_n = (i_access / (repeat_sub * repeat_k )) / repeat_m;
334 using AWarpTensor =
typename WarpGemm::AWarpTensor;
335 using BWarpTensor =
typename WarpGemm::BWarpTensor;
336 using CWarpTensor =
typename WarpGemm::CWarpTensor;
337 using AWarpDstr =
typename WarpGemm::AWarpDstr;
338 using BWarpDstr =
typename WarpGemm::BWarpDstr;
339 using CWarpDstr =
typename WarpGemm::CWarpDstr;
345 constexpr
auto a_warp_y_lengths =
to_sequence(AWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
346 constexpr
auto b_warp_y_lengths =
to_sequence(BWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
347 constexpr
auto c_warp_y_lengths =
to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
350 w_a.get_thread_buffer() = t_a.get_y_sliced_thread_data(
355 w_b.get_thread_buffer() = t_b.get_y_sliced_thread_data(
360 w_c.get_thread_buffer() = t_c.get_y_sliced_thread_data(
366 t_c.set_y_sliced_thread_data(
369 w_c.get_thread_buffer());
375 (
auto& t_c,
auto& t_a,
auto& t_b,
auto i_access, PostNop = {}) {
378 constexpr
auto repeat_sub = WarpGemm::get_num_of_access();
379 constexpr
auto repeat_m = BlockShape::Repeat_M0;
381 constexpr
auto repeat_k = BlockShape::Repeat_K0;
383 constexpr
auto i_sub = i_access % repeat_sub;
384 constexpr
auto i_k = (i_access / repeat_sub) % repeat_k;
385 constexpr
auto i_m = (i_access / (repeat_sub * repeat_k )) % repeat_m;
386 constexpr
auto i_n = (i_access / (repeat_sub * repeat_k )) / repeat_m;
388 using AWarpTensor =
typename WarpGemm::AWarpTensor;
389 using BWarpTensor =
typename WarpGemm::BWarpTensor;
390 using CWarpTensor =
typename WarpGemm::CWarpTensor;
391 using AWarpDstr =
typename WarpGemm::AWarpDstr;
392 using BWarpDstr =
typename WarpGemm::BWarpDstr;
393 using CWarpDstr =
typename WarpGemm::CWarpDstr;
399 constexpr
auto a_warp_y_lengths =
to_sequence(AWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
400 constexpr
auto b_warp_y_lengths =
to_sequence(BWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
401 constexpr
auto c_warp_y_lengths =
to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
404 w_a.get_thread_buffer() = t_a.get_y_sliced_thread_data(
409 w_b.get_thread_buffer() = t_b.get_y_sliced_thread_data(
414 w_c.get_thread_buffer() = t_c.get_y_sliced_thread_data(
420 t_c.set_y_sliced_thread_data(
423 w_c.get_thread_buffer());
426 _Pragma(
"clang diagnostic pop");
435 auto pipeline_gemm0 = [&]() {
436 constexpr
index_t total_loops = issues_gemm0;
437 constexpr
auto sr = Policy::template GetSequencer_0<Problem>();
438 static_assert(sr.size() == total_loops);
440 constexpr
auto c_sld_a_0 =
MAKE_SC();
441 constexpr
auto c_gld_a_0 =
MAKE_SC();
442 constexpr
auto c_gld_b_0 =
MAKE_SC();
445 gemm_0(acc_0, as[I0], gs[I0], i_issue);
446 constexpr
index_t slot = sr.at(i_issue);
448 if constexpr(slot &
SLD_A)
449 sld_a(as[I1], a_sld_win1,
number<
NEXT_SCI(c_sld_a_0, i_issue)>{});
450 if constexpr(slot &
GLD_A)
452 if constexpr(slot &
GLD_B)
457 block_sync_load_raw(issues_a + issues_g);
460 constexpr
auto c_sld_a_1 =
MAKE_SC();
461 constexpr
auto c_gld_a_1 =
MAKE_SC();
462 constexpr
auto c_gld_b_1 =
MAKE_SC();
466 gemm_0(acc_0, as[I1], gs[I1], i_issue);
467 constexpr
index_t slot = sr.at(i_issue);
469 if constexpr(slot &
SLD_A)
470 sld_a(as[I0], a_sld_win0,
number<
NEXT_SCI(c_sld_a_1, i_issue)>{});
471 if constexpr(slot &
GLD_A)
473 if constexpr(slot &
GLD_B)
478 block_sync_load_raw(issues_a + issues_g);
482 auto pipeline_gemm0_tail = [&]() {
483 constexpr
index_t total_loops = issues_gemm0;
484 constexpr
auto sr = Policy::template GetSequencer_0<Problem>();
485 static_assert(sr.size() == total_loops);
487 constexpr
auto c_gld_b_0 =
MAKE_SC();
491 gemm_0(acc_0, as[I0], gs[I0], i_issue);
492 constexpr
index_t slot = sr.at(i_issue);
494 if constexpr(slot &
GLD_B)
498 block_sync_load_raw(issues_g);
499 sld_a(as[I1], a_sld_win1, NEG1);
503 constexpr
auto last_nop = [&]() {
504 if constexpr(i_issue == (total_loops - 1))
509 gemm_0(acc_0, as[I1], gs[I1], i_issue, last_nop);
513 auto y = Policy::template MakeYBlockTile<Problem>();
515 auto pipeline_bridge = [&]() {
517 auto y_pre = cast_tile<YDataType>(acc_0);
526 auto pipeline_gemm1 = [&]() {
527 constexpr
index_t total_loops = issues_gemm1;
528 constexpr
auto sr = Policy::template GetSequencer_1<Problem>();
529 static_assert(sr.size() == total_loops);
531 constexpr
auto c_gld_b_0 =
MAKE_SC();
532 constexpr
auto c_gst_o_0 =
MAKE_SC();
533 constexpr
auto c_gld_b_1 =
MAKE_SC();
534 constexpr
auto c_gst_o_1 =
MAKE_SC();
538 gemm_1(acc_1s[I1], y, ds[I1], i_issue);
539 constexpr
index_t slot = sr.at(i_issue);
540 if constexpr(slot &
GLD_B)
543 if constexpr(slot &
GST_O)
545 auto out = cast_tile<ODataType>(acc_1s[I0]);
554 gemm_1(acc_1s[I0], y, ds[I0], i_issue);
555 constexpr
index_t slot = sr.at(i_issue);
556 if constexpr(slot &
GLD_B)
559 if constexpr(slot &
GST_O)
561 auto out = cast_tile<ODataType>(acc_1s[I1]);
568 auto pipeline_gemm1_head = [&]() {
569 constexpr
index_t total_loops = issues_gemm1;
570 constexpr
auto sr = Policy::template GetSequencer_1<Problem>();
571 static_assert(sr.size() == total_loops);
573 constexpr
auto c_gld_b_0 =
MAKE_SC();
577 gemm_1(acc_1s[I0], y, ds[I0], i_issue);
578 constexpr
index_t slot = sr.at(i_issue);
579 if constexpr(slot &
GLD_B)
584 auto pipeline_gemm1_tail = [&]() {
585 constexpr
index_t total_loops = issues_gemm1;
586 constexpr
auto sr = Policy::template GetSequencer_1<Problem>();
587 static_assert(sr.size() == total_loops);
589 constexpr
auto c_gst_o_0 =
MAKE_SC();
593 gemm_1(acc_1s[I1], y, ds[I1], i_issue);
595 constexpr
index_t slot = sr.at(i_issue);
596 if constexpr(slot &
GST_O)
598 auto out = cast_tile<ODataType>(acc_1s[I0]);
603 auto out = cast_tile<ODataType>(acc_1s[I1]);
604 atomic_add_o(out, NEG1);
610 gld_a(a_sst_win0, NEG1, TRUE);
611 gld_g(gs[I0], NEG1, TRUE);
617 gld_a(a_sst_win1, NEG1);
621 block_sync_load_raw(issues_a + issues_g);
625 const index_t iters_0 = (num_blocks_k0 - 2) / 2;
627 while(i_0++ < iters_0)
631 pipeline_gemm0_tail();
635 const index_t iters_1 = (num_blocks_n1 - 2) / 2;
637 pipeline_gemm1_head();
638 while(i_1++ < iters_1)
642 pipeline_gemm1_tail();
#define CK_TILE_DEVICE
Definition: config.hpp:41
#define CK_TILE_LDS_ADDR
Definition: config.hpp:58
#define CK_TILE_HOST_DEVICE
Definition: config.hpp:42
constexpr CK_TILE_HOST_DEVICE auto make_embed_tile_distribution_encoding(OuterDstr, InnerDstr)
Definition: tile_distribution_encoding.hpp:457
Definition: cluster_descriptor.hpp:13
tuple_array< T, N > statically_indexed_array
Definition: statically_indexed_array.hpp:16
int32_t index_t
Definition: integer.hpp:9
CK_TILE_DEVICE void lds_load_fence(index_t cnt=0)
Definition: amd_buffer_addressing.hpp:757
constexpr CK_TILE_HOST_DEVICE auto pad_tensor_view(const TensorView &tensor_view, const TileLengths &tile_lengths, DoPads)
Definition: tensor_view.hpp:530
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition: type_traits.hpp:21
constexpr CK_TILE_HOST_DEVICE auto to_sequence(tuple< number< Is >... >)
Definition: sequence.hpp:1052
constexpr CK_TILE_HOST_DEVICE auto merge_sequences(Seqs...)
Definition: sequence.hpp:823
constexpr CK_TILE_DEVICE auto make_tile_window(null_tensor_view, const WindowLengths &window_lengths, const multi_index< WindowLengths::size()> &, Ts &&...)
Definition: null_tile_window.hpp:75
CK_TILE_DEVICE auto load_tile_raw(T &tile, const tile_window_with_static_distribution< BottomTensorView_, WindowLengths_, TileDistribution_, NumCoord > &tile_window, number< i_access >={}, bool_constant< oob_conditional_check >={}, bool_constant< pre_nop >={})
Loads a tile of data using inline assembly.
Definition: load_tile.hpp:58
CK_TILE_DEVICE void move_tile_window(null_tile_window< WindowLengths > &, const typename null_tile_window< WindowLengths >::BottomTensorIndex &)
Definition: null_tile_window.hpp:95
constexpr CK_TILE_DEVICE auto make_tile_window_linear(const TensorView_ &tensor_view, const WindowLengths_ &window_lengths, const multi_index< TensorView_::get_num_of_dimension()> &origin, const StaticTileDistribution_ &tile_distribution, LinearBottomDims_={})
Definition: tile_window_linear.hpp:993
constexpr CK_TILE_HOST_DEVICE auto generate_tuple(F &&f, number< N >)
Definition: tuple.hpp:429
constexpr CK_TILE_HOST_DEVICE auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:360
CK_TILE_DEVICE auto async_load_tile_raw(LdsTileWindow_ &&lds_tile, const TileWindow_ &tile_window, number< i_access >={}, bool_constant< oob_conditional_check >={}, bool_constant< pre_nop >={})
Definition: load_tile.hpp:110
CK_TILE_DEVICE void store_tile(tile_window_with_static_lengths< BottomTensorView_, WindowLengths_ > &tile_window_tmp, const static_distributed_tensor< DataType_, TileDistribution_ > &dstr_tensor)
Definition: store_tile.hpp:23
CK_TILE_DEVICE void clear_tile(DstrTensors &dstr_tensor)
Definition: tile_elementwise.hpp:177
CK_TILE_DEVICE auto load_tile(const TileWindow_ &tile_window, number< i_access >={}, bool_constant< oob_conditional_check >={})
Definition: load_tile.hpp:22
constexpr CK_TILE_HOST_DEVICE auto make_static_tile_distribution(StaticTileDistributionEncoding_)
Definition: tile_distribution.hpp:480
typename uniform_sequence_gen< NSize, I >::type uniform_sequence_gen_t
Definition: sequence.hpp:1023
CK_TILE_DEVICE void update_tile_raw(tile_window_with_static_distribution< BottomTensorView_, WindowLengths_, TileDistribution_, NumCoord > &tile_window, const static_distributed_tensor< DataType_, TileDistribution_ > &dstr_tensor, number< i_access >={}, bool_constant< oob_conditional_check >={}, bool_constant< pre_nop >={})
Definition: update_tile.hpp:68
#define NEXT_SCI(c_, static_i_)
Definition: static_counter.hpp:109
#define MAKE_SC()
Definition: static_counter.hpp:104
Definition: fused_moegemm_pipeline_flatmm_ex.hpp:23
static constexpr index_t SLD_A
Definition: fused_moegemm_pipeline_flatmm_ex.hpp:54
static constexpr index_t kAlignmentA
Definition: fused_moegemm_pipeline_flatmm_ex.hpp:49
static constexpr bool PadHiddenSize
Definition: fused_moegemm_pipeline_flatmm_ex.hpp:46
static constexpr CK_TILE_HOST_DEVICE ck_tile::index_t GetSmemSize_A()
Definition: fused_moegemm_pipeline_flatmm_ex.hpp:72
typename Problem::DScaleDataType DScaleDataType
Definition: fused_moegemm_pipeline_flatmm_ex.hpp:36
typename Problem::BlockShape BlockShape
Definition: fused_moegemm_pipeline_flatmm_ex.hpp:27
static constexpr index_t kAlignmentO
Definition: fused_moegemm_pipeline_flatmm_ex.hpp:52
typename Problem::IndexDataType IndexDataType
Definition: fused_moegemm_pipeline_flatmm_ex.hpp:39
static constexpr index_t kBlockPerCu
Definition: fused_moegemm_pipeline_flatmm_ex.hpp:59
remove_cvref_t< Policy_ > Policy
Definition: fused_moegemm_pipeline_flatmm_ex.hpp:25
static constexpr const char * name
Definition: fused_moegemm_pipeline_flatmm_ex.hpp:69
static constexpr index_t GLD_B
Definition: fused_moegemm_pipeline_flatmm_ex.hpp:56
static constexpr index_t GLD_A
Definition: fused_moegemm_pipeline_flatmm_ex.hpp:55
remove_cvref_t< Problem_ > Problem
Definition: fused_moegemm_pipeline_flatmm_ex.hpp:24
typename Problem::ADataType ADataType
Definition: fused_moegemm_pipeline_flatmm_ex.hpp:29
static CK_TILE_HOST_DEVICE auto GetOCoord()
Definition: fused_moegemm_pipeline_flatmm_ex.hpp:91
static constexpr CK_TILE_HOST_DEVICE ck_tile::index_t GetSmemSize()
Definition: fused_moegemm_pipeline_flatmm_ex.hpp:77
typename Problem::GDataType GDataType
Definition: fused_moegemm_pipeline_flatmm_ex.hpp:30
static constexpr index_t kAlignmentG
Definition: fused_moegemm_pipeline_flatmm_ex.hpp:50
typename Problem::DDataType DDataType
Definition: fused_moegemm_pipeline_flatmm_ex.hpp:31
static CK_TILE_HOST_DEVICE auto GetACoord()
Definition: fused_moegemm_pipeline_flatmm_ex.hpp:83
typename Problem::ODataType ODataType
Definition: fused_moegemm_pipeline_flatmm_ex.hpp:33
typename Problem::GScaleDataType GScaleDataType
Definition: fused_moegemm_pipeline_flatmm_ex.hpp:35
typename Problem::AccDataType AccDataType
Definition: fused_moegemm_pipeline_flatmm_ex.hpp:32
static constexpr index_t GST_O
Definition: fused_moegemm_pipeline_flatmm_ex.hpp:57
typename Problem::TopkWeightDataType TopkWeightDataType
Definition: fused_moegemm_pipeline_flatmm_ex.hpp:38
static constexpr index_t kAlignmentD
Definition: fused_moegemm_pipeline_flatmm_ex.hpp:51
typename Problem::YDataType YDataType
Definition: fused_moegemm_pipeline_flatmm_ex.hpp:40
static constexpr bool IsGateOnly
Definition: fused_moegemm_pipeline_flatmm_ex.hpp:44
typename Problem::Traits Traits
Definition: fused_moegemm_pipeline_flatmm_ex.hpp:42
typename Problem::YSmoothScaleDataType YSmoothScaleDataType
Definition: fused_moegemm_pipeline_flatmm_ex.hpp:37
static constexpr bool PadIntermediateSize
Definition: fused_moegemm_pipeline_flatmm_ex.hpp:47
typename Problem::AScaleDataType AScaleDataType
Definition: fused_moegemm_pipeline_flatmm_ex.hpp:34
static constexpr bool UseSmoothQuant
Definition: fused_moegemm_pipeline_flatmm_ex.hpp:45
CK_TILE_DEVICE auto operator()(const AWindow &a_window_, const GWindow &g_window_, const DWindow &d_window_, OWindow &o_window_, TopkWeightDataType, CK_TILE_LDS_ADDR void *smem, index_t hidden_size, index_t intermediate_size)
Definition: fused_moegemm_pipeline_flatmm_ex.hpp:99
Definition: integral_constant.hpp:13
static constexpr value_type value
Definition: integral_constant.hpp:16
Definition: sequence.hpp:49
Definition: functional.hpp:43
Definition: tile_distribution_encoding.hpp:26
Definition: tuple.hpp:192