/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm_policy.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm_policy.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm_policy.hpp Source File
fused_moegemm_pipeline_flatmm_policy.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include "ck_tile/core.hpp"
8 #include "ck_tile/ops/flatmm.hpp"
11 
12 namespace ck_tile {
13 
15 {
17  {
18  // TODO: always 1 dword
19  return 1;
20  }
21 
22  template <typename Problem>
23  CK_TILE_HOST_DEVICE static constexpr auto GetAlignment_A()
24  {
25  // using async
26  constexpr index_t copy_bytes = 4 * GetAsyncCopyDwords();
27  constexpr index_t data_bytes = sizeof(typename Problem::ADataType);
28  static_assert(copy_bytes % data_bytes == 0);
29  return copy_bytes / data_bytes;
30  }
31 
32  template <typename Problem>
33  CK_TILE_HOST_DEVICE static constexpr auto GetAlignment_G()
34  {
35  constexpr index_t copy_bytes = [&]() { return 16; }();
36  constexpr index_t data_bytes = sizeof(typename Problem::GDataType);
37  static_assert(copy_bytes % data_bytes == 0);
38  return copy_bytes / data_bytes;
39  }
40 
41  template <typename Problem>
42  CK_TILE_HOST_DEVICE static constexpr auto GetAlignment_D()
43  {
44  constexpr index_t copy_bytes = [&]() { return 16; }();
45  constexpr index_t data_bytes = sizeof(typename Problem::DDataType);
46  static_assert(copy_bytes % data_bytes == 0);
47  return copy_bytes / data_bytes;
48  }
49 
50  template <typename Problem>
51  CK_TILE_HOST_DEVICE static constexpr auto GetAlignment_O()
52  {
53  if constexpr(Problem::Traits::OAtomic == 1)
54  {
55  // pack fp16/bf16 atomic
56  static_assert(sizeof(typename Problem::ODataType) == 2);
57  return 2;
58  }
59  else if constexpr(Problem::Traits::OAtomic == 2)
60  {
61  // fp32 atomic
62  return 1;
63  }
64  else
65  {
66  return 16 / sizeof(typename Problem::ODataType);
67  }
68  }
69 
70  template <typename DataType_>
71  CK_TILE_HOST_DEVICE static constexpr auto GetSmemKPack()
72  {
73  // TODO: this is for 3d layout
74  return 16 / sizeof(remove_cvref_t<DataType_>);
75  }
76 
77  template <typename Problem>
78  CK_TILE_HOST_DEVICE static constexpr auto GetSmemKPack_A()
79  {
80  return GetSmemKPack<typename Problem::ADataType>();
81  }
82 
83  // used for bridge LDS shuffle
84  template <typename Problem>
85  CK_TILE_HOST_DEVICE static constexpr auto GetSmemKPack_Y()
86  {
87  // TODO: this should match mfma layout
88  return 16 / sizeof(typename Problem::YDataType);
89  }
90 
91  template <typename Problem>
93  {
94  constexpr auto a_sld_desc = MakeLdsLoadDesc_A<Problem>();
95  constexpr auto a_sst_desc = MakeLdsStoreDesc_A<Problem>();
96  static_assert(a_sld_desc.get_element_space_size() == a_sst_desc.get_element_space_size());
97  return a_sld_desc.get_element_space_size();
98  }
99 
100  template <typename Problem>
102  {
103  constexpr auto bridge_sld_desc = MakeBridgeLdsLoadDesc<Problem>();
104  constexpr auto bridge_sst_desc = MakeBridgeLdsStoreDesc<Problem>();
105  static_assert(bridge_sld_desc.get_element_space_size() ==
106  bridge_sst_desc.get_element_space_size());
107  return bridge_sld_desc.get_element_space_size();
108  }
109 
110  template <typename Problem>
112  {
113  constexpr index_t a_lds = GetSmemSize_A<Problem>();
114  constexpr index_t bridge_lds = GetSmemSize_Bridge<Problem>();
115  return max(a_lds, bridge_lds);
116  }
117 
118  template <index_t MPerBlock, index_t KPerBlock, index_t NumWarps, index_t Alignment>
120  {
121  constexpr index_t K_vec = Alignment;
122  constexpr index_t K_rem = KPerBlock / K_vec;
123 
124  if constexpr(get_warp_size() < K_rem)
125  {
126  static_assert(K_rem % get_warp_size() == 0);
127  constexpr index_t K_lan = get_warp_size(); // lane within same wave is along gemm-k
128  constexpr index_t K_wav = K_rem / get_warp_size();
129  static_assert(K_wav <= NumWarps, "not not support thread has repeat along K yet");
130  constexpr index_t M_wav = NumWarps / K_wav;
131  static_assert(MPerBlock % M_wav == 0, "this tile size is too small please check");
132  constexpr index_t M_rep = MPerBlock / M_wav;
133 
136  sequence<1>,
141  sequence<0, 2>>{});
142  }
143  else
144  {
145  constexpr index_t K_lan = K_rem;
146  constexpr index_t M_lan = get_warp_size() / K_lan;
147  constexpr index_t M_wav = NumWarps;
148  static_assert(MPerBlock % (M_lan * M_wav) == 0,
149  "this tile size is too small please check");
150  constexpr index_t M_rep = MPerBlock / (M_lan * M_wav);
153  sequence<1>,
158  sequence<0, 1>>{});
159  }
160  }
161 
162  // optimized version for async, not same as simple MXK dist(pay attention!!)
163  template <index_t MPerBlock, index_t KPerBlock, index_t NumWarps, index_t Alignment>
165  {
166  constexpr index_t K_vec = Alignment;
167  constexpr index_t K_rem = KPerBlock / K_vec;
168 
169  if constexpr(get_warp_size() <= K_rem)
170  {
171  static_assert(K_rem % get_warp_size() == 0);
172  constexpr index_t K_lan = get_warp_size(); // lane within same wave is along gemm-k
173  constexpr index_t K_wav = K_rem / get_warp_size();
174  static_assert(K_wav <= NumWarps, "do not support thread has repeat along K yet");
175  constexpr index_t M_wav = NumWarps / K_wav;
176  static_assert(MPerBlock % M_wav == 0, "this tile size is too small please check");
177  constexpr index_t M_rep = MPerBlock / M_wav;
178  // NOTE: no swap, but hard to avoid LDS bank conflict
181  sequence<1>,
186  sequence<0, 2>>{});
187  }
188  else
189  {
190  constexpr index_t K_lan = K_rem;
191  constexpr index_t M_lan = get_warp_size() / K_lan;
192  constexpr index_t M_wav = NumWarps;
193  static_assert(MPerBlock % (M_lan * M_wav) == 0,
194  "this tile size is too small please check");
195  constexpr index_t M_rep = MPerBlock / (M_lan * M_wav);
196  // NOTE: swapped for LDS load bank conflict free
199  sequence<1>,
200  // Note M_wave(num waves) is the fastest dim, different from sipmle 2d
201  // distribution
206  sequence<0, 1>>{});
207  }
208  }
209 
210  template <index_t WarpPerBlock_N_,
211  index_t WarpPerBlock_K_,
212  index_t Repeat_N_,
213  index_t Repeat_K_,
214  index_t WarpSize_,
215  index_t Alignment_>
217  {
226  sequence<0, 0, 1>>{});
227  }
228 
229  template <typename Problem>
231  {
232  constexpr index_t Block_M_ = Problem::BlockShape::Block_M0;
233  constexpr index_t Block_K_ = Problem::BlockShape::Block_K0;
234  constexpr index_t NumWarps_ = Problem::BlockShape::NumWarps;
235  constexpr index_t Alignment_ = GetAlignment_A<Problem>();
237  Block_K_,
238  NumWarps_,
239  Alignment_>();
240  }
241 
242  template <typename Problem>
244  {
245  constexpr auto PermuteEnum = Problem::Traits::PermuteEnum;
246  // constexpr index_t hidden_radio_0 = Problem::Traits::IsGateOnly ? 1 : 2;
247  using S_ = typename Problem::BlockShape;
248  if constexpr(PermuteEnum == FusedMoeGemmWeightPermuteEnum::b_nr_kr_waveflatten)
249  {
250  // number<S_::WarpPerBlock_N0>{}.rrr();
251  // number<S_::Repeat_N0>{}.eee();
252  return MakeGlobalTileDistribution_Nr_Kr_W<S_::WarpPerBlock_N0,
253  S_::WarpPerBlock_K0,
254  S_::Repeat_N0,
255  S_::Repeat_K0,
256  get_warp_size(),
257  GetAlignment_G<Problem>()>();
258  }
259  }
260 
261  template <typename Problem>
263  {
264  constexpr auto PermuteEnum = Problem::Traits::PermuteEnum;
265  using S_ = typename Problem::BlockShape;
266  if constexpr(PermuteEnum == FusedMoeGemmWeightPermuteEnum::b_nr_kr_waveflatten)
267  {
268  return MakeGlobalTileDistribution_Nr_Kr_W<S_::WarpPerBlock_N1,
269  S_::WarpPerBlock_K1,
270  S_::Repeat_N1,
271  S_::Repeat_K1,
272  get_warp_size(),
273  GetAlignment_D<Problem>()>();
274  }
275  }
276 
277  template <typename Problem>
279  {
282  // using CDataType = typename WarpGemm::CDataType;
283 
284  constexpr auto c_block_outer_dstr_encoding =
291  sequence<0, 0>>{};
292 
293  constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
294  c_block_outer_dstr_encoding, typename WarpGemm::CWarpDstrEncoding{});
295  constexpr auto c_block_dstr = make_static_tile_distribution(c_block_dstr_encode);
296  return c_block_dstr;
297  }
298 
299  template <typename Problem>
300  CK_TILE_HOST_DEVICE static constexpr auto MakeLdsStoreDesc_A()
301  {
302  // A async->LDS
303  constexpr index_t Block_M = Problem::BlockShape::Block_M0;
304  constexpr index_t Block_K = Problem::BlockShape::Block_K0;
305  // constexpr index_t BlockSize = Problem::BlockShape::BlockSize;
306  constexpr index_t WarpSize = ck_tile::get_warp_size();
307  constexpr index_t NumWarps = Problem::BlockShape::NumWarps;
308 
309  constexpr index_t KPack = GetSmemKPack_A<Problem>(); // LDS
310  constexpr index_t KVector = GetAlignment_A<Problem>(); // async copy 1 dword
311  constexpr index_t KPad = KPack; // pad between warps
312 
313  static_assert(Block_K % KVector == 0);
314  constexpr index_t LanesPerK = Block_K / KVector; // how many thread loading K
315  if constexpr(LanesPerK >= WarpSize)
316  {
317  // need multiple waves to load K
318  static_assert(LanesPerK % WarpSize == 0);
319  constexpr index_t wavesPerK = LanesPerK / WarpSize;
320  if constexpr(wavesPerK > NumWarps)
321  {
322  // TODO: need multiple issues along K to load all data
323  }
324  else
325  {
326  constexpr index_t wavesPerM = NumWarps / wavesPerK;
327  constexpr index_t NumIssues = Block_M / wavesPerM;
328  constexpr auto lds_block_desc_0 = make_naive_tensor_descriptor(
330  number<wavesPerM>{}, // m1
331  number<wavesPerK>{}, // k0
332  number<WarpSize>{}, // k1
333  number<KVector>{}), // k2
334  make_tuple(number<NumWarps*(WarpSize * KVector + KPad)>{}, // m0
335  number<wavesPerK*(WarpSize * KVector + KPad)>{}, // m1
337  number<KVector>{}, // k1
338  number<1>{}), // k2
339  number<KVector>{}, // lds store vector(actually no explicit store)
340  number<1>{});
341 
342  constexpr auto lds_block_desc_issues_warps_lanes = transform_tensor_descriptor(
343  lds_block_desc_0,
344  make_tuple(
350 
351  return lds_block_desc_issues_warps_lanes;
352  }
353  }
354  else
355  {
356  // lanes within a wave load different M but same K
357  static_assert(WarpSize % LanesPerK == 0);
358  constexpr index_t LaneGroups = WarpSize / LanesPerK; // along m
359  constexpr index_t NumIssues = Block_M / (LaneGroups * NumWarps);
360 
361  constexpr auto lds_block_desc_0 = make_naive_tensor_descriptor(
363  number<LaneGroups>{}, // m1
364  number<NumWarps>{}, // m2
365  number<LanesPerK>{}, // k0
366  number<KVector>{}), // k1
367  make_tuple(number<NumWarps*(WarpSize * KVector + KPad)>{}, // m0
368  number<Block_K>{}, // m1
370  number<KVector>{}, // k0
371  number<1>{}), // k1
372  number<KVector>{}, // lds store vector(actually no explicit store)
373  number<1>{});
374 
375  constexpr auto lds_block_desc_issues_warps_lanes = transform_tensor_descriptor(
376  lds_block_desc_0,
383 
384  return lds_block_desc_issues_warps_lanes;
385  }
386  }
387 
388  template <typename Problem>
389  CK_TILE_HOST_DEVICE static constexpr auto MakeLdsLoadDesc_A()
390  {
391  // A async->LDS
392  // Note that, this descriptor is only to construct the layout inside LDS
393  // in real Gemm pipeline, ds_read may not follow this pattern
394  // (may follow that in tile_distribution)
395  // below code is almost the same as SmemStore dist, with difference:
396  // 1). modify the GuaranteedLastDimensionVectorLength of naive tensor desc
397  // 2). return discriptor is in NxK 2d layout
398  constexpr index_t Block_M = Problem::BlockShape::Block_M0;
399  constexpr index_t Block_K = Problem::BlockShape::Block_K0;
400  // constexpr index_t BlockSize = Problem::BlockShape::BlockSize;
401  constexpr index_t WarpSize = ck_tile::get_warp_size();
402  constexpr index_t NumWarps = Problem::BlockShape::NumWarps;
403 
404  constexpr index_t KPack = GetSmemKPack_A<Problem>(); // LDS
405  constexpr index_t KVector = GetAlignment_A<Problem>(); // async copy 1 dword
406  constexpr index_t KPad = KPack; // pad between warps
407 
408  static_assert(Block_K % KVector == 0);
409  constexpr index_t LanesPerK = Block_K / KVector; // how many thread loading K
410  if constexpr(LanesPerK >= WarpSize)
411  {
412  // need multiple waves to load K
413  static_assert(LanesPerK % WarpSize == 0);
414  constexpr index_t wavesPerK = LanesPerK / WarpSize;
415  if constexpr(wavesPerK >= NumWarps)
416  {
417  // TODO: need multiple issues along K to load all data
418  }
419  else
420  {
421  constexpr index_t wavesPerM = NumWarps / wavesPerK;
422  constexpr index_t NumIssues = Block_M / wavesPerM;
423  constexpr auto lds_block_desc_0 = make_naive_tensor_descriptor(
425  number<wavesPerM>{}, // m1
426  number<wavesPerK>{}, // k0
427  number<WarpSize>{}, // k1
428  number<KVector>{}), // k2
429  make_tuple(number<NumWarps*(WarpSize * KVector + KPad)>{}, // m0
430  number<wavesPerK*(WarpSize * KVector + KPad)>{}, // m1
432  number<KVector>{}, // k1
433  number<1>{}), // k2
434  number<KPack>{}, // lds load vector
435  number<1>{});
436 
437  constexpr auto lds_desc_m_k = transform_tensor_descriptor(
438  lds_block_desc_0,
439  make_tuple(
445 
446  return lds_desc_m_k;
447  }
448  }
449  else
450  {
451  // lanes within a wave load different M but same K
452  static_assert(WarpSize % LanesPerK == 0);
453  constexpr index_t LaneGroups = WarpSize / LanesPerK; // along m
454  constexpr index_t NumIssues = Block_M / (LaneGroups * NumWarps);
455 
456  constexpr auto lds_block_desc_0 = make_naive_tensor_descriptor(
458  number<LaneGroups>{}, // m1
459  number<NumWarps>{}, // m2
460  number<LanesPerK>{}, // k0
461  number<KVector>{}), // k1
462  make_tuple(number<NumWarps*(WarpSize * KVector + KPad)>{}, // m0
463  number<Block_K>{}, // m1
465  number<KVector>{}, // k0
466  number<1>{}), // k1
467  number<KPack>{}, // lds load vector
468  number<1>{});
469 
470  constexpr auto lds_desc_m_k = transform_tensor_descriptor(
471  lds_block_desc_0,
472  make_tuple(
478 
479  return lds_desc_m_k;
480  }
481  }
482 
483  template <typename Problem>
485  {
486  constexpr index_t Block_M = Problem::BlockShape::Block_M0;
487  constexpr index_t Block_N = Problem::BlockShape::Block_N0;
488 
489  constexpr index_t KVector = GetSmemKPack_Y<Problem>(); // async copy 1 dword
490  constexpr index_t KPad = 0; // pad between warps
491 
492  constexpr auto desc =
495  number<KVector>{},
496  number<1>{});
497  return desc;
498  }
499 
500  template <typename Problem>
502  {
503  constexpr index_t Block_M = Problem::BlockShape::Block_M0;
504  constexpr index_t Block_N = Problem::BlockShape::Block_N0;
505 
506  constexpr index_t KVector = GetSmemKPack_Y<Problem>(); // async copy 1 dword
507  constexpr index_t KPad = 0; // KVector; // pad between warps
508 
509  constexpr auto desc =
512  number<KVector>{},
513  number<1>{});
514  return desc;
515  }
516 
517  template <typename Problem>
519  {
520  constexpr index_t WarpPerBlock_N = Problem::BlockShape::WarpPerBlock_N0;
521  constexpr index_t Repeat_N = Problem::BlockShape::Repeat_N0;
522  constexpr index_t Repeat_M = Problem::BlockShape::Repeat_M0;
523 
524  constexpr index_t kAMLane = 16;
525  constexpr index_t kABKLane = 4;
526  constexpr index_t kABKPerLane = 4;
527 
528  constexpr index_t KPack = kABKPerLane;
529 
530  constexpr auto lds_block_desc_0 = make_naive_tensor_descriptor(
532  number<Repeat_N>{}, // n
533  number<WarpPerBlock_N>{}, // n
534  number<kABKLane>{}, // n
535  number<kAMLane>{}, // m
536  number<KPack>{}), // n
541  number<KPack>{}, // m
542  number<1>{}), // n
543  number<KPack>{}, // lds store vector(actually no explicit store)
544  number<1>{});
545 
546  constexpr auto desc = transform_tensor_descriptor(
547  lds_block_desc_0,
552  number<KPack>{}))),
555 
556  return desc;
557  }
558 
559  template <typename Problem>
560  CK_TILE_HOST_DEVICE static constexpr auto GetWarpGemm0()
561  {
562  using S_ = typename Problem::BlockShape;
563  // A is vgpr, B is agpr. But since we transposed, so also need swap this
564  // TODO: this is ugly
565  constexpr auto wg_ctrl = WGAttrCtlEnum::Raw_avv;
566  // TODO: ugly
567  if constexpr(std::is_same_v<typename Problem::ADataType, ck_tile::bf16_t> &&
568  std::is_same_v<typename Problem::GDataType, ck_tile::bf16_t> &&
569  S_::Warp_M0 == 32 && S_::Warp_N0 == 32 && S_::Warp_K0 == 16)
570  {
573  2>>{};
574  }
575  else if constexpr(std::is_same_v<typename Problem::ADataType, ck_tile::int8_t> &&
576  std::is_same_v<typename Problem::GDataType, ck_tile::int8_t> &&
577  S_::Warp_M0 == 32 && S_::Warp_N0 == 32 && S_::Warp_K0 == 32)
578  {
581  2>>{};
582  }
583  }
584 
585  template <typename Problem>
586  CK_TILE_HOST_DEVICE static constexpr auto GetSequencer_0()
587  {
588  // this function return seq<...> used to identify gld/sld/valu... inside mfma sequence
589  // the purpose is to hide thoes instructions under mfma
590  // every value inside seq<...> is a mask, indicating a specific operation
591  using S_ = typename Problem::BlockShape;
595  if constexpr(std::is_same_v<typename Problem::YDataType, ck_tile::bf16_t> &&
596  std::is_same_v<typename Problem::DDataType, ck_tile::bf16_t> &&
597  S_::Warp_M0 == 32 && S_::Warp_N0 == 32 && S_::Warp_K0 == 16 &&
598  S_::Block_M0 == 32 && S_::Block_N0 == 512 && S_::Block_K0 == 128 &&
599  S_::Block_N1 == 128)
600  {
601  // Total 64 instructions, 32 buffer-load-dwordx4 gld_b, 8x buffer-load-dwordx1-async
602  // gld_a 8x ds_read_b128 sld_a total 64 slot :)
603  // clang-format off
604  constexpr auto seq_all =
605  // 0 1 2 3 4 5 6 7
607  GLD_B, GLD_A, GLD_B, GLD_A, GLD_B, GLD_A, GLD_B, GLD_A, // 1
608  GLD_B, SLD_A, GLD_B, SLD_A, GLD_B, SLD_A, GLD_B, SLD_A, // 2
609  GLD_B, SLD_A, GLD_B, SLD_A, GLD_B, SLD_A, GLD_B, SLD_A, // 3
610  GLD_B, 0, GLD_B, 0, GLD_B, 0, GLD_B, 0, // 4
611  GLD_B, 0, GLD_B, 0, GLD_B, 0, GLD_B, 0, // 5
612  GLD_B, 0, GLD_B, 0, GLD_B, 0, GLD_B, 0, // 6
613  GLD_B, 0, GLD_B, 0, GLD_B, 0, GLD_B, 0>{}; // 7
614  return seq_all;
615  // clang-format on
616  }
617  else if constexpr(std::is_same_v<typename Problem::YDataType, ck_tile::bf16_t> &&
618  std::is_same_v<typename Problem::DDataType, ck_tile::bf16_t> &&
619  S_::Warp_M0 == 32 && S_::Warp_N0 == 32 && S_::Warp_K0 == 16 &&
620  S_::Block_M0 == 32 && S_::Block_N0 == 256 && S_::Block_K0 == 128 &&
621  S_::Block_N1 == 128)
622  {
623  // Total 32 instructions, 16 buffer-load-dwordx4 gld_b, 8x buffer-load-dwordx1-async
624  // gld_a 8x ds_read_b128 sld_a total 64 slot :)
625  // clang-format off
626  constexpr auto seq_all =
627  // 0 1 2 3 4 5 6 7
629  GLD_B, GLD_A, GLD_B, GLD_A, GLD_B, GLD_A, GLD_B, GLD_A, // 1
630  GLD_B, SLD_A, GLD_B, SLD_A, GLD_B, SLD_A, GLD_B, SLD_A, // 2
631  GLD_B, SLD_A, GLD_B, SLD_A, GLD_B, SLD_A, GLD_B, SLD_A>{}; // 3
632  return seq_all;
633  // clang-format on
634  }
635  }
636 
637  template <typename Problem>
638  CK_TILE_HOST_DEVICE static constexpr auto GetSequencer_1()
639  {
640  // this function return seq<...> used to identify gld/sld/valu... inside mfma sequence
641  // the purpose is to hide thoes instructions under mfma
642  // every value inside seq<...> is a mask, indicating a specific operation
643  using S_ = typename Problem::BlockShape;
646  if constexpr(std::is_same_v<typename Problem::YDataType, ck_tile::bf16_t> &&
647  std::is_same_v<typename Problem::DDataType, ck_tile::bf16_t> &&
648  S_::Warp_M1 == 32 && S_::Warp_N1 == 32 && S_::Warp_K1 == 16 &&
649  S_::Block_M0 == 32 && S_::Block_N0 == 512 && S_::Block_K0 == 128 &&
650  S_::Block_N1 == 128)
651  {
652  // Total 64 instructions, 32 buffer-load-dwordx4 gld_b, 8x buffer-load-dwordx1-async
653  // gld_a 8x ds_read_b128 sld_a total 64 slot :)
654  // clang-format off
655  constexpr auto seq_all =
656  // 0 1 2 3 4 5 6 7
658  GLD_B, GST_O, GLD_B, GST_O, GLD_B, GST_O, GLD_B, GST_O, // 1
659  GLD_B, 0, GLD_B, 0, GLD_B, 0, GLD_B, 0, // 2
660  GLD_B, 0, GLD_B, 0, GLD_B, 0, GLD_B, 0, // 3
661  GLD_B, 0, GLD_B, 0, GLD_B, 0, GLD_B, 0, // 4
662  GLD_B, 0, GLD_B, 0, GLD_B, 0, GLD_B, 0, // 5
663  GLD_B, 0, GLD_B, 0, GLD_B, 0, GLD_B, 0, // 6
664  GLD_B, 0, GLD_B, 0, GLD_B, 0, GLD_B, 0>{}; // 7
665  return seq_all;
666  // clang-format on
667  }
668  else if constexpr(std::is_same_v<typename Problem::YDataType, ck_tile::bf16_t> &&
669  std::is_same_v<typename Problem::DDataType, ck_tile::bf16_t> &&
670  S_::Warp_M1 == 32 && S_::Warp_N1 == 32 && S_::Warp_K1 == 16 &&
671  S_::Block_M0 == 32 && S_::Block_N0 == 256 && S_::Block_K0 == 128 &&
672  S_::Block_N1 == 128)
673  {
674  // Total 64 instructions, 32 buffer-load-dwordx4 gld_b, 8x buffer-load-dwordx1-async
675  // gld_a 8x ds_read_b128 sld_a total 64 slot :)
676  // clang-format off
677  constexpr auto seq_all =
678  // 0 1 2 3 4 5 6 7
680  GLD_B, GST_O, GLD_B, GST_O, GLD_B, GST_O, GLD_B, GST_O, // 1
681  GLD_B, 0, GLD_B, 0, GLD_B, 0, GLD_B, 0, // 2
682  GLD_B, 0, GLD_B, 0, GLD_B, 0, GLD_B, 0>{}; // 3
683  return seq_all;
684  // clang-format on
685  }
686  }
687 
688  template <typename Problem>
689  CK_TILE_HOST_DEVICE static constexpr auto GetWarpGemm1()
690  {
691  using S_ = typename Problem::BlockShape;
692  constexpr auto wg_ctrl = WGAttrCtlEnum::Raw_avv;
693  // TODO: ugly
694  if constexpr(std::is_same_v<typename Problem::YDataType, ck_tile::bf16_t> &&
695  std::is_same_v<typename Problem::DDataType, ck_tile::bf16_t> &&
696  S_::Warp_M0 == 32 && S_::Warp_N0 == 32 && S_::Warp_K0 == 16)
697  {
700  2>>{};
701  }
702  else if constexpr(std::is_same_v<typename Problem::YDataType, ck_tile::int8_t> &&
703  std::is_same_v<typename Problem::DDataType, ck_tile::int8_t> &&
704  S_::Warp_M0 == 32 && S_::Warp_N0 == 32 && S_::Warp_K0 == 32)
705  {
708  2>>{};
709  }
710  }
711 
712  template <typename Problem>
714  {
717  using CDataType = typename WarpGemm::CDataType;
718 
719  constexpr auto c_block_outer_dstr_encoding =
726  sequence<0, 0>>{};
727 
728  constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
729  c_block_outer_dstr_encoding, typename WarpGemm::CWarpDstrEncoding{});
730  constexpr auto c_block_dstr = make_static_tile_distribution(c_block_dstr_encode);
731  auto c_block_tensor = make_static_distributed_tensor<CDataType>(c_block_dstr);
732  return c_block_tensor;
733  }
734 
735  template <typename Problem>
737  {
740  using CDataType = typename WarpGemm::CDataType;
741 
742  constexpr auto c_block_outer_dstr_encoding =
749  sequence<0, 0>>{};
750 
751  constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
752  c_block_outer_dstr_encoding, typename WarpGemm::CWarpDstrEncoding{});
753  constexpr auto c_block_dstr = make_static_tile_distribution(c_block_dstr_encode);
754  auto c_block_tensor = make_static_distributed_tensor<CDataType>(c_block_dstr);
755  return c_block_tensor;
756  }
757 
758  // this is used as A matrix for 2nd gemm
759  template <typename Problem>
761  {
764 
765  // TODO: all waves a along different N, but same M
766  constexpr auto y_outer_dstr_enc =
772  sequence<0, 0>>{};
773 
774  constexpr auto y_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
775  y_outer_dstr_enc, typename WarpGemm::AWarpDstrEncoding{});
776  constexpr auto y_block_dstr = make_static_tile_distribution(y_block_dstr_encode);
777  return y_block_dstr;
778  }
779 
780  template <typename Problem>
781  CK_TILE_HOST_DEVICE static constexpr auto MakeYBlockTile()
782  {
783  constexpr auto y_block_dstr = MakeYTileDistribution<Problem>();
784  auto y_block_tensor =
785  make_static_distributed_tensor<typename Problem::YDataType>(y_block_dstr);
786  return y_block_tensor;
787  }
788 
789  template <typename Problem>
790  CK_TILE_HOST_DEVICE static constexpr auto GetUK_0()
791  {
792  using S_ = typename Problem::BlockShape;
793  if constexpr(std::is_same_v<typename Problem::ADataType, ck_tile::bf16_t> &&
794  std::is_same_v<typename Problem::GDataType, ck_tile::bf16_t> &&
795  S_::Block_M0 == 32 && S_::Block_N0 == 512 && S_::Block_K0 == 128 &&
796  S_::Warp_M0 == 16 && S_::Warp_N0 == 16 && S_::Warp_K0 == 32)
797  {
799  }
800  else if constexpr(std::is_same_v<typename Problem::ADataType, ck_tile::fp16_t> &&
801  std::is_same_v<typename Problem::GDataType, ck_tile::fp16_t> &&
802  S_::Block_M0 == 32 && S_::Block_N0 == 512 && S_::Block_K0 == 128 &&
803  S_::Warp_M0 == 16 && S_::Warp_N0 == 16 && S_::Warp_K0 == 32)
804  {
806  }
807  }
808 
809  template <typename Problem>
810  CK_TILE_HOST_DEVICE static constexpr auto GetUK_1()
811  {
812  using S_ = typename Problem::BlockShape;
813  using T_ = typename Problem::Traits;
814  if constexpr(std::is_same_v<typename Problem::YDataType, ck_tile::bf16_t> &&
815  std::is_same_v<typename Problem::DDataType, ck_tile::bf16_t> &&
816  std::is_same_v<typename Problem::TopkWeightDataType, float> &&
817  S_::Block_M1 == 32 && S_::Block_N1 == 128 && S_::Block_K1 == 512 &&
818  S_::Warp_M0 == 16 && S_::Warp_N0 == 16 && S_::Warp_K0 == 32 &&
819  T_::PipeInterleave == false)
820  {
822  // return FlatmmSn_32x128x512_1x4x1_16x16x32_BF16_itl{};
823  }
824  else if constexpr(std::is_same_v<typename Problem::YDataType, ck_tile::fp16_t> &&
825  std::is_same_v<typename Problem::DDataType, ck_tile::fp16_t> &&
826  std::is_same_v<typename Problem::TopkWeightDataType, float> &&
827  S_::Block_M1 == 32 && S_::Block_N1 == 128 && S_::Block_K1 == 512 &&
828  S_::Warp_M0 == 16 && S_::Warp_N0 == 16 && S_::Warp_K0 == 32 &&
829  T_::PipeInterleave == false)
830  {
832  // return FlatmmSn_32x128x512_1x4x1_16x16x32_FP16_itl{};
833  }
834  else if constexpr(std::is_same_v<typename Problem::YDataType, ck_tile::bf16_t> &&
835  std::is_same_v<typename Problem::DDataType, ck_tile::bf16_t> &&
836  std::is_same_v<typename Problem::TopkWeightDataType, float> &&
837  S_::Block_M1 == 32 && S_::Block_N1 == 128 && S_::Block_K1 == 512 &&
838  S_::Warp_M0 == 16 && S_::Warp_N0 == 16 && S_::Warp_K0 == 32 &&
839  T_::PipeInterleave == true)
840  {
841  // return FlatmmSn_32x128x512_1x4x1_16x16x32_FP16{};
843  }
844  else if constexpr(std::is_same_v<typename Problem::YDataType, ck_tile::fp16_t> &&
845  std::is_same_v<typename Problem::DDataType, ck_tile::fp16_t> &&
846  std::is_same_v<typename Problem::TopkWeightDataType, float> &&
847  S_::Block_M1 == 32 && S_::Block_N1 == 128 && S_::Block_K1 == 512 &&
848  S_::Warp_M0 == 16 && S_::Warp_N0 == 16 && S_::Warp_K0 == 32 &&
849  T_::PipeInterleave == true)
850  {
851  // return FlatmmSn_32x128x512_1x4x1_16x16x32_FP16{};
853  }
854  }
855 };
856 } // namespace ck_tile
#define CK_TILE_HOST_DEVICE
Definition: config.hpp:42
constexpr CK_TILE_HOST_DEVICE auto make_embed_tile_distribution_encoding(OuterDstr, InnerDstr)
Definition: tile_distribution_encoding.hpp:457
Definition: cluster_descriptor.hpp:13
constexpr CK_TILE_HOST_DEVICE auto make_naive_tensor_descriptor(const tuple< Lengths... > &lengths, const tuple< Strides... > &strides, number< GuaranteedLastDimensionVectorLength >=number<-1 >{}, number< GuaranteedLastDimensionVectorStride >=number<-1 >{})
Definition: tensor_descriptor.hpp:268
constexpr CK_TILE_HOST_DEVICE auto make_merge_transform(const LowLengths &low_lengths)
Definition: coordinate_transform.hpp:1615
int32_t index_t
Definition: integer.hpp:9
constexpr CK_TILE_HOST_DEVICE auto make_pass_through_transform(const LowLength &low_length)
Definition: coordinate_transform.hpp:1558
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition: type_traits.hpp:21
constexpr CK_TILE_HOST_DEVICE auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldTopIdss, NewUpperDimensionNewTopIdss)
Definition: tensor_descriptor.hpp:197
constexpr CK_TILE_HOST_DEVICE auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:360
constexpr CK_TILE_HOST_DEVICE T max(T x)
Definition: math.hpp:161
constexpr CK_TILE_HOST_DEVICE auto make_static_tile_distribution(StaticTileDistributionEncoding_)
Definition: tile_distribution.hpp:480
__host__ constexpr __device__ index_t get_warp_size()
Definition: get_id.hpp:42
Definition: flatmm_32x512x128_1x4x1_16x16x32.hpp:401
Definition: flatmm_32x512x128_1x4x1_16x16x32.hpp:540
Definition: flatmm_sn_32x128x512_1x4x1_16x16x32_itl.hpp:18
Definition: flatmm_sn_32x128x512_1x4x1_16x16x32.hpp:74
Definition: flatmm_sn_32x128x512_1x4x1_16x16x32_itl.hpp:265
Definition: flatmm_sn_32x128x512_1x4x1_16x16x32.hpp:318
Definition: fused_moegemm_pipeline_flatmm_policy.hpp:15
static constexpr CK_TILE_HOST_DEVICE auto MakeGlobalTileDistribution_D()
Definition: fused_moegemm_pipeline_flatmm_policy.hpp:262
static constexpr CK_TILE_HOST_DEVICE auto MakeCBlockTile_Gemm1()
Definition: fused_moegemm_pipeline_flatmm_policy.hpp:736
static constexpr CK_TILE_HOST_DEVICE auto MakeGlobalTileDistribution_O()
Definition: fused_moegemm_pipeline_flatmm_policy.hpp:278
static constexpr CK_TILE_HOST_DEVICE auto GetSmemKPack_Y()
Definition: fused_moegemm_pipeline_flatmm_policy.hpp:85
static constexpr CK_TILE_HOST_DEVICE auto GetSmemKPack()
Definition: fused_moegemm_pipeline_flatmm_policy.hpp:71
static constexpr CK_TILE_HOST_DEVICE auto GetAlignment_G()
Definition: fused_moegemm_pipeline_flatmm_policy.hpp:33
static constexpr CK_TILE_HOST_DEVICE auto MakeBridgeLdsStoreDesc()
Definition: fused_moegemm_pipeline_flatmm_policy.hpp:501
static constexpr CK_TILE_HOST_DEVICE ck_tile::index_t GetSmemSize_A()
Definition: fused_moegemm_pipeline_flatmm_policy.hpp:92
static constexpr CK_TILE_HOST_DEVICE auto MakeYTileDistribution()
Definition: fused_moegemm_pipeline_flatmm_policy.hpp:760
static constexpr CK_TILE_HOST_DEVICE ck_tile::index_t GetSmemSize_Bridge()
Definition: fused_moegemm_pipeline_flatmm_policy.hpp:101
static constexpr CK_TILE_HOST_DEVICE auto MakeGlobalTileDistribution_SimpleMxK_Async()
Definition: fused_moegemm_pipeline_flatmm_policy.hpp:164
static constexpr CK_TILE_HOST_DEVICE auto MakeGlobalTileDistribution_SimpleMxK()
Definition: fused_moegemm_pipeline_flatmm_policy.hpp:119
static constexpr CK_TILE_HOST_DEVICE auto GetWarpGemm0()
Definition: fused_moegemm_pipeline_flatmm_policy.hpp:560
static constexpr CK_TILE_HOST_DEVICE index_t GetAsyncCopyDwords()
Definition: fused_moegemm_pipeline_flatmm_policy.hpp:16
static constexpr CK_TILE_HOST_DEVICE auto GetAlignment_O()
Definition: fused_moegemm_pipeline_flatmm_policy.hpp:51
static constexpr CK_TILE_HOST_DEVICE auto GetAlignment_D()
Definition: fused_moegemm_pipeline_flatmm_policy.hpp:42
static constexpr CK_TILE_HOST_DEVICE auto MakeBridgeLdsStoreForUKDesc()
Definition: fused_moegemm_pipeline_flatmm_policy.hpp:518
static constexpr CK_TILE_HOST_DEVICE auto GetAlignment_A()
Definition: fused_moegemm_pipeline_flatmm_policy.hpp:23
static constexpr CK_TILE_HOST_DEVICE auto GetSequencer_0()
Definition: fused_moegemm_pipeline_flatmm_policy.hpp:586
static constexpr CK_TILE_HOST_DEVICE auto MakeYBlockTile()
Definition: fused_moegemm_pipeline_flatmm_policy.hpp:781
static constexpr CK_TILE_HOST_DEVICE auto MakeGlobalTileDistribution_A()
Definition: fused_moegemm_pipeline_flatmm_policy.hpp:230
static constexpr CK_TILE_HOST_DEVICE auto MakeLdsStoreDesc_A()
Definition: fused_moegemm_pipeline_flatmm_policy.hpp:300
static constexpr CK_TILE_HOST_DEVICE auto GetSequencer_1()
Definition: fused_moegemm_pipeline_flatmm_policy.hpp:638
static constexpr CK_TILE_HOST_DEVICE auto MakeBridgeLdsLoadDesc()
Definition: fused_moegemm_pipeline_flatmm_policy.hpp:484
static constexpr CK_TILE_HOST_DEVICE ck_tile::index_t GetSmemSize()
Definition: fused_moegemm_pipeline_flatmm_policy.hpp:111
static constexpr CK_TILE_HOST_DEVICE auto MakeLdsLoadDesc_A()
Definition: fused_moegemm_pipeline_flatmm_policy.hpp:389
static constexpr CK_TILE_HOST_DEVICE auto MakeCBlockTile_Gemm0()
Definition: fused_moegemm_pipeline_flatmm_policy.hpp:713
static constexpr CK_TILE_HOST_DEVICE auto GetUK_1()
Definition: fused_moegemm_pipeline_flatmm_policy.hpp:810
static constexpr CK_TILE_HOST_DEVICE auto GetSmemKPack_A()
Definition: fused_moegemm_pipeline_flatmm_policy.hpp:78
static constexpr CK_TILE_HOST_DEVICE auto MakeGlobalTileDistribution_Nr_Kr_W()
Definition: fused_moegemm_pipeline_flatmm_policy.hpp:216
static constexpr CK_TILE_HOST_DEVICE auto GetUK_0()
Definition: fused_moegemm_pipeline_flatmm_policy.hpp:790
static constexpr CK_TILE_HOST_DEVICE auto MakeGlobalTileDistribution_G()
Definition: fused_moegemm_pipeline_flatmm_policy.hpp:243
static constexpr CK_TILE_HOST_DEVICE auto GetWarpGemm1()
Definition: fused_moegemm_pipeline_flatmm_policy.hpp:689
Definition: warp_gemm_attribute_mfma_impl.hpp:1596
Definition: warp_gemm_attribute_mfma_impl.hpp:448
Definition: warp_gemm_impl.hpp:11
Definition: integral_constant.hpp:13
Definition: sequence.hpp:49
Definition: tile_distribution_encoding.hpp:26
Definition: tuple.hpp:192