/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm_ex.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm_ex.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm_ex.hpp Source File
fused_moegemm_pipeline_flatmm_ex.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include "ck_tile/core.hpp"
9 
10 namespace ck_tile {
11 
12 /*
13 This pipeline deal with a gemm(actually 2 gemm) with one very small(token), one very big(weight)
14 we need to design the pipeline such that all waves along gemm-N dim (gemm-m only 1 wave)
15 
16  <----- gemm-N ------>
17  +----+----+----+----+
18  | w0 | w1 | w2 | w3 | gemm-m
19  +----+----+----+----+
20 */
21 template <typename Problem_, typename Policy_ = FusedMoeGemmPipelineFlatmmPolicy>
23 {
26 
27  using BlockShape = typename Problem::BlockShape; // this is FusedMoeGemmShape
28 
29  using ADataType = typename Problem::ADataType;
30  using GDataType = typename Problem::GDataType;
31  using DDataType = typename Problem::DDataType;
32  using AccDataType = typename Problem::AccDataType;
33  using ODataType = typename Problem::ODataType;
34  using AScaleDataType = typename Problem::AScaleDataType;
35  using GScaleDataType = typename Problem::GScaleDataType;
36  using DScaleDataType = typename Problem::DScaleDataType;
37  using YSmoothScaleDataType = typename Problem::YSmoothScaleDataType;
38  using TopkWeightDataType = typename Problem::TopkWeightDataType;
39  using IndexDataType = typename Problem::IndexDataType;
40  using YDataType = typename Problem::YDataType;
41 
42  using Traits = typename Problem::Traits;
43 
44  static constexpr bool IsGateOnly = Traits::IsGateOnly;
45  static constexpr bool UseSmoothQuant = Traits::UseSmoothQuant;
46  static constexpr bool PadHiddenSize = Traits::PadHiddenSize;
47  static constexpr bool PadIntermediateSize = Traits::PadIntermediateSize;
48 
49  static constexpr index_t kAlignmentA = Policy::template GetAlignment_A<Problem>();
50  static constexpr index_t kAlignmentG = Policy::template GetAlignment_G<Problem>();
51  static constexpr index_t kAlignmentD = Policy::template GetAlignment_D<Problem>();
52  static constexpr index_t kAlignmentO = Policy::template GetAlignment_O<Problem>();
53 
58 
59  static constexpr index_t kBlockPerCu = []() {
60  if constexpr(Problem::kBlockPerCu != -1)
61  return Problem::kBlockPerCu;
62  else
63  {
64  // minimize occupancy
65  return 2;
66  }
67  }();
68 
69  static constexpr const char* name = "fused_moe_flatmm";
70 
71  // TODO: there are multiple buffers
73  {
74  return Policy::template GetSmemSize_A<Problem>();
75  }
76 
78  {
79  return Policy::template GetSmemSize<Problem>();
80  }
81 
82  // this is the thread-offset along row/col
84  {
85  constexpr auto a_dist = Policy::template MakeGlobalTileDistribution_A<Problem>();
86  const auto a_coord = a_dist.calculate_index();
87  return a_coord;
88  }
89 
90  // this is the thread-offset along row/col
92  {
93  constexpr auto o_dist = Policy::template MakeOGlobalTileDistribution<Problem>();
94  const auto o_coord = o_dist.calculate_index();
95  return o_coord;
96  }
97 
98  template <typename AWindow, typename GWindow, typename DWindow, typename OWindow>
99  CK_TILE_DEVICE auto operator()(const AWindow& a_window_,
100  const GWindow& g_window_,
101  const DWindow& d_window_,
102  OWindow& o_window_,
103  TopkWeightDataType /*topk_weight*/,
104  CK_TILE_LDS_ADDR void* smem,
105  index_t hidden_size,
106  index_t intermediate_size)
107  {
108  _Pragma("clang diagnostic push") _Pragma("clang diagnostic ignored \"-Wc++20-extensions\"");
109  constexpr auto NEG1 = number<-1>{};
110  constexpr auto I0 = number<0>{};
111  constexpr auto I1 = number<1>{};
112  constexpr auto TRUE = bool_constant<true>{};
113  constexpr auto FALSE = bool_constant<false>{};
114 
115  CK_TILE_LDS_ADDR ADataType* smem_0 = reinterpret_cast<CK_TILE_LDS_ADDR ADataType*>(smem);
116  CK_TILE_LDS_ADDR ADataType* smem_1 = reinterpret_cast<CK_TILE_LDS_ADDR ADataType*>(
117  reinterpret_cast<CK_TILE_LDS_ADDR char*>(smem) +
118  Policy::template GetSmemSize_A<Problem>());
119 
120  auto g_view = g_window_.get_bottom_tensor_view();
121 
122  auto u_view = [&]() {
123  if constexpr(IsGateOnly)
124  {
125  return g_view;
126  }
127  else
128  {
129  index_t nr_0 = intermediate_size / BlockShape::Block_Nr0;
130  index_t kr_0 = hidden_size / BlockShape::Block_Kr0;
131 
132  const GDataType* g_ptr =
133  g_window_.get_bottom_tensor_view().get_buffer_view().p_data_;
134  const GDataType* u_ptr = g_ptr + (nr_0 / 2) * kr_0 * number<BlockShape::Block_W0>{};
135 
136  const auto u_view_ = make_naive_tensor_view<address_space_enum::global>(
137  u_ptr,
139  make_tuple(kr_0 * BlockShape::Block_W0, number<BlockShape::Block_W0>{}, 1),
141  number<1>{});
142  const auto u_view_1_ =
143  pad_tensor_view(u_view_,
148  return u_view_1_;
149  }
150  }();
151 
152  auto a_win = make_tile_window_linear(
153  a_window_, Policy::template MakeGlobalTileDistribution_A<Problem>());
154  auto g_win =
155  make_tile_window_linear(g_window_,
156  Policy::template MakeGlobalTileDistribution_G<Problem>(),
158  auto d_win =
159  make_tile_window_linear(d_window_,
160  Policy::template MakeGlobalTileDistribution_D<Problem>(),
162  auto o_win = make_tile_window_linear(
163  o_window_, Policy::template MakeGlobalTileDistribution_O<Problem>());
164 
165  using g_thread_type = decltype(load_tile(g_win));
166  using d_thread_type = decltype(load_tile(d_win));
167 
168  using WarpGemm0 = decltype(Policy::template GetWarpGemm0<Problem>());
169  using WarpGemm1 = decltype(Policy::template GetWarpGemm1<Problem>());
170  auto warp_gemm_0 = WarpGemm0{};
171  auto warp_gemm_1 = WarpGemm1{};
172 
173  // issues_warps_lanes
174  auto a_sst_win0 =
175  make_tile_window(make_tensor_view<address_space_enum::lds>(
176  smem_0, Policy::template MakeLdsStoreDesc_A<Problem>()),
177  Policy::template MakeLdsStoreDesc_A<Problem>().get_lengths(),
178  {0, 0, 0});
179 
180  auto a_sst_win1 =
181  make_tile_window(make_tensor_view<address_space_enum::lds>(
182  smem_1, Policy::template MakeLdsStoreDesc_A<Problem>()),
183  Policy::template MakeLdsStoreDesc_A<Problem>().get_lengths(),
184  {0, 0, 0});
185  // m*k
186  auto a_sld_win0 = [&]() {
187  using WG = WarpGemm0;
188  constexpr auto a_outer_dstr_enc = tile_distribution_encoding<
189  sequence<>,
195  sequence<0, 0>>{};
196  constexpr auto a_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
197  a_outer_dstr_enc, typename WG::AWarpDstrEncoding{});
199  make_tensor_view<address_space_enum::lds>(
200  smem_0, Policy::template MakeLdsLoadDesc_A<Problem>()),
201  Policy::template MakeLdsLoadDesc_A<Problem>().get_lengths(),
202  {0, 0},
203  make_static_tile_distribution(a_block_dstr_encode));
204  }();
205 
206  // m*k
207  auto a_sld_win1 = [&]() {
208  using WG = WarpGemm0;
209  constexpr auto a_outer_dstr_enc = tile_distribution_encoding<
210  sequence<>,
216  sequence<0, 0>>{};
217  constexpr auto a_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
218  a_outer_dstr_enc, typename WG::AWarpDstrEncoding{});
220  make_tensor_view<address_space_enum::lds>(
221  smem_1, Policy::template MakeLdsLoadDesc_A<Problem>()),
222  Policy::template MakeLdsLoadDesc_A<Problem>().get_lengths(),
223  {0, 0},
224  make_static_tile_distribution(a_block_dstr_encode));
225  }();
226 
227  auto bridge_sst_win = [&]() {
228  return make_tile_window(
229  make_tensor_view<address_space_enum::lds>(
230  reinterpret_cast<YDataType*>(smem),
231  Policy::template MakeBridgeLdsStoreDesc<Problem>()),
232  Policy::template MakeBridgeLdsStoreDesc<Problem>().get_lengths(),
233  {0, 0});
234  }();
235 
236  auto bridge_sld_win = [&]() {
238  make_tensor_view<address_space_enum::lds>(
239  reinterpret_cast<YDataType*>(smem),
240  Policy::template MakeBridgeLdsLoadDesc<Problem>()),
241  Policy::template MakeBridgeLdsLoadDesc<Problem>().get_lengths(),
242  {0, 0},
243  Policy::template MakeYTileDistribution<Problem>());
244  }();
245 
246  // also OK with C array, 2 register buffer
248 
249  constexpr auto issues_a = number<a_win.get_num_of_access()>{};
250  constexpr auto issues_g = number<g_win.get_num_of_access()>{};
251  // constexpr auto issues_d = number<d_win.get_num_of_access()>{};
252  // constexpr auto issues_o = number<o_win.get_num_of_access()>{};
253  constexpr auto issues_gemm0 =
254  number<BlockShape::Repeat_M0 * BlockShape::Repeat_N0 * BlockShape::Repeat_K0 *
255  warp_gemm_0.get_num_of_access()>{};
256  constexpr auto issues_gemm1 =
257  number<BlockShape::Repeat_M1 * BlockShape::Repeat_N1 * BlockShape::Repeat_K1 *
258  warp_gemm_1.get_num_of_access()>{};
259  // constexpr auto issues_sld_a = number<a_sld_win0.get_num_of_access()>{};
260 
261  const index_t num_blocks_k0 =
262  (hidden_size + BlockShape::Block_K0 - 1) / BlockShape::Block_K0;
263  const index_t num_blocks_n1 =
264  (hidden_size + BlockShape::Block_N1 - 1) / BlockShape::Block_N1;
265 
266  using a_thread_type = decltype(load_tile(a_sld_win0));
268 
269  auto gld_a = [&]<typename PreNop = bool_constant<false>>(
270  auto& a_store_, auto i_access, PreNop = {}) {
271  async_load_tile_raw(a_store_, a_win, i_access, PreNop{});
272  };
273  auto move_a = [&]() {
275  };
276  auto sld_a = [&](auto& a_, auto& win_, auto i_access) {
277  load_tile_raw(a_, win_, i_access);
278  };
279 
280  auto gld_g =
281  [&]<typename PreNop = bool_constant<false>>(auto& g_, auto i_access, PreNop = {}) {
282  if constexpr(IsGateOnly)
283  {
284  // TODO: hack!
285  if constexpr(i_access.value == 0)
286  {
287  g_win.bottom_tensor_view_ = g_view;
288  }
289  else if constexpr(i_access.value == issues_g / 2)
290  {
291  g_win.bottom_tensor_view_ = u_view;
292  }
293  }
294  load_tile_raw(g_, g_win, i_access, FALSE, PreNop{});
295  };
296  auto move_g = [&]() {
298  };
300 
301  auto gld_d =
302  [&]<typename PreNop = bool_constant<false>>(auto& d_, auto i_access, PreNop = {}) {
303  load_tile_raw(d_, d_win, i_access, FALSE, PreNop{});
304  };
305  auto move_d = [&]() {
306  // d move along gemm-n
308  };
309 
310  auto atomic_add_o =
311  [&]<typename PreNop = bool_constant<false>>(auto& o_, auto i_access, PreNop = {}) {
312  update_tile_raw(o_win, o_, i_access, TRUE, PreNop{});
313  };
314 
315  auto acc_0 = Policy::template MakeCBlockTile_Gemm0<Problem>();
316  auto acc_1s = generate_tuple(
317  [&](auto) { return Policy::template MakeCBlockTile_Gemm1<Problem>(); }, number<2>{});
318 
319  // clang-format off
320  auto gemm_0 = [&]<typename PostNop = bool_constant<false>>
321  (auto& t_c, auto& t_a, auto& t_b, auto i_access, PostNop = {}) {
322  using WarpGemm = remove_cvref_t<decltype(warp_gemm_0)>;
323 
324  constexpr auto repeat_sub = WarpGemm::get_num_of_access();
325  constexpr auto repeat_m = BlockShape::Repeat_M0;
326  // constexpr auto repeat_n = BlockShape::Repeat_N0;
327  constexpr auto repeat_k = BlockShape::Repeat_K0;
328  // loop order n->m->k
329  constexpr auto i_sub = i_access % repeat_sub;
330  constexpr auto i_k = (i_access / repeat_sub) % repeat_k;
331  constexpr auto i_m = (i_access / (repeat_sub * repeat_k )) % repeat_m;
332  constexpr auto i_n = (i_access / (repeat_sub * repeat_k )) / repeat_m;
333 
334  using AWarpTensor = typename WarpGemm::AWarpTensor;
335  using BWarpTensor = typename WarpGemm::BWarpTensor;
336  using CWarpTensor = typename WarpGemm::CWarpTensor;
337  using AWarpDstr = typename WarpGemm::AWarpDstr;
338  using BWarpDstr = typename WarpGemm::BWarpDstr;
339  using CWarpDstr = typename WarpGemm::CWarpDstr;
340 
341  constexpr auto a_warp_y_index_zeros = uniform_sequence_gen_t<AWarpDstr::NDimY, 0>{};
342  constexpr auto b_warp_y_index_zeros = uniform_sequence_gen_t<BWarpDstr::NDimY, 0>{};
343  constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t<CWarpDstr::NDimY, 0>{};
344 
345  constexpr auto a_warp_y_lengths = to_sequence(AWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
346  constexpr auto b_warp_y_lengths = to_sequence(BWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
347  constexpr auto c_warp_y_lengths = to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
348 
349  AWarpTensor w_a;
350  w_a.get_thread_buffer() = t_a.get_y_sliced_thread_data(
351  merge_sequences(sequence<i_m, i_k>{}, a_warp_y_index_zeros),
352  merge_sequences(sequence<1, 1>{}, a_warp_y_lengths));
353 
354  BWarpTensor w_b;
355  w_b.get_thread_buffer() = t_b.get_y_sliced_thread_data(
356  merge_sequences(sequence<i_n, i_k>{}, b_warp_y_index_zeros),
357  merge_sequences(sequence<1, 1>{}, b_warp_y_lengths));
358 
359  CWarpTensor w_c;
360  w_c.get_thread_buffer() = t_c.get_y_sliced_thread_data(
361  merge_sequences(sequence<i_m, i_n>{}, c_warp_y_index_zeros),
362  merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
363 
364  warp_gemm_0(w_c, w_a, w_b, number<i_sub>{}, PostNop{});
365 
366  t_c.set_y_sliced_thread_data(
367  merge_sequences(sequence<i_m, i_n>{}, c_warp_y_index_zeros),
368  merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
369  w_c.get_thread_buffer());
370  };
371  // clang-format on
372 
373  // clang-format off
374  auto gemm_1 = [&]<typename PostNop = bool_constant<false>>
375  (auto& t_c, auto& t_a, auto& t_b, auto i_access, PostNop = {}) {
376  using WarpGemm = remove_cvref_t<decltype(warp_gemm_1)>;
377 
378  constexpr auto repeat_sub = WarpGemm::get_num_of_access();
379  constexpr auto repeat_m = BlockShape::Repeat_M0;
380  // constexpr auto repeat_n = BlockShape::Repeat_N0;
381  constexpr auto repeat_k = BlockShape::Repeat_K0;
382  // loop order n->m->k
383  constexpr auto i_sub = i_access % repeat_sub;
384  constexpr auto i_k = (i_access / repeat_sub) % repeat_k;
385  constexpr auto i_m = (i_access / (repeat_sub * repeat_k )) % repeat_m;
386  constexpr auto i_n = (i_access / (repeat_sub * repeat_k )) / repeat_m;
387 
388  using AWarpTensor = typename WarpGemm::AWarpTensor;
389  using BWarpTensor = typename WarpGemm::BWarpTensor;
390  using CWarpTensor = typename WarpGemm::CWarpTensor;
391  using AWarpDstr = typename WarpGemm::AWarpDstr;
392  using BWarpDstr = typename WarpGemm::BWarpDstr;
393  using CWarpDstr = typename WarpGemm::CWarpDstr;
394 
395  constexpr auto a_warp_y_index_zeros = uniform_sequence_gen_t<AWarpDstr::NDimY, 0>{};
396  constexpr auto b_warp_y_index_zeros = uniform_sequence_gen_t<BWarpDstr::NDimY, 0>{};
397  constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t<CWarpDstr::NDimY, 0>{};
398 
399  constexpr auto a_warp_y_lengths = to_sequence(AWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
400  constexpr auto b_warp_y_lengths = to_sequence(BWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
401  constexpr auto c_warp_y_lengths = to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
402 
403  AWarpTensor w_a;
404  w_a.get_thread_buffer() = t_a.get_y_sliced_thread_data(
405  merge_sequences(sequence<i_m, i_k>{}, a_warp_y_index_zeros),
406  merge_sequences(sequence<1, 1>{}, a_warp_y_lengths));
407 
408  BWarpTensor w_b;
409  w_b.get_thread_buffer() = t_b.get_y_sliced_thread_data(
410  merge_sequences(sequence<i_n, i_k>{}, b_warp_y_index_zeros),
411  merge_sequences(sequence<1, 1>{}, b_warp_y_lengths));
412 
413  CWarpTensor w_c;
414  w_c.get_thread_buffer() = t_c.get_y_sliced_thread_data(
415  merge_sequences(sequence<i_m, i_n>{}, c_warp_y_index_zeros),
416  merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
417 
418  warp_gemm_1(w_c, w_a, w_b, number<i_sub>{}, PostNop{});
419 
420  t_c.set_y_sliced_thread_data(
421  merge_sequences(sequence<i_m, i_n>{}, c_warp_y_index_zeros),
422  merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
423  w_c.get_thread_buffer());
424  };
425  // clang-format on
426  _Pragma("clang diagnostic pop");
427 
428  // this gemm pipeline is designed with assumption that issues of buffer-load/ds_read can
429  // be hide under mfma. In other words, issues of mfma is >= memory this is true if we
430  // pre-shuffle B matrix, and A matrix is relatively small we prefer use multiple mfma
431  // paired with 1 buffer-load B matrix, to get max throughput of buffer_load. and by
432  // preshuffle, we always pack to dwordx4 load, and this will already extend to multiple
433  // mfma but that is already consumed inside warpgemm-impl. So indeed how many extra
434  // mfma(that can reuse the B matrix) only affected by M repeat.
435  auto pipeline_gemm0 = [&]() {
436  constexpr index_t total_loops = issues_gemm0;
437  constexpr auto sr = Policy::template GetSequencer_0<Problem>();
438  static_assert(sr.size() == total_loops);
439 
440  constexpr auto c_sld_a_0 = MAKE_SC();
441  constexpr auto c_gld_a_0 = MAKE_SC();
442  constexpr auto c_gld_b_0 = MAKE_SC();
443  // compute buffer 1
444  static_for<0, total_loops, 1>{}([&](auto i_issue) {
445  gemm_0(acc_0, as[I0], gs[I0], i_issue);
446  constexpr index_t slot = sr.at(i_issue);
447 
448  if constexpr(slot & SLD_A)
449  sld_a(as[I1], a_sld_win1, number<NEXT_SCI(c_sld_a_0, i_issue)>{});
450  if constexpr(slot & GLD_A)
451  gld_a(a_sst_win0, number<NEXT_SCI(c_gld_a_0, i_issue)>{});
452  if constexpr(slot & GLD_B)
453  gld_g(gs[I0], number<NEXT_SCI(c_gld_b_0, i_issue)>{});
454  });
455  move_g();
456  move_a();
457  block_sync_load_raw(issues_a + issues_g);
458  lds_load_fence();
459 
460  constexpr auto c_sld_a_1 = MAKE_SC();
461  constexpr auto c_gld_a_1 = MAKE_SC();
462  constexpr auto c_gld_b_1 = MAKE_SC();
463 
464  // compute buffer 1
465  static_for<0, total_loops, 1>{}([&](auto i_issue) {
466  gemm_0(acc_0, as[I1], gs[I1], i_issue);
467  constexpr index_t slot = sr.at(i_issue);
468 
469  if constexpr(slot & SLD_A)
470  sld_a(as[I0], a_sld_win0, number<NEXT_SCI(c_sld_a_1, i_issue)>{});
471  if constexpr(slot & GLD_A)
472  gld_a(a_sst_win1, number<NEXT_SCI(c_gld_a_1, i_issue)>{});
473  if constexpr(slot & GLD_B)
474  gld_g(gs[I1], number<NEXT_SCI(c_gld_b_1, i_issue)>{});
475  });
476  move_g();
477  move_a();
478  block_sync_load_raw(issues_a + issues_g);
479  lds_load_fence();
480  };
481 
482  auto pipeline_gemm0_tail = [&]() {
483  constexpr index_t total_loops = issues_gemm0;
484  constexpr auto sr = Policy::template GetSequencer_0<Problem>();
485  static_assert(sr.size() == total_loops);
486 
487  constexpr auto c_gld_b_0 = MAKE_SC();
488 
489  // compute buffer 0
490  static_for<0, total_loops, 1>{}([&](auto i_issue) {
491  gemm_0(acc_0, as[I0], gs[I0], i_issue);
492  constexpr index_t slot = sr.at(i_issue);
493 
494  if constexpr(slot & GLD_B)
495  gld_g(gs[I1], number<NEXT_SCI(c_gld_b_0, i_issue)>{});
496  });
497 
498  block_sync_load_raw(issues_g);
499  sld_a(as[I1], a_sld_win1, NEG1);
500 
501  // compute buffer 1
502  static_for<0, total_loops, 1>{}([&](auto i_issue) {
503  constexpr auto last_nop = [&]() {
504  if constexpr(i_issue == (total_loops - 1))
505  return TRUE;
506  else
507  return FALSE;
508  }();
509  gemm_0(acc_0, as[I1], gs[I1], i_issue, last_nop); // last gemm has nop
510  });
511  };
512 
513  auto y = Policy::template MakeYBlockTile<Problem>();
514 
515  auto pipeline_bridge = [&]() {
516  // cast to Y data
517  auto y_pre = cast_tile<YDataType>(acc_0);
518  store_tile(bridge_sst_win, y_pre);
519  clear_tile(acc_1s(I0));
520  // wave_barrier();
521  load_tile(y, bridge_sld_win);
522  clear_tile(acc_1s(I1));
523  };
524 
525  // note, gemm-1 start from idx-1 to N-2 (0, 1, 2....N-1)
526  auto pipeline_gemm1 = [&]() {
527  constexpr index_t total_loops = issues_gemm1;
528  constexpr auto sr = Policy::template GetSequencer_1<Problem>();
529  static_assert(sr.size() == total_loops);
530 
531  constexpr auto c_gld_b_0 = MAKE_SC();
532  constexpr auto c_gst_o_0 = MAKE_SC();
533  constexpr auto c_gld_b_1 = MAKE_SC();
534  constexpr auto c_gst_o_1 = MAKE_SC();
535 
536  // compute buffer 0
537  static_for<0, total_loops, 1>{}([&](auto i_issue) {
538  gemm_1(acc_1s[I1], y, ds[I1], i_issue);
539  constexpr index_t slot = sr.at(i_issue);
540  if constexpr(slot & GLD_B)
541  gld_d(ds[I0], number<NEXT_SCI(c_gld_b_0, i_issue)>{});
542 
543  if constexpr(slot & GST_O)
544  {
545  auto out = cast_tile<ODataType>(acc_1s[I0]);
546  atomic_add_o(out, number<NEXT_SCI(c_gst_o_0, i_issue)>{});
547  }
548  });
549  move_d();
550  // move_o();
551 
552  // compute buffer 1
553  static_for<0, total_loops, 1>{}([&](auto i_issue) {
554  gemm_1(acc_1s[I0], y, ds[I0], i_issue);
555  constexpr index_t slot = sr.at(i_issue);
556  if constexpr(slot & GLD_B)
557  gld_d(ds[I1], number<NEXT_SCI(c_gld_b_1, i_issue)>{});
558 
559  if constexpr(slot & GST_O)
560  {
561  auto out = cast_tile<ODataType>(acc_1s[I1]);
562  atomic_add_o(out, number<NEXT_SCI(c_gst_o_1, i_issue)>{});
563  }
564  });
565  move_d();
566  };
567 
568  auto pipeline_gemm1_head = [&]() {
569  constexpr index_t total_loops = issues_gemm1;
570  constexpr auto sr = Policy::template GetSequencer_1<Problem>();
571  static_assert(sr.size() == total_loops);
572 
573  constexpr auto c_gld_b_0 = MAKE_SC();
574 
575  // compute buffer 0
576  static_for<0, total_loops, 1>{}([&](auto i_issue) {
577  gemm_1(acc_1s[I0], y, ds[I0], i_issue);
578  constexpr index_t slot = sr.at(i_issue);
579  if constexpr(slot & GLD_B)
580  gld_d(ds[I1], number<NEXT_SCI(c_gld_b_0, i_issue)>{});
581  });
582  move_d();
583  };
584  auto pipeline_gemm1_tail = [&]() {
585  constexpr index_t total_loops = issues_gemm1;
586  constexpr auto sr = Policy::template GetSequencer_1<Problem>();
587  static_assert(sr.size() == total_loops);
588 
589  constexpr auto c_gst_o_0 = MAKE_SC();
590 
591  // compute buffer 1
592  static_for<0, total_loops, 1>{}([&](auto i_issue) {
593  gemm_1(acc_1s[I1], y, ds[I1], i_issue);
594 
595  constexpr index_t slot = sr.at(i_issue);
596  if constexpr(slot & GST_O)
597  {
598  auto out = cast_tile<ODataType>(acc_1s[I0]);
599  atomic_add_o(out, number<NEXT_SCI(c_gst_o_0, i_issue)>{});
600  }
601  });
602  {
603  auto out = cast_tile<ODataType>(acc_1s[I1]);
604  atomic_add_o(out, NEG1);
605  }
606  };
607 
608  // start of pipeline
609  // clang-format off
610  gld_a(a_sst_win0, NEG1, TRUE);
611  gld_g(gs[I0], NEG1, TRUE);
612  move_a();
613  move_g();
614  clear_tile(acc_0);
615 
616  // preload for next round
617  gld_a(a_sst_win1, NEG1);
618  gld_g(gs[I1], NEG1);
619 
620  // make sure a,g loaded
621  block_sync_load_raw(issues_a + issues_g);
622  lds_load_fence();
623 
624  // we manually unroll double buffer inside hot loop
625  const index_t iters_0 = (num_blocks_k0 - 2) / 2;
626  index_t i_0 = 0; // (void)i_0; (void)iters_0; (void)pipeline_gemm0;
627  while(i_0++ < iters_0)
628  {
629  pipeline_gemm0();
630  }
631  pipeline_gemm0_tail();
632 
633  pipeline_bridge();
634 
635  const index_t iters_1 = (num_blocks_n1 - 2) / 2;
636  index_t i_1 = 0; // (void) i_1; (void)iters_1; (void)pipeline_gemm1;
637  pipeline_gemm1_head();
638  while(i_1++ < iters_1)
639  {
640  pipeline_gemm1();
641  }
642  pipeline_gemm1_tail();
643  // clang-format on
644  }
645 };
646 
647 } // namespace ck_tile
#define CK_TILE_DEVICE
Definition: config.hpp:41
#define CK_TILE_LDS_ADDR
Definition: config.hpp:58
#define CK_TILE_HOST_DEVICE
Definition: config.hpp:42
constexpr CK_TILE_HOST_DEVICE auto make_embed_tile_distribution_encoding(OuterDstr, InnerDstr)
Definition: tile_distribution_encoding.hpp:457
Definition: cluster_descriptor.hpp:13
tuple_array< T, N > statically_indexed_array
Definition: statically_indexed_array.hpp:16
int32_t index_t
Definition: integer.hpp:9
CK_TILE_DEVICE void lds_load_fence(index_t cnt=0)
Definition: amd_buffer_addressing.hpp:757
constexpr CK_TILE_HOST_DEVICE auto pad_tensor_view(const TensorView &tensor_view, const TileLengths &tile_lengths, DoPads)
Definition: tensor_view.hpp:530
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition: type_traits.hpp:21
constexpr CK_TILE_HOST_DEVICE auto to_sequence(tuple< number< Is >... >)
Definition: sequence.hpp:1052
constexpr CK_TILE_HOST_DEVICE auto merge_sequences(Seqs...)
Definition: sequence.hpp:823
constexpr CK_TILE_DEVICE auto make_tile_window(null_tensor_view, const WindowLengths &window_lengths, const multi_index< WindowLengths::size()> &, Ts &&...)
Definition: null_tile_window.hpp:75
CK_TILE_DEVICE auto load_tile_raw(T &tile, const tile_window_with_static_distribution< BottomTensorView_, WindowLengths_, TileDistribution_, NumCoord > &tile_window, number< i_access >={}, bool_constant< oob_conditional_check >={}, bool_constant< pre_nop >={})
Loads a tile of data using inline assembly.
Definition: load_tile.hpp:58
CK_TILE_DEVICE void move_tile_window(null_tile_window< WindowLengths > &, const typename null_tile_window< WindowLengths >::BottomTensorIndex &)
Definition: null_tile_window.hpp:95
constexpr CK_TILE_DEVICE auto make_tile_window_linear(const TensorView_ &tensor_view, const WindowLengths_ &window_lengths, const multi_index< TensorView_::get_num_of_dimension()> &origin, const StaticTileDistribution_ &tile_distribution, LinearBottomDims_={})
Definition: tile_window_linear.hpp:993
constexpr CK_TILE_HOST_DEVICE auto generate_tuple(F &&f, number< N >)
Definition: tuple.hpp:429
constexpr CK_TILE_HOST_DEVICE auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:360
CK_TILE_DEVICE auto async_load_tile_raw(LdsTileWindow_ &&lds_tile, const TileWindow_ &tile_window, number< i_access >={}, bool_constant< oob_conditional_check >={}, bool_constant< pre_nop >={})
Definition: load_tile.hpp:110
CK_TILE_DEVICE void store_tile(tile_window_with_static_lengths< BottomTensorView_, WindowLengths_ > &tile_window_tmp, const static_distributed_tensor< DataType_, TileDistribution_ > &dstr_tensor)
Definition: store_tile.hpp:23
CK_TILE_DEVICE void clear_tile(DstrTensors &dstr_tensor)
Definition: tile_elementwise.hpp:177
CK_TILE_DEVICE auto load_tile(const TileWindow_ &tile_window, number< i_access >={}, bool_constant< oob_conditional_check >={})
Definition: load_tile.hpp:22
constexpr CK_TILE_HOST_DEVICE auto make_static_tile_distribution(StaticTileDistributionEncoding_)
Definition: tile_distribution.hpp:480
typename uniform_sequence_gen< NSize, I >::type uniform_sequence_gen_t
Definition: sequence.hpp:1023
CK_TILE_DEVICE void update_tile_raw(tile_window_with_static_distribution< BottomTensorView_, WindowLengths_, TileDistribution_, NumCoord > &tile_window, const static_distributed_tensor< DataType_, TileDistribution_ > &dstr_tensor, number< i_access >={}, bool_constant< oob_conditional_check >={}, bool_constant< pre_nop >={})
Definition: update_tile.hpp:68
#define NEXT_SCI(c_, static_i_)
Definition: static_counter.hpp:109
#define MAKE_SC()
Definition: static_counter.hpp:104
Definition: fused_moegemm_pipeline_flatmm_ex.hpp:23
static constexpr index_t SLD_A
Definition: fused_moegemm_pipeline_flatmm_ex.hpp:54
static constexpr index_t kAlignmentA
Definition: fused_moegemm_pipeline_flatmm_ex.hpp:49
static constexpr bool PadHiddenSize
Definition: fused_moegemm_pipeline_flatmm_ex.hpp:46
static constexpr CK_TILE_HOST_DEVICE ck_tile::index_t GetSmemSize_A()
Definition: fused_moegemm_pipeline_flatmm_ex.hpp:72
typename Problem::DScaleDataType DScaleDataType
Definition: fused_moegemm_pipeline_flatmm_ex.hpp:36
typename Problem::BlockShape BlockShape
Definition: fused_moegemm_pipeline_flatmm_ex.hpp:27
static constexpr index_t kAlignmentO
Definition: fused_moegemm_pipeline_flatmm_ex.hpp:52
typename Problem::IndexDataType IndexDataType
Definition: fused_moegemm_pipeline_flatmm_ex.hpp:39
static constexpr index_t kBlockPerCu
Definition: fused_moegemm_pipeline_flatmm_ex.hpp:59
remove_cvref_t< Policy_ > Policy
Definition: fused_moegemm_pipeline_flatmm_ex.hpp:25
static constexpr const char * name
Definition: fused_moegemm_pipeline_flatmm_ex.hpp:69
static constexpr index_t GLD_B
Definition: fused_moegemm_pipeline_flatmm_ex.hpp:56
static constexpr index_t GLD_A
Definition: fused_moegemm_pipeline_flatmm_ex.hpp:55
remove_cvref_t< Problem_ > Problem
Definition: fused_moegemm_pipeline_flatmm_ex.hpp:24
typename Problem::ADataType ADataType
Definition: fused_moegemm_pipeline_flatmm_ex.hpp:29
static CK_TILE_HOST_DEVICE auto GetOCoord()
Definition: fused_moegemm_pipeline_flatmm_ex.hpp:91
static constexpr CK_TILE_HOST_DEVICE ck_tile::index_t GetSmemSize()
Definition: fused_moegemm_pipeline_flatmm_ex.hpp:77
typename Problem::GDataType GDataType
Definition: fused_moegemm_pipeline_flatmm_ex.hpp:30
static constexpr index_t kAlignmentG
Definition: fused_moegemm_pipeline_flatmm_ex.hpp:50
typename Problem::DDataType DDataType
Definition: fused_moegemm_pipeline_flatmm_ex.hpp:31
static CK_TILE_HOST_DEVICE auto GetACoord()
Definition: fused_moegemm_pipeline_flatmm_ex.hpp:83
typename Problem::ODataType ODataType
Definition: fused_moegemm_pipeline_flatmm_ex.hpp:33
typename Problem::GScaleDataType GScaleDataType
Definition: fused_moegemm_pipeline_flatmm_ex.hpp:35
typename Problem::AccDataType AccDataType
Definition: fused_moegemm_pipeline_flatmm_ex.hpp:32
static constexpr index_t GST_O
Definition: fused_moegemm_pipeline_flatmm_ex.hpp:57
typename Problem::TopkWeightDataType TopkWeightDataType
Definition: fused_moegemm_pipeline_flatmm_ex.hpp:38
static constexpr index_t kAlignmentD
Definition: fused_moegemm_pipeline_flatmm_ex.hpp:51
typename Problem::YDataType YDataType
Definition: fused_moegemm_pipeline_flatmm_ex.hpp:40
static constexpr bool IsGateOnly
Definition: fused_moegemm_pipeline_flatmm_ex.hpp:44
typename Problem::Traits Traits
Definition: fused_moegemm_pipeline_flatmm_ex.hpp:42
typename Problem::YSmoothScaleDataType YSmoothScaleDataType
Definition: fused_moegemm_pipeline_flatmm_ex.hpp:37
static constexpr bool PadIntermediateSize
Definition: fused_moegemm_pipeline_flatmm_ex.hpp:47
typename Problem::AScaleDataType AScaleDataType
Definition: fused_moegemm_pipeline_flatmm_ex.hpp:34
static constexpr bool UseSmoothQuant
Definition: fused_moegemm_pipeline_flatmm_ex.hpp:45
CK_TILE_DEVICE auto operator()(const AWindow &a_window_, const GWindow &g_window_, const DWindow &d_window_, OWindow &o_window_, TopkWeightDataType, CK_TILE_LDS_ADDR void *smem, index_t hidden_size, index_t intermediate_size)
Definition: fused_moegemm_pipeline_flatmm_ex.hpp:99
Definition: integral_constant.hpp:13
static constexpr value_type value
Definition: integral_constant.hpp:16
Definition: sequence.hpp:49
Definition: functional.hpp:43
Definition: tile_distribution_encoding.hpp:26
Definition: tuple.hpp:192