include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm_ex.hpp Source File

include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm_ex.hpp Source File#

Composable Kernel: include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm_ex.hpp Source File
fused_moegemm_pipeline_flatmm_ex.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include "ck_tile/core.hpp"
9 
10 namespace ck_tile {
11 
12 /*
13 This pipeline deal with a gemm(actually 2 gemm) with one very small(token), one very big(weight)
14 we need to design the pipeline such that all waves along gemm-N dim (gemm-m only 1 wave)
15 
16  <----- gemm-N ------>
17  +----+----+----+----+
18  | w0 | w1 | w2 | w3 | gemm-m
19  +----+----+----+----+
20 */
21 template <typename Problem_, typename Policy_ = FusedMoeGemmPipelineFlatmmPolicy>
23 {
26 
27  using BlockShape = typename Problem::BlockShape; // this is FusedMoeGemmShape
28 
29  using ADataType = typename Problem::ADataType;
30  using GDataType = typename Problem::GDataType;
31  using DDataType = typename Problem::DDataType;
32  using AccDataType = typename Problem::AccDataType;
33  using ODataType = typename Problem::ODataType;
34  using AScaleDataType = typename Problem::AScaleDataType;
35  using GScaleDataType = typename Problem::GScaleDataType;
36  using DScaleDataType = typename Problem::DScaleDataType;
37  using YSmoothScaleDataType = typename Problem::YSmoothScaleDataType;
38  using TopkWeightDataType = typename Problem::TopkWeightDataType;
39  using IndexDataType = typename Problem::IndexDataType;
40  using YDataType = typename Problem::YDataType;
41 
42  using Traits = typename Problem::Traits;
43 
44  static constexpr bool IsGateOnly = Traits::IsGateOnly;
45  static constexpr bool UseSmoothQuant = Traits::UseSmoothQuant;
46  static constexpr bool PadHiddenSize = Traits::PadHiddenSize;
47  static constexpr bool PadIntermediateSize = Traits::PadIntermediateSize;
48 
49  static constexpr index_t kAlignmentA = Policy::template GetAlignment_A<Problem>();
50  static constexpr index_t kAlignmentG = Policy::template GetAlignment_G<Problem>();
51  static constexpr index_t kAlignmentD = Policy::template GetAlignment_D<Problem>();
52  static constexpr index_t kAlignmentO = Policy::template GetAlignment_O<Problem>();
53 
58 
59  static constexpr index_t kBlockPerCu = []() {
60  if constexpr(Problem::kBlockPerCu != -1)
61  return Problem::kBlockPerCu;
62  else
63  {
64  // minimize occupancy
65  return 2;
66  }
67  }();
68 
69  static constexpr const char* name = "fused_moe_flatmm";
70 
71  // TODO: there are multiple buffers
73  {
74  return Policy::template GetSmemSize_A<Problem>();
75  }
76 
78  {
79  return Policy::template GetSmemSize<Problem>();
80  }
81 
82  // this is the thread-offset along row/col
84  {
85  constexpr auto a_dist = Policy::template MakeGlobalTileDistribution_A<Problem>();
86  const auto a_coord = a_dist.calculate_index();
87  return a_coord;
88  }
89 
90  // this is the thread-offset along row/col
92  {
93  constexpr auto o_dist = Policy::template MakeOGlobalTileDistribution<Problem>();
94  const auto o_coord = o_dist.calculate_index();
95  return o_coord;
96  }
97 
98  template <typename AWindow, typename GWindow, typename DWindow, typename OWindow>
99  CK_TILE_DEVICE auto operator()(const AWindow& a_window_,
100  const GWindow& g_window_,
101  const DWindow& d_window_,
102  OWindow& o_window_,
103  TopkWeightDataType /*topk_weight*/,
104  CK_TILE_LDS_ADDR void* smem,
105  index_t hidden_size,
106  index_t intermediate_size)
107  {
108  _Pragma("clang diagnostic push") _Pragma("clang diagnostic ignored \"-Wc++20-extensions\"");
109  constexpr auto NEG1 = number<-1>{};
110  constexpr auto I0 = number<0>{};
111  constexpr auto I1 = number<1>{};
112  constexpr auto TRUE = bool_constant<true>{};
113  constexpr auto FALSE = bool_constant<false>{};
114 
115  CK_TILE_LDS_ADDR ADataType* smem_0 = reinterpret_cast<CK_TILE_LDS_ADDR ADataType*>(smem);
116  CK_TILE_LDS_ADDR ADataType* smem_1 = reinterpret_cast<CK_TILE_LDS_ADDR ADataType*>(
117  reinterpret_cast<CK_TILE_LDS_ADDR char*>(smem) +
118  Policy::template GetSmemSize_A<Problem>());
119 
120  auto g_view = g_window_.get_bottom_tensor_view();
121 
122  auto u_view = [&]() {
123  if constexpr(IsGateOnly)
124  {
125  return g_view;
126  }
127  else
128  {
129  index_t nr_0 = intermediate_size / BlockShape::Block_Nr0;
130  index_t kr_0 = hidden_size / BlockShape::Block_Kr0;
131 
132  const GDataType* g_ptr =
133  g_window_.get_bottom_tensor_view().get_buffer_view().p_data_;
134  const GDataType* u_ptr = g_ptr + (nr_0 / 2) * kr_0 * number<BlockShape::Block_W0>{};
135 
136  const auto u_view_ = make_naive_tensor_view<address_space_enum::global>(
137  u_ptr,
139  make_tuple(kr_0 * BlockShape::Block_W0, number<BlockShape::Block_W0>{}, 1),
141  number<1>{});
142  const auto u_view_1_ =
143  pad_tensor_view(u_view_,
148  return u_view_1_;
149  }
150  }();
151 
152  auto a_win = make_tile_window_linear(
153  a_window_, Policy::template MakeGlobalTileDistribution_A<Problem>());
154  auto g_win =
155  make_tile_window_linear(g_window_,
156  Policy::template MakeGlobalTileDistribution_G<Problem>(),
158  auto d_win =
159  make_tile_window_linear(d_window_,
160  Policy::template MakeGlobalTileDistribution_D<Problem>(),
162  auto o_win = make_tile_window_linear(
163  o_window_, Policy::template MakeGlobalTileDistribution_O<Problem>());
164 
165  using g_thread_type = decltype(load_tile(g_win));
166  using d_thread_type = decltype(load_tile(d_win));
167 
168  using WarpGemm0 = decltype(Policy::template GetWarpGemm0<Problem>());
169  using WarpGemm1 = decltype(Policy::template GetWarpGemm1<Problem>());
170  auto warp_gemm_0 = WarpGemm0{};
171  auto warp_gemm_1 = WarpGemm1{};
172 
173  // issues_warps_lanes
174  auto a_sst_win0 =
175  make_tile_window(make_tensor_view<address_space_enum::lds>(
176  smem_0, Policy::template MakeLdsStoreDesc_A<Problem>()),
177  Policy::template MakeLdsStoreDesc_A<Problem>().get_lengths(),
178  {0, 0, 0});
179 
180  auto a_sst_win1 =
181  make_tile_window(make_tensor_view<address_space_enum::lds>(
182  smem_1, Policy::template MakeLdsStoreDesc_A<Problem>()),
183  Policy::template MakeLdsStoreDesc_A<Problem>().get_lengths(),
184  {0, 0, 0});
185  // m*k
186  auto a_sld_win0 = [&]() {
187  using WG = WarpGemm0;
188  constexpr auto a_outer_dstr_enc = tile_distribution_encoding<
189  sequence<>,
195  sequence<0, 0>>{};
196  constexpr auto a_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
197  a_outer_dstr_enc, typename WG::AWarpDstrEncoding{});
199  make_tensor_view<address_space_enum::lds>(
200  smem_0, Policy::template MakeLdsLoadDesc_A<Problem>()),
201  Policy::template MakeLdsLoadDesc_A<Problem>().get_lengths(),
202  {0, 0},
203  make_static_tile_distribution(a_block_dstr_encode));
204  }();
205 
206  // m*k
207  auto a_sld_win1 = [&]() {
208  using WG = WarpGemm0;
209  constexpr auto a_outer_dstr_enc = tile_distribution_encoding<
210  sequence<>,
216  sequence<0, 0>>{};
217  constexpr auto a_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
218  a_outer_dstr_enc, typename WG::AWarpDstrEncoding{});
220  make_tensor_view<address_space_enum::lds>(
221  smem_1, Policy::template MakeLdsLoadDesc_A<Problem>()),
222  Policy::template MakeLdsLoadDesc_A<Problem>().get_lengths(),
223  {0, 0},
224  make_static_tile_distribution(a_block_dstr_encode));
225  }();
226 
227  auto bridge_sst_win = [&]() {
228  return make_tile_window(
229  make_tensor_view<address_space_enum::lds>(
230  reinterpret_cast<YDataType*>(smem),
231  Policy::template MakeBridgeLdsStoreDesc<Problem>()),
232  Policy::template MakeBridgeLdsStoreDesc<Problem>().get_lengths(),
233  {0, 0});
234  }();
235 
236  auto bridge_sld_win = [&]() {
238  make_tensor_view<address_space_enum::lds>(
239  reinterpret_cast<YDataType*>(smem),
240  Policy::template MakeBridgeLdsLoadDesc<Problem>()),
241  Policy::template MakeBridgeLdsLoadDesc<Problem>().get_lengths(),
242  {0, 0},
243  Policy::template MakeYTileDistribution<Problem>());
244  }();
245 
246  // also OK with C array, 2 register buffer
248 
249  constexpr auto issues_a = number<a_win.get_num_of_access()>{};
250  constexpr auto issues_g = number<g_win.get_num_of_access()>{};
251  // constexpr auto issues_d = number<d_win.get_num_of_access()>{};
252  // constexpr auto issues_o = number<o_win.get_num_of_access()>{};
253  constexpr auto issues_gemm0 =
254  number<BlockShape::Repeat_M0 * BlockShape::Repeat_N0 * BlockShape::Repeat_K0 *
255  warp_gemm_0.get_num_of_access()>{};
256  constexpr auto issues_gemm1 =
257  number<BlockShape::Repeat_M1 * BlockShape::Repeat_N1 * BlockShape::Repeat_K1 *
258  warp_gemm_1.get_num_of_access()>{};
259  // constexpr auto issues_sld_a = number<a_sld_win0.get_num_of_access()>{};
260 
261  const index_t num_blocks_k0 =
262  (hidden_size + BlockShape::Block_K0 - 1) / BlockShape::Block_K0;
263  const index_t num_blocks_n1 =
264  (hidden_size + BlockShape::Block_N1 - 1) / BlockShape::Block_N1;
265 
266  using a_thread_type = decltype(load_tile(a_sld_win0));
268 
269  auto gld_a = [&]<typename PreNop = bool_constant<false>>(
270  auto& a_store_, auto i_access, PreNop = {})
271  {
272  async_load_tile_raw(a_store_, a_win, i_access, PreNop{});
273  };
274  auto move_a = [&]() {
276  };
277  auto sld_a = [&](auto& a_, auto& win_, auto i_access) {
278  load_tile_raw(a_, win_, i_access);
279  };
280 
281  auto gld_g = [&]<typename PreNop = bool_constant<false>>(
282  auto& g_, auto i_access, PreNop = {})
283  {
284  if constexpr(IsGateOnly)
285  {
286  // TODO: hack!
287  if constexpr(i_access.value == 0)
288  {
289  g_win.bottom_tensor_view_ = g_view;
290  }
291  else if constexpr(i_access.value == issues_g / 2)
292  {
293  g_win.bottom_tensor_view_ = u_view;
294  }
295  }
296  load_tile_raw(g_, g_win, i_access, FALSE, PreNop{});
297  };
298  auto move_g = [&]() {
300  };
302 
303  auto gld_d = [&]<typename PreNop = bool_constant<false>>(
304  auto& d_, auto i_access, PreNop = {})
305  {
306  load_tile_raw(d_, d_win, i_access, FALSE, PreNop{});
307  };
308  auto move_d = [&]() {
309  // d move along gemm-n
311  };
312 
313  auto atomic_add_o = [&]<typename PreNop = bool_constant<false>>(
314  auto& o_, auto i_access, PreNop = {})
315  {
316  update_tile_raw(o_win, o_, i_access, TRUE, PreNop{});
317  };
318 
319  auto acc_0 = Policy::template MakeCBlockTile_Gemm0<Problem>();
320  auto acc_1s = generate_tuple(
321  [&](auto) { return Policy::template MakeCBlockTile_Gemm1<Problem>(); }, number<2>{});
322 
323  // clang-format off
324  auto gemm_0 = [&]<typename PostNop = bool_constant<false>>
325  (auto& t_c, auto& t_a, auto& t_b, auto i_access, PostNop = {}) {
326  using WarpGemm = remove_cvref_t<decltype(warp_gemm_0)>;
327 
328  constexpr auto repeat_sub = WarpGemm::get_num_of_access();
329  constexpr auto repeat_m = BlockShape::Repeat_M0;
330  // constexpr auto repeat_n = BlockShape::Repeat_N0;
331  constexpr auto repeat_k = BlockShape::Repeat_K0;
332  // loop order n->m->k
333  constexpr auto i_sub = i_access % repeat_sub;
334  constexpr auto i_k = (i_access / repeat_sub) % repeat_k;
335  constexpr auto i_m = (i_access / (repeat_sub * repeat_k )) % repeat_m;
336  constexpr auto i_n = (i_access / (repeat_sub * repeat_k )) / repeat_m;
337 
338  using AWarpTensor = typename WarpGemm::AWarpTensor;
339  using BWarpTensor = typename WarpGemm::BWarpTensor;
340  using CWarpTensor = typename WarpGemm::CWarpTensor;
341  using AWarpDstr = typename WarpGemm::AWarpDstr;
342  using BWarpDstr = typename WarpGemm::BWarpDstr;
343  using CWarpDstr = typename WarpGemm::CWarpDstr;
344 
345  constexpr auto a_warp_y_index_zeros = uniform_sequence_gen_t<AWarpDstr::NDimY, 0>{};
346  constexpr auto b_warp_y_index_zeros = uniform_sequence_gen_t<BWarpDstr::NDimY, 0>{};
347  constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t<CWarpDstr::NDimY, 0>{};
348 
349  constexpr auto a_warp_y_lengths = to_sequence(AWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
350  constexpr auto b_warp_y_lengths = to_sequence(BWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
351  constexpr auto c_warp_y_lengths = to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
352 
353  AWarpTensor w_a;
354  w_a.get_thread_buffer() = t_a.get_y_sliced_thread_data(
355  merge_sequences(sequence<i_m, i_k>{}, a_warp_y_index_zeros),
356  merge_sequences(sequence<1, 1>{}, a_warp_y_lengths));
357 
358  BWarpTensor w_b;
359  w_b.get_thread_buffer() = t_b.get_y_sliced_thread_data(
360  merge_sequences(sequence<i_n, i_k>{}, b_warp_y_index_zeros),
361  merge_sequences(sequence<1, 1>{}, b_warp_y_lengths));
362 
363  CWarpTensor w_c;
364  w_c.get_thread_buffer() = t_c.get_y_sliced_thread_data(
365  merge_sequences(sequence<i_m, i_n>{}, c_warp_y_index_zeros),
366  merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
367 
368  warp_gemm_0(w_c, w_a, w_b, number<i_sub>{}, PostNop{});
369 
370  t_c.set_y_sliced_thread_data(
371  merge_sequences(sequence<i_m, i_n>{}, c_warp_y_index_zeros),
372  merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
373  w_c.get_thread_buffer());
374  };
375  // clang-format on
376 
377  // clang-format off
378  auto gemm_1 = [&]<typename PostNop = bool_constant<false>>
379  (auto& t_c, auto& t_a, auto& t_b, auto i_access, PostNop = {}) {
380  using WarpGemm = remove_cvref_t<decltype(warp_gemm_1)>;
381 
382  constexpr auto repeat_sub = WarpGemm::get_num_of_access();
383  constexpr auto repeat_m = BlockShape::Repeat_M0;
384  // constexpr auto repeat_n = BlockShape::Repeat_N0;
385  constexpr auto repeat_k = BlockShape::Repeat_K0;
386  // loop order n->m->k
387  constexpr auto i_sub = i_access % repeat_sub;
388  constexpr auto i_k = (i_access / repeat_sub) % repeat_k;
389  constexpr auto i_m = (i_access / (repeat_sub * repeat_k )) % repeat_m;
390  constexpr auto i_n = (i_access / (repeat_sub * repeat_k )) / repeat_m;
391 
392  using AWarpTensor = typename WarpGemm::AWarpTensor;
393  using BWarpTensor = typename WarpGemm::BWarpTensor;
394  using CWarpTensor = typename WarpGemm::CWarpTensor;
395  using AWarpDstr = typename WarpGemm::AWarpDstr;
396  using BWarpDstr = typename WarpGemm::BWarpDstr;
397  using CWarpDstr = typename WarpGemm::CWarpDstr;
398 
399  constexpr auto a_warp_y_index_zeros = uniform_sequence_gen_t<AWarpDstr::NDimY, 0>{};
400  constexpr auto b_warp_y_index_zeros = uniform_sequence_gen_t<BWarpDstr::NDimY, 0>{};
401  constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t<CWarpDstr::NDimY, 0>{};
402 
403  constexpr auto a_warp_y_lengths = to_sequence(AWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
404  constexpr auto b_warp_y_lengths = to_sequence(BWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
405  constexpr auto c_warp_y_lengths = to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
406 
407  AWarpTensor w_a;
408  w_a.get_thread_buffer() = t_a.get_y_sliced_thread_data(
409  merge_sequences(sequence<i_m, i_k>{}, a_warp_y_index_zeros),
410  merge_sequences(sequence<1, 1>{}, a_warp_y_lengths));
411 
412  BWarpTensor w_b;
413  w_b.get_thread_buffer() = t_b.get_y_sliced_thread_data(
414  merge_sequences(sequence<i_n, i_k>{}, b_warp_y_index_zeros),
415  merge_sequences(sequence<1, 1>{}, b_warp_y_lengths));
416 
417  CWarpTensor w_c;
418  w_c.get_thread_buffer() = t_c.get_y_sliced_thread_data(
419  merge_sequences(sequence<i_m, i_n>{}, c_warp_y_index_zeros),
420  merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
421 
422  warp_gemm_1(w_c, w_a, w_b, number<i_sub>{}, PostNop{});
423 
424  t_c.set_y_sliced_thread_data(
425  merge_sequences(sequence<i_m, i_n>{}, c_warp_y_index_zeros),
426  merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
427  w_c.get_thread_buffer());
428  };
429  // clang-format on
430  _Pragma("clang diagnostic pop");
431 
432  // this gemm pipeline is designed with assumption that issues of buffer-load/ds_read can
433  // be hide under mfma. In other words, issues of mfma is >= memory this is true if we
434  // pre-shuffle B matrix, and A matrix is relatively small we prefer use multiple mfma
435  // paired with 1 buffer-load B matrix, to get max throughput of buffer_load. and by
436  // preshuffle, we always pack to dwordx4 load, and this will already extend to multiple
437  // mfma but that is already consumed inside warpgemm-impl. So indeed how many extra
438  // mfma(that can reuse the B matrix) only affected by M repeat.
439  auto pipeline_gemm0 = [&]() {
440  constexpr index_t total_loops = issues_gemm0;
441  constexpr auto sr = Policy::template GetSequencer_0<Problem>();
442  static_assert(sr.size() == total_loops);
443 
444  constexpr auto c_sld_a_0 = MAKE_SC();
445  constexpr auto c_gld_a_0 = MAKE_SC();
446  constexpr auto c_gld_b_0 = MAKE_SC();
447  // compute buffer 1
448  static_for<0, total_loops, 1>{}([&](auto i_issue) {
449  gemm_0(acc_0, as[I0], gs[I0], i_issue);
450  constexpr index_t slot = sr.at(i_issue);
451 
452  if constexpr(slot & SLD_A)
453  sld_a(as[I1], a_sld_win1, number<NEXT_SCI(c_sld_a_0, i_issue)>{});
454  if constexpr(slot & GLD_A)
455  gld_a(a_sst_win0, number<NEXT_SCI(c_gld_a_0, i_issue)>{});
456  if constexpr(slot & GLD_B)
457  gld_g(gs[I0], number<NEXT_SCI(c_gld_b_0, i_issue)>{});
458  });
459  move_g();
460  move_a();
461  block_sync_load_raw(issues_a + issues_g);
462  lds_load_fence();
463 
464  constexpr auto c_sld_a_1 = MAKE_SC();
465  constexpr auto c_gld_a_1 = MAKE_SC();
466  constexpr auto c_gld_b_1 = MAKE_SC();
467 
468  // compute buffer 1
469  static_for<0, total_loops, 1>{}([&](auto i_issue) {
470  gemm_0(acc_0, as[I1], gs[I1], i_issue);
471  constexpr index_t slot = sr.at(i_issue);
472 
473  if constexpr(slot & SLD_A)
474  sld_a(as[I0], a_sld_win0, number<NEXT_SCI(c_sld_a_1, i_issue)>{});
475  if constexpr(slot & GLD_A)
476  gld_a(a_sst_win1, number<NEXT_SCI(c_gld_a_1, i_issue)>{});
477  if constexpr(slot & GLD_B)
478  gld_g(gs[I1], number<NEXT_SCI(c_gld_b_1, i_issue)>{});
479  });
480  move_g();
481  move_a();
482  block_sync_load_raw(issues_a + issues_g);
483  lds_load_fence();
484  };
485 
486  auto pipeline_gemm0_tail = [&]() {
487  constexpr index_t total_loops = issues_gemm0;
488  constexpr auto sr = Policy::template GetSequencer_0<Problem>();
489  static_assert(sr.size() == total_loops);
490 
491  constexpr auto c_gld_b_0 = MAKE_SC();
492 
493  // compute buffer 0
494  static_for<0, total_loops, 1>{}([&](auto i_issue) {
495  gemm_0(acc_0, as[I0], gs[I0], i_issue);
496  constexpr index_t slot = sr.at(i_issue);
497 
498  if constexpr(slot & GLD_B)
499  gld_g(gs[I1], number<NEXT_SCI(c_gld_b_0, i_issue)>{});
500  });
501 
502  block_sync_load_raw(issues_g);
503  sld_a(as[I1], a_sld_win1, NEG1);
504 
505  // compute buffer 1
506  static_for<0, total_loops, 1>{}([&](auto i_issue) {
507  constexpr auto last_nop = [&]() {
508  if constexpr(i_issue == (total_loops - 1))
509  return TRUE;
510  else
511  return FALSE;
512  }();
513  gemm_0(acc_0, as[I1], gs[I1], i_issue, last_nop); // last gemm has nop
514  });
515  };
516 
517  auto y = Policy::template MakeYBlockTile<Problem>();
518 
519  auto pipeline_bridge = [&]() {
520  // cast to Y data
521  auto y_pre = cast_tile<YDataType>(acc_0);
522  store_tile(bridge_sst_win, y_pre);
523  clear_tile(acc_1s(I0));
524  // wave_barrier();
525  load_tile(y, bridge_sld_win);
526  clear_tile(acc_1s(I1));
527  };
528 
529  // note, gemm-1 start from idx-1 to N-2 (0, 1, 2....N-1)
530  auto pipeline_gemm1 = [&]() {
531  constexpr index_t total_loops = issues_gemm1;
532  constexpr auto sr = Policy::template GetSequencer_1<Problem>();
533  static_assert(sr.size() == total_loops);
534 
535  constexpr auto c_gld_b_0 = MAKE_SC();
536  constexpr auto c_gst_o_0 = MAKE_SC();
537  constexpr auto c_gld_b_1 = MAKE_SC();
538  constexpr auto c_gst_o_1 = MAKE_SC();
539 
540  // compute buffer 0
541  static_for<0, total_loops, 1>{}([&](auto i_issue) {
542  gemm_1(acc_1s[I1], y, ds[I1], i_issue);
543  constexpr index_t slot = sr.at(i_issue);
544  if constexpr(slot & GLD_B)
545  gld_d(ds[I0], number<NEXT_SCI(c_gld_b_0, i_issue)>{});
546 
547  if constexpr(slot & GST_O)
548  {
549  auto out = cast_tile<ODataType>(acc_1s[I0]);
550  atomic_add_o(out, number<NEXT_SCI(c_gst_o_0, i_issue)>{});
551  }
552  });
553  move_d();
554  // move_o();
555 
556  // compute buffer 1
557  static_for<0, total_loops, 1>{}([&](auto i_issue) {
558  gemm_1(acc_1s[I0], y, ds[I0], i_issue);
559  constexpr index_t slot = sr.at(i_issue);
560  if constexpr(slot & GLD_B)
561  gld_d(ds[I1], number<NEXT_SCI(c_gld_b_1, i_issue)>{});
562 
563  if constexpr(slot & GST_O)
564  {
565  auto out = cast_tile<ODataType>(acc_1s[I1]);
566  atomic_add_o(out, number<NEXT_SCI(c_gst_o_1, i_issue)>{});
567  }
568  });
569  move_d();
570  };
571 
572  auto pipeline_gemm1_head = [&]() {
573  constexpr index_t total_loops = issues_gemm1;
574  constexpr auto sr = Policy::template GetSequencer_1<Problem>();
575  static_assert(sr.size() == total_loops);
576 
577  constexpr auto c_gld_b_0 = MAKE_SC();
578 
579  // compute buffer 0
580  static_for<0, total_loops, 1>{}([&](auto i_issue) {
581  gemm_1(acc_1s[I0], y, ds[I0], i_issue);
582  constexpr index_t slot = sr.at(i_issue);
583  if constexpr(slot & GLD_B)
584  gld_d(ds[I1], number<NEXT_SCI(c_gld_b_0, i_issue)>{});
585  });
586  move_d();
587  };
588  auto pipeline_gemm1_tail = [&]() {
589  constexpr index_t total_loops = issues_gemm1;
590  constexpr auto sr = Policy::template GetSequencer_1<Problem>();
591  static_assert(sr.size() == total_loops);
592 
593  constexpr auto c_gst_o_0 = MAKE_SC();
594 
595  // compute buffer 1
596  static_for<0, total_loops, 1>{}([&](auto i_issue) {
597  gemm_1(acc_1s[I1], y, ds[I1], i_issue);
598 
599  constexpr index_t slot = sr.at(i_issue);
600  if constexpr(slot & GST_O)
601  {
602  auto out = cast_tile<ODataType>(acc_1s[I0]);
603  atomic_add_o(out, number<NEXT_SCI(c_gst_o_0, i_issue)>{});
604  }
605  });
606  {
607  auto out = cast_tile<ODataType>(acc_1s[I1]);
608  atomic_add_o(out, NEG1);
609  }
610  };
611 
612  // start of pipeline
613  // clang-format off
614  gld_a(a_sst_win0, NEG1, TRUE);
615  gld_g(gs[I0], NEG1, TRUE);
616  move_a();
617  move_g();
618  clear_tile(acc_0);
619 
620  // preload for next round
621  gld_a(a_sst_win1, NEG1);
622  gld_g(gs[I1], NEG1);
623 
624  // make sure a,g loaded
625  block_sync_load_raw(issues_a + issues_g);
626  lds_load_fence();
627 
628  // we manually unroll double buffer inside hot loop
629  const index_t iters_0 = (num_blocks_k0 - 2) / 2;
630  index_t i_0 = 0; // (void)i_0; (void)iters_0; (void)pipeline_gemm0;
631  while(i_0++ < iters_0)
632  {
633  pipeline_gemm0();
634  }
635  pipeline_gemm0_tail();
636 
637  pipeline_bridge();
638 
639  const index_t iters_1 = (num_blocks_n1 - 2) / 2;
640  index_t i_1 = 0; // (void) i_1; (void)iters_1; (void)pipeline_gemm1;
641  pipeline_gemm1_head();
642  while(i_1++ < iters_1)
643  {
644  pipeline_gemm1();
645  }
646  pipeline_gemm1_tail();
647  // clang-format on
648  }
649 };
650 
651 } // namespace ck_tile
#define CK_TILE_DEVICE
Definition: config.hpp:40
#define CK_TILE_LDS_ADDR
Definition: config.hpp:56
#define CK_TILE_HOST_DEVICE
Definition: config.hpp:41
constexpr CK_TILE_HOST_DEVICE auto make_embed_tile_distribution_encoding(OuterDstr, InnerDstr)
Definition: tile_distribution_encoding.hpp:420
Definition: cluster_descriptor.hpp:13
CK_TILE_DEVICE auto async_load_tile_raw(LdsTileWindow_ &&lds_tile, const tile_window_with_static_distribution< BottomTensorView_, WindowLengths_, TileDistribution_, NumCoord > &tile_window, number< i_access >={}, bool_constant< oob_conditional_check >={}, bool_constant< pre_nop >={})
Definition: load_tile.hpp:149
tuple_array< T, N > statically_indexed_array
Definition: statically_indexed_array.hpp:16
int32_t index_t
Definition: integer.hpp:9
CK_TILE_DEVICE void lds_load_fence(index_t cnt=0)
Definition: amd_buffer_addressing.hpp:624
constexpr CK_TILE_HOST_DEVICE auto pad_tensor_view(const TensorView &tensor_view, const TileLengths &tile_lengths, DoPads)
Definition: tensor_view.hpp:480
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition: type_traits.hpp:20
CK_TILE_DEVICE auto load_tile(const tile_window_with_static_distribution< BottomTensorView_, WindowLengths_, TileDistribution_, NumCoord > &tile_window, number< i_access >={}, bool_constant< oob_conditional_check >={})
Definition: load_tile.hpp:27
constexpr CK_TILE_HOST_DEVICE auto to_sequence(tuple< number< Is >... >)
Definition: sequence.hpp:1046
constexpr CK_TILE_HOST_DEVICE auto merge_sequences(Seqs...)
Definition: sequence.hpp:817
constexpr CK_TILE_DEVICE auto make_tile_window(null_tensor_view, const WindowLengths &window_lengths, const multi_index< WindowLengths::size()> &, Ts &&...)
Definition: null_tile_window.hpp:72
CK_TILE_DEVICE auto load_tile_raw(T &tile, const tile_window_with_static_distribution< BottomTensorView_, WindowLengths_, TileDistribution_, NumCoord > &tile_window, number< i_access >={}, bool_constant< oob_conditional_check >={}, bool_constant< pre_nop >={})
Loads a tile of data using inline assembly.
Definition: load_tile.hpp:106
CK_TILE_DEVICE void move_tile_window(null_tile_window< WindowLengths > &, const typename null_tile_window< WindowLengths >::BottomTensorIndex &)
Definition: null_tile_window.hpp:92
constexpr CK_TILE_DEVICE auto make_tile_window_linear(const TensorView_ &tensor_view, const WindowLengths_ &window_lengths, const multi_index< TensorView_::get_num_of_dimension()> &origin, const StaticTileDistribution_ &tile_distribution, LinearBottomDims_={})
Definition: tile_window_linear.hpp:1124
constexpr CK_TILE_HOST_DEVICE auto generate_tuple(F &&f, number< N >)
Definition: tuple.hpp:400
constexpr CK_TILE_HOST_DEVICE auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:337
CK_TILE_DEVICE void store_tile(tile_window_with_static_lengths< BottomTensorView_, WindowLengths_ > &tile_window_tmp, const static_distributed_tensor< DataType_, TileDistribution_ > &dstr_tensor)
Definition: store_tile.hpp:23
CK_TILE_DEVICE void clear_tile(DstrTensors &dstr_tensor)
Definition: tile_elementwise.hpp:145
CK_TILE_DEVICE void block_sync_load_raw(index_t cnt=0)
Definition: arch.hpp:95
constexpr CK_TILE_HOST_DEVICE auto make_static_tile_distribution(StaticTileDistributionEncoding_)
Definition: tile_distribution.hpp:498
typename uniform_sequence_gen< NSize, I >::type uniform_sequence_gen_t
Definition: sequence.hpp:1017
CK_TILE_DEVICE void update_tile_raw(tile_window_with_static_distribution< BottomTensorView_, WindowLengths_, TileDistribution_, NumCoord > &tile_window, const static_distributed_tensor< DataType_, TileDistribution_ > &dstr_tensor, number< i_access >={}, bool_constant< oob_conditional_check >={}, bool_constant< pre_nop >={})
Definition: update_tile.hpp:68
#define NEXT_SCI(c_, static_i_)
Definition: static_counter.hpp:109
#define MAKE_SC()
Definition: static_counter.hpp:104
Definition: fused_moegemm_pipeline_flatmm_ex.hpp:23
static constexpr index_t SLD_A
Definition: fused_moegemm_pipeline_flatmm_ex.hpp:54
static constexpr index_t kAlignmentA
Definition: fused_moegemm_pipeline_flatmm_ex.hpp:49
static constexpr bool PadHiddenSize
Definition: fused_moegemm_pipeline_flatmm_ex.hpp:46
static constexpr CK_TILE_HOST_DEVICE ck_tile::index_t GetSmemSize_A()
Definition: fused_moegemm_pipeline_flatmm_ex.hpp:72
typename Problem::DScaleDataType DScaleDataType
Definition: fused_moegemm_pipeline_flatmm_ex.hpp:36
typename Problem::BlockShape BlockShape
Definition: fused_moegemm_pipeline_flatmm_ex.hpp:27
static constexpr index_t kAlignmentO
Definition: fused_moegemm_pipeline_flatmm_ex.hpp:52
typename Problem::IndexDataType IndexDataType
Definition: fused_moegemm_pipeline_flatmm_ex.hpp:39
static constexpr index_t kBlockPerCu
Definition: fused_moegemm_pipeline_flatmm_ex.hpp:59
remove_cvref_t< Policy_ > Policy
Definition: fused_moegemm_pipeline_flatmm_ex.hpp:25
static constexpr const char * name
Definition: fused_moegemm_pipeline_flatmm_ex.hpp:69
static constexpr index_t GLD_B
Definition: fused_moegemm_pipeline_flatmm_ex.hpp:56
static constexpr index_t GLD_A
Definition: fused_moegemm_pipeline_flatmm_ex.hpp:55
remove_cvref_t< Problem_ > Problem
Definition: fused_moegemm_pipeline_flatmm_ex.hpp:24
typename Problem::ADataType ADataType
Definition: fused_moegemm_pipeline_flatmm_ex.hpp:29
static CK_TILE_HOST_DEVICE auto GetOCoord()
Definition: fused_moegemm_pipeline_flatmm_ex.hpp:91
static constexpr CK_TILE_HOST_DEVICE ck_tile::index_t GetSmemSize()
Definition: fused_moegemm_pipeline_flatmm_ex.hpp:77
typename Problem::GDataType GDataType
Definition: fused_moegemm_pipeline_flatmm_ex.hpp:30
static constexpr index_t kAlignmentG
Definition: fused_moegemm_pipeline_flatmm_ex.hpp:50
typename Problem::DDataType DDataType
Definition: fused_moegemm_pipeline_flatmm_ex.hpp:31
static CK_TILE_HOST_DEVICE auto GetACoord()
Definition: fused_moegemm_pipeline_flatmm_ex.hpp:83
typename Problem::ODataType ODataType
Definition: fused_moegemm_pipeline_flatmm_ex.hpp:33
typename Problem::GScaleDataType GScaleDataType
Definition: fused_moegemm_pipeline_flatmm_ex.hpp:35
typename Problem::AccDataType AccDataType
Definition: fused_moegemm_pipeline_flatmm_ex.hpp:32
static constexpr index_t GST_O
Definition: fused_moegemm_pipeline_flatmm_ex.hpp:57
typename Problem::TopkWeightDataType TopkWeightDataType
Definition: fused_moegemm_pipeline_flatmm_ex.hpp:38
static constexpr index_t kAlignmentD
Definition: fused_moegemm_pipeline_flatmm_ex.hpp:51
typename Problem::YDataType YDataType
Definition: fused_moegemm_pipeline_flatmm_ex.hpp:40
static constexpr bool IsGateOnly
Definition: fused_moegemm_pipeline_flatmm_ex.hpp:44
typename Problem::Traits Traits
Definition: fused_moegemm_pipeline_flatmm_ex.hpp:42
typename Problem::YSmoothScaleDataType YSmoothScaleDataType
Definition: fused_moegemm_pipeline_flatmm_ex.hpp:37
static constexpr bool PadIntermediateSize
Definition: fused_moegemm_pipeline_flatmm_ex.hpp:47
typename Problem::AScaleDataType AScaleDataType
Definition: fused_moegemm_pipeline_flatmm_ex.hpp:34
static constexpr bool UseSmoothQuant
Definition: fused_moegemm_pipeline_flatmm_ex.hpp:45
CK_TILE_DEVICE auto operator()(const AWindow &a_window_, const GWindow &g_window_, const DWindow &d_window_, OWindow &o_window_, TopkWeightDataType, CK_TILE_LDS_ADDR void *smem, index_t hidden_size, index_t intermediate_size)
Definition: fused_moegemm_pipeline_flatmm_ex.hpp:99
Definition: integral_constant.hpp:13
static constexpr value_type value
Definition: integral_constant.hpp:16
Definition: sequence.hpp:52
Definition: functional.hpp:43
Definition: tile_distribution_encoding.hpp:26
Definition: tuple.hpp:192