/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_v3_pipeline.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_v3_pipeline.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_v3_pipeline.hpp Source File
block_fmha_fwd_v3_pipeline.hpp
Go to the documentation of this file.
1 // Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
2 // SPDX-License-Identifier: MIT
3 
4 #pragma once
5 
6 #include "ck_tile/core.hpp"
11 
12 #define ENABLE_ASM_MARKER 1
13 #if ENABLE_ASM_MARKER
14 #define ASM_MARKER(marker) \
15  __builtin_amdgcn_sched_barrier(0); \
16  asm volatile("; [POYENC] " #marker); \
17  __builtin_amdgcn_sched_barrier(0);
18 #else
19 #define ASM_MARKER(marker)
20 #endif
21 
22 #define ADD_SBARRIER_FOR_PHASE0 1
23 #if !defined(CK_TILE_DISABLE_PACKED_FP32)
24 #define CK_TILE_DISABLE_PACKED_FP32 0
25 #endif
26 
27 #define WARP_ID 0
28 #define LANE_ID 0
29 
30 #define ENABLE_DEBUG_STMTS 1
31 #if ENABLE_DEBUG_STMTS
32 #define DEBUG_STMTS \
33  if(get_block_1d_id() == 0 && get_warp_id() == WARP_ID && get_lane_id() == LANE_ID)
34 #else
35 #define DEBUG_STMTS if constexpr(false)
36 #endif
37 
38 namespace ck_tile {
39 
40 template <typename PipelineProblem, bool kIsMasking>
42 
43 template <typename PipelineProblem>
44 struct CoreLoopScheduler<PipelineProblem, /*kIsMasking=*/true>
45 {
46  template <ck_tile::index_t WaveGroup, ck_tile::index_t Phase>
49  {
50  using namespace ck_tile;
51 
52  if constexpr(WaveGroup == 0)
53  {
54  if constexpr(Phase == 0)
55  {
56  static_for<0, 8, 1>{}([&](auto) {
57  __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
58  __builtin_amdgcn_sched_group_barrier(0x200, 2, 0); // TRANS
59  __builtin_amdgcn_sched_group_barrier(0x002, 2, 0); // VALU
60  });
61  }
62  else if constexpr(Phase == 1)
63  {
64  __builtin_amdgcn_sched_group_barrier(0x002, 2, 0); // VALU
65  __builtin_amdgcn_sched_group_barrier(0x004, 4, 0); // SALU
66  }
67  else if constexpr(Phase == 2)
68  {
69 #if !CK_TILE_DISABLE_PACKED_FP32
70  __builtin_amdgcn_sched_group_barrier(0x002, 4, 0); // VALU
71 #endif
72  static_for<0, 8, 1>{}([&](auto) {
73  __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
74  __builtin_amdgcn_sched_group_barrier(0x002, 4, 0); // VALU
75  });
76  }
77  else if constexpr(Phase == 3)
78  {
79  __builtin_amdgcn_sched_group_barrier(0x002, 2, 0); // VALU
80  __builtin_amdgcn_sched_group_barrier(0x004, 4, 0); // SALU
81  }
82  }
83  else
84  {
85  if constexpr(Phase == 0)
86  {
87  __builtin_amdgcn_sched_group_barrier(0x002, 2, 0); // VALU
88  __builtin_amdgcn_sched_group_barrier(0x004, 4, 0); // SALU
89  }
90  else if constexpr(Phase == 1)
91  {
92  static_for<0, 8, 1>{}([&](auto) {
93  __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
94  __builtin_amdgcn_sched_group_barrier(0x200, 2, 0); // TRANS
95  __builtin_amdgcn_sched_group_barrier(0x002, 2, 0); // VALU
96  });
97  }
98  else if constexpr(Phase == 2)
99  {
100  __builtin_amdgcn_sched_group_barrier(0x002, 2, 0); // VALU
101  __builtin_amdgcn_sched_group_barrier(0x004, 4, 0); // SALU
102  }
103  else if constexpr(Phase == 3)
104  {
105 #if !CK_TILE_DISABLE_PACKED_FP32
106  __builtin_amdgcn_sched_group_barrier(0x002, 4, 0); // VALU
107 #endif
108  static_for<0, 8, 1>{}([&](auto) {
109  __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
110  __builtin_amdgcn_sched_group_barrier(0x002, 4, 0); // VALU
111  });
112  }
113  }
114  }
115 };
116 
117 template <typename PipelineProblem>
118 struct CoreLoopScheduler<PipelineProblem, /*kIsMasking=*/false>
119 {
120  template <ck_tile::index_t WaveGroup, ck_tile::index_t Phase>
123  {
124  using namespace ck_tile;
125 
126  if constexpr(WaveGroup == 0)
127  {
128  if constexpr(Phase == 0)
129  {
130  static_for<0, 8, 1>{}([&](auto) {
131  __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
132  __builtin_amdgcn_sched_group_barrier(0x200, 2, 0); // TRANS
133  __builtin_amdgcn_sched_group_barrier(0x002, 2, 0); // VALU
134  });
135  }
136  else if constexpr(Phase == 1)
137  {
138  __builtin_amdgcn_sched_group_barrier(0x002, 2, 0); // VALU
139  __builtin_amdgcn_sched_group_barrier(0x004, 4, 0); // SALU
140  }
141  else if constexpr(Phase == 2)
142  {
143 #if !CK_TILE_DISABLE_PACKED_FP32
144  __builtin_amdgcn_sched_group_barrier(0x002, 4, 0); // VALU
145 #endif
146  static_for<0, 8, 1>{}([&](auto) {
147  __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
148  __builtin_amdgcn_sched_group_barrier(0x002, 4, 0); // VALU
149  });
150  }
151  else if constexpr(Phase == 3)
152  {
153  __builtin_amdgcn_sched_group_barrier(0x002, 2, 0); // VALU
154  __builtin_amdgcn_sched_group_barrier(0x004, 4, 0); // SALU
155  }
156  }
157  else
158  {
159  if constexpr(Phase == 0)
160  {
161  __builtin_amdgcn_sched_group_barrier(0x002, 2, 0); // VALU
162  __builtin_amdgcn_sched_group_barrier(0x004, 4, 0); // SALU
163  }
164  else if constexpr(Phase == 1)
165  {
166  static_for<0, 8, 1>{}([&](auto) {
167  __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
168  __builtin_amdgcn_sched_group_barrier(0x200, 2, 0); // TRANS
169  __builtin_amdgcn_sched_group_barrier(0x002, 2, 0); // VALU
170  });
171  }
172  else if constexpr(Phase == 2)
173  {
174  __builtin_amdgcn_sched_group_barrier(0x002, 2, 0); // VALU
175  __builtin_amdgcn_sched_group_barrier(0x004, 4, 0); // SALU
176  }
177  else if constexpr(Phase == 3)
178  {
179 #if !CK_TILE_DISABLE_PACKED_FP32
180  __builtin_amdgcn_sched_group_barrier(0x002, 4, 0); // VALU
181 #endif
182  static_for<0, 8, 1>{}([&](auto) {
183  __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
184  __builtin_amdgcn_sched_group_barrier(0x002, 4, 0); // VALU
185  });
186  }
187  }
188  }
189 };
190 
191 namespace detail {
192 CK_TILE_DEVICE float fma_impl_vsv(float a, float b, float c)
193 {
194 #if CK_TILE_DISABLE_PACKED_FP32
195  return a * b + c;
196 #else
197  float result;
198  asm volatile("v_fma_f32 %[result], %[a], %[b], %[c]"
199  : [result] "=v"(result)
200  : [a] "v"(a), [b] "s"(b), [c] "v"(c));
201  return result;
202 #endif
203 }
204 
205 CK_TILE_DEVICE float add_impl_vv(float lhs, float rhs)
206 {
207  float result;
208  asm volatile("v_add_f32_e32 %[result], %[lhs], %[rhs]"
209  : [result] "=v"(result)
210  : [lhs] "v"(lhs), [rhs] "v"(rhs));
211  return result;
212 }
213 
214 CK_TILE_DEVICE float mul_impl_vv(float lhs, float rhs)
215 {
216  float result;
217  asm volatile("v_mul_f32_e32 %[result], %[lhs], %[rhs]"
218  : [result] "=v"(result)
219  : [lhs] "v"(lhs), [rhs] "v"(rhs));
220  return result;
221 }
222 
224 {
225  fp16x2_t result;
226  asm volatile("v_cvt_pk_f16_f32 %[result], %[a], %[b]"
227  : [result] "=v"(result)
228  : [a] "v"(a), [b] "v"(b));
229  return result;
230 }
231 
233 {
234  bf16x2_t result;
235  asm volatile("v_cvt_pk_bf16_f32 %[result], %[a], %[b]"
236  : [result] "=v"(result)
237  : [a] "v"(a), [b] "v"(b));
238  return result;
239 }
240 
242 {
243  fp32x2_t result;
244  asm volatile("v_pk_mul_f32 %[result], %[lhs], %[rhs]"
245  : [result] "=v"(result)
246  : [lhs] "v"(lhs), [rhs] "v"(rhs));
247  return result;
248 }
249 } // namespace detail
250 
253 template <typename Problem_, typename Policy_ = BlockFmhaV3PipelineDefaultPolicy>
255 {
268  static_assert(is_generic_attention_mask_v<FmhaMask>);
269 
270  static_assert(std::is_same_v<SaccDataType, SMPLComputeDataType>,
271  "we will the same dist tensor 'sp_compute' for both gemm0 & softmax");
272 
274 
276  static_assert(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>);
277 
278  static constexpr ck_tile::index_t kBlockSize = Problem::kBlockSize;
279 
280  static constexpr ck_tile::index_t kM0 = BlockFmhaShape::kM0;
281  static constexpr ck_tile::index_t kN0 = BlockFmhaShape::kN0;
282  static constexpr ck_tile::index_t kK0 = BlockFmhaShape::kK0;
283  static constexpr ck_tile::index_t kN1 = BlockFmhaShape::kN1;
284  static constexpr ck_tile::index_t kK1 = BlockFmhaShape::kK1;
285  static constexpr ck_tile::index_t kQKHeaddim = BlockFmhaShape::kQKHeaddim;
286  static constexpr ck_tile::index_t kSubQKHeaddim = BlockFmhaShape::kSubQKHeaddim;
287 
288  static_assert(kQKHeaddim == 128 && kSubQKHeaddim == 128, "only supports hdim=hdim_v=128");
289 
290  static constexpr bool kIsGroupMode = Problem::kIsGroupMode;
291  static constexpr bool kPadSeqLenQ = Problem::kPadSeqLenQ;
292  static constexpr bool kPadSeqLenK = Problem::kPadSeqLenK;
293  static constexpr bool kPadHeadDimQ = Problem::kPadHeadDimQ;
294  static constexpr bool kPadHeadDimV = Problem::kPadHeadDimV;
295  static constexpr bool kHasLogitsSoftCap = Problem::kHasLogitsSoftCap;
296  static constexpr auto BiasEnum = Problem::BiasEnum;
297  static constexpr bool kStoreLSE = Problem::kStoreLSE;
298  static constexpr bool kHasDropout = Problem::kHasDropout;
299  static constexpr auto QScaleEnum = Problem::QScaleEnum;
300  static constexpr bool kSkipMinSeqlenQ = Problem::kSkipMinSeqlenQ;
301  static_assert((!kHasLogitsSoftCap && BiasEnum == BlockAttentionBiasEnum::NO_BIAS &&
302  !kStoreLSE && !kHasDropout &&
304  !kSkipMinSeqlenQ),
305  "enable unsupported features");
306 
307  // last dimension vector length used to create tensor view(and decide buffer_load vector length)
308  // ... together with tensor distribution. tensor dist should able to overwrite this
309  static constexpr ck_tile::index_t kAlignmentQ =
310  kPadHeadDimQ ? 1 : Policy::template GetAlignmentQ<Problem>();
311  static constexpr ck_tile::index_t kAlignmentK =
312  kPadHeadDimQ ? 1 : Policy::template GetAlignmentK<Problem>();
313  static constexpr ck_tile::index_t kAlignmentV =
314  kPadHeadDimV ? 1 : Policy::template GetAlignmentV<Problem>();
315 
316  static constexpr ck_tile::index_t kAlignmentO =
317  kPadHeadDimV ? 1 : Policy::template GetAlignmentO<Problem>();
318 
319  static constexpr ck_tile::index_t kBlockPerCu = []() {
320  if constexpr(Problem::kBlockPerCu != -1)
321  return Problem::kBlockPerCu;
322  else
323  {
324  return 2;
325  }
326  }();
327 
329  {
330  // create another LDS buffer for p
331  return ck_tile::max(kM0 * kN1 * sizeof(PDataType),
332  Policy::template GetSmemSize<Problem>() +
333  kM0 * kN0 * sizeof(PDataType));
334  }
335 
336  // for debug only
337  template <ck_tile::index_t MPerBlock, ck_tile::index_t NPerBlock>
338  CK_TILE_DEVICE static constexpr auto MakeSimpleLdsDesc()
339  {
340  using namespace ck_tile;
341  constexpr auto lds_block_desc =
344  number<1>{},
345  number<1>{});
346 
347  return lds_block_desc;
348  }
349 
350  // for debug only
351  template <ck_tile::index_t MPerBlock>
352  CK_TILE_DEVICE static constexpr auto MakeSimpleLdsDesc1D()
353  {
354  using namespace ck_tile;
355  constexpr auto lds_block_desc = make_naive_tensor_descriptor(
357 
358  return lds_block_desc;
359  }
360 
361  template <typename DataType, typename Descriptor>
362  CK_TILE_DEVICE static constexpr auto make_lds_tile_window(void* base, const Descriptor& desc)
363  {
364  using namespace ck_tile;
365 
366  auto tensor_view =
367  make_tensor_view<address_space_enum::lds>(reinterpret_cast<DataType*>(base), desc);
368  return make_tile_window(tensor_view, desc.get_lengths(), {0, 0});
369  }
370 
371  // vmcnt=0~63, lgkmcnt=0~15, expcnt=0~7
372  template <uint16_t Vmcnt, uint8_t Lgkmcnt, uint8_t Expcnt = 7>
373  CK_TILE_DEVICE static constexpr void s_waitcnt()
374  {
375  // vmcnt use bits {[15:14],[3:0]}
376  // expcnt use bits [6:4]
377  // lgkmcnt use bits [11:8]
378  __builtin_amdgcn_s_waitcnt((((0b110000 & Vmcnt) << (14 - 4)) | (0b1111 & Vmcnt)) |
379  ((0b111 & Expcnt) << 4) | ((0b1111 & Lgkmcnt) << 8));
380  }
381 
382  template <uint16_t Vmcnt>
383  CK_TILE_DEVICE static constexpr void s_waitcnt_vmcnt()
384  {
385  s_waitcnt<Vmcnt, 15>();
386  }
387 
388  template <uint8_t Lgkmcnt>
389  CK_TILE_DEVICE static constexpr void s_waitcnt_lgkmcnt()
390  {
391  s_waitcnt<63, Lgkmcnt>();
392  }
393 
394  template <typename QDramBlockWindowTmp,
395  typename KDramBlockWindowTmp,
396  typename VDramBlockWindowTmp,
397  typename LSEDramBlockWindowTmp,
398  typename QElementFunction,
399  typename KElementFunction,
400  typename VElementFunction,
401  typename LSEElementFunction,
402  typename SAccElementFunction,
403  typename PComputeElementFunction,
404  typename OAccElementFunction>
405  CK_TILE_DEVICE auto operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile
406  const QElementFunction& q_element_func,
407  const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile
408  [[maybe_unused]] const KElementFunction& k_element_func,
409  const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile
410  [[maybe_unused]] const VElementFunction& v_element_func,
411  LSEDramBlockWindowTmp& lse_dram_window_tmp, // M0*1 tile
412  const LSEElementFunction& lse_element_func,
413  [[maybe_unused]] const SAccElementFunction& s_acc_element_func,
414  const PComputeElementFunction& p_compute_element_func,
415  const OAccElementFunction& o_acc_element_func,
416  FmhaMask mask,
417  float scale_s,
418  void* smem_ptr) const
419  {
420  using namespace ck_tile;
421 
422  static_assert(
426  "wrong!");
427 
428  static_assert(kM0 == QDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
429  kN0 == KDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
430  kK0 == KDramBlockWindowTmp{}.get_window_lengths()[number<1>{}] &&
431  kK1 == VDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
432  kN1 == VDramBlockWindowTmp{}.get_window_lengths()[number<1>{}],
433  "wrong!");
434 
435  static_assert(sizeof(SaccDataType) * kM0 * kN0 <= GetSmemSize());
436  auto s_lds = make_tensor_view<address_space_enum::lds>(
437  reinterpret_cast<SaccDataType*>(static_cast<char*>(smem_ptr)),
438  MakeSimpleLdsDesc<kM0, kN0>());
439  [[maybe_unused]] auto s_lds_window =
440  make_tile_window(s_lds, make_tuple(number<kM0>{}, number<kN0>{}), {0, 0});
441 
442  auto p_lds = make_tensor_view<address_space_enum::lds>(
443  reinterpret_cast<PDataType*>(static_cast<char*>(smem_ptr) +
444  Policy::template GetSmemSize<Problem>()),
445  MakeSimpleLdsDesc<kM0, kN0>());
446  [[maybe_unused]] auto p_lds_window =
447  make_tile_window(p_lds, make_tuple(number<kM0>{}, number<kN0>{}), {0, 0});
448 
449  auto o_lds = make_tensor_view<address_space_enum::lds>(
450  reinterpret_cast<PDataType*>(static_cast<char*>(smem_ptr)),
451  MakeSimpleLdsDesc<kM0, kN1>());
452  [[maybe_unused]] auto o_lds_window =
453  make_tile_window(o_lds, make_tuple(number<kM0>{}, number<kN1>{}), {0, 0});
454 
455  auto m_lds = make_tensor_view<address_space_enum::lds>(
456  reinterpret_cast<SMPLComputeDataType*>(static_cast<char*>(smem_ptr) +
457  Policy::template GetSmemSize<Problem>()),
458  MakeSimpleLdsDesc1D<kM0>());
459  [[maybe_unused]] auto m_lds_window =
460  make_tile_window(m_lds, make_tuple(number<kM0>{}), {0});
461 
462  const index_t warp_group_id = get_warp_id() / 4;
463 
464  // Block GEMM
465  constexpr auto gemm_0 = Policy::template GetQKBlockGemm<Problem>();
466  constexpr auto gemm_1 = Policy::template GetPVBlockGemm<Problem>();
467 
468  auto q_dram_window = make_tile_window_linear(
469  q_dram_block_window_tmp, Policy::template MakeQRegTileDistribution<Problem>());
470 
471  // reduction function for softmax
472  const auto f_max = [](auto e0, auto e1) { return max(e0, e1); };
473  const auto f_sum = [](auto e0, auto e1) { return e0 + e1; };
474 
475  auto k_lds_window_store = generate_tuple(
476  [&](auto i_buf) {
477  return make_lds_tile_window<KDataType>(
478  smem_ptr, Policy::template MakeKLdsStoreBlockDescriptor<Problem>(i_buf));
479  },
480  number<2>{});
481 
482  auto v_lds_window_store = generate_tuple(
483  [&](auto i_buf) {
484  return make_lds_tile_window<KDataType>(
485  smem_ptr, Policy::template MakeVLdsStoreBlockDescriptor<Problem>(i_buf));
486  },
487  number<2>{});
488 
490  make_lds_tile_window<KDataType>(
491  nullptr,
492  Policy::template MakeKLdsLoadBlockDescriptor<Problem>()),
493  Policy::template MakeKRegTileDistribution<Problem>())),
494  2>
495  k_lds_window_load;
496 
498  make_lds_tile_window<VDataType>(
499  nullptr,
500  Policy::template MakeVLdsLoadBlockDescriptor<Problem>()),
501  Policy::template MakeVRegTileDistribution<Problem>())),
502  2>
503  v_lds_window_load;
504 
505  decltype(make_static_distributed_tensor<QDataType>(
506  Policy::template MakeQRegTileDistribution<Problem>())) q_tile;
507 
508  union kv_tile_type
509  {
510  CK_TILE_DEVICE kv_tile_type() {}
511 
512  decltype(load_tile(k_lds_window_load(number<0>{}))) k_tile;
513 
514  decltype(load_tile_transpose(v_lds_window_load(number<0>{}))) v_tile;
515  } kv_tile;
516 
517  union sp_compute_type
518  {
519  CK_TILE_DEVICE sp_compute_type() {}
520 
521  decltype(gemm_0.MakeCBlockTile()) sp_compute;
522  decltype(make_static_distributed_tensor<PDataType>(
523  Policy::template MakePRegTileDistribution<Problem>())) p;
524  };
526 
527  decltype(gemm_1.MakeCBlockTile()) o_acc;
528  constexpr index_t fmha_alu_D_reg_cnt = 6; // threshold to decide how many fmha_alu_D_upd()
529  // instructions should we move to fmha_alu1()
530  static_assert(fmha_alu_D_reg_cnt <= o_acc.thread_buf_.size());
531 
532  decltype(block_tile_reduce<SMPLComputeDataType>(
533  sp(number<0>{}).sp_compute, sequence<1>{}, f_max, SMPLComputeDataType{0})) m;
534  decltype(m) l;
535 
536  // initialize k_lds_window and v_lds_window
537  static_for<0, 2, 1>{}([&](auto idx) {
538  k_lds_window_load(idx) = make_tile_window(
539  make_lds_tile_window<KDataType>(
540  static_cast<char*>(smem_ptr) + (idx)*Policy::template GetSmemSizeKV<Problem>(),
541  Policy::template MakeKLdsLoadBlockDescriptor<Problem>()),
542  Policy::template MakeKRegTileDistribution<Problem>());
543  });
544 
545  static_for<0, 2, 1>{}([&](auto idx) {
546  v_lds_window_load(idx) =
547  make_tile_window(make_lds_tile_window<VDataType>(
548  static_cast<char*>(smem_ptr) +
549  (idx + 2) * Policy::template GetSmemSizeKV<Problem>(),
550  Policy::template MakeVLdsLoadBlockDescriptor<Problem>()),
551  Policy::template MakeVRegTileDistribution<Problem>());
552  });
553 
554  {
555  auto origin_q = load_tile(q_dram_window);
556  auto transformed_q = tile_elementwise_in(q_element_func, origin_q);
557 
558  q_tile = transformed_q;
559  }
560 
561  clear_tile(o_acc);
562  set_tile(m, bit_cast<float>(0xff7fffff)); // a bit larger than -infinity
563  clear_tile(l);
564 
565  const auto q_origin = q_dram_window.get_window_origin();
566  const auto [seqlen_k_start, seqlen_k_end] =
567  mask.GetTileRangeAlongX(q_origin.at(number<0>{}), number<kM0>{}, number<kN0>{});
568 
569  const auto num_total_loop = integer_divide_ceil(seqlen_k_end - seqlen_k_start, kN0);
570  index_t kv_token_start = seqlen_k_start;
571 
572  // check early exit if no work to do
573  if constexpr(FmhaMask::IsMasking || kPadSeqLenK)
574  {
575  if(num_total_loop <= 0)
576  {
577  if constexpr(kStoreLSE)
578  {
579  auto lse =
580  make_static_distributed_tensor<LSEDataType>(m.get_tile_distribution());
581 
583 
584  store_tile(lse_dram_window_tmp, tile_elementwise_in(lse_element_func, lse));
585  }
586 
587  // Note: here occ are all cleard, return it
588  // Note: q loaded but no fence, ignore it.
589  return o_acc;
590  }
591  }
592 
593  auto k_dram_window =
594  make_tile_window(k_dram_block_window_tmp.get_bottom_tensor_view(),
595  k_dram_block_window_tmp.get_window_lengths(),
596  {seqlen_k_start, 0},
597  Policy::template MakeKDramTileDistribution<Problem>());
598  k_dram_window.init_raw();
599 
600  auto v_dram_window =
601  make_tile_window(v_dram_block_window_tmp.get_bottom_tensor_view(),
602  v_dram_block_window_tmp.get_window_lengths(),
603  {seqlen_k_start, 0}, // TODO: hdim split?
604  Policy::template MakeVDramTileDistribution<Problem>());
605  v_dram_window.init_raw();
606 
607  // prefetch K tile
608  index_t i_total_loops = 0;
609  constexpr index_t k0_loops = kQKHeaddim / kK0;
610  constexpr index_t k1_loops = kN0 / kK1;
611  static_assert(1 == k0_loops);
612  static_assert(1 == k1_loops);
613  static_assert(kN0 == kK1);
614 
615  constexpr index_t NumWarpGroups = Problem::kBlockSize / Policy::NumThreadPerWarpGroup;
616  static_assert(NumWarpGroups == 2);
617 
618  [[maybe_unused]] auto print_dist_tensor = [&](const auto& dist_tensor, const char* name) {
619  printf("[POYENC] %s (size=%d): %5.2f",
620  name,
621  decltype(dist_tensor.thread_buf_)::size(),
622  ck_tile::type_convert<float>(dist_tensor.thread_buf_[0]));
623  static_for<1, decltype(dist_tensor.thread_buf_)::size(), 1>{}([&](auto i) {
624  printf(", %5.2f", ck_tile::type_convert<float>(dist_tensor.thread_buf_[i]));
625  });
626  printf("\n");
627  };
628 
629  [[maybe_unused]] auto print_lds = [&](auto lds_tile_window, const char* name) {
630  const auto num_rows = lds_tile_window.get_window_lengths().at(number<0>{});
631  const auto num_cols = lds_tile_window.get_window_lengths().at(number<1>{});
632 
633  auto desc = lds_tile_window.get_bottom_tensor_view().desc_;
634  auto data = lds_tile_window.get_bottom_tensor_view().buf_.p_data_;
635 
636  if constexpr(true || num_rows < num_cols)
637  {
638  for(int row = 0; row < num_rows; ++row)
639  {
640  int offset = desc.calculate_offset(make_tuple(row, 0));
641  printf("[DEVICE] %s[%3d] = %5.2f",
642  name,
643  row,
644  ck_tile::type_convert<float>(data[offset]));
645  for(int col = 1; col < num_cols; ++col)
646  {
647  printf(", ");
648  offset = desc.calculate_offset(make_tuple(row, col));
649  printf("%5.2f", ck_tile::type_convert<float>(data[offset]));
650  }
651  printf("\n");
652  }
653  }
654  else
655  {
656  for(int col = 0; col < num_cols; ++col)
657  {
658  int offset = desc.calculate_offset(make_tuple(0, col));
659  printf("[DEVICE] %s[%3d] = %5.2f",
660  name,
661  col,
662  ck_tile::type_convert<float>(data[offset]));
663  for(int row = 1; row < num_rows; ++row)
664  {
665  printf(", ");
666  offset = desc.calculate_offset(make_tuple(row, col));
667  printf("%5.2f", ck_tile::type_convert<float>(data[offset]));
668  }
669  printf("\n");
670  }
671  }
672  };
673 
674  [[maybe_unused]] auto print_lds_1d = [&](auto lds_tile_window, const char* name) {
675  const auto num_elems = lds_tile_window.get_window_lengths().at(number<0>{});
676 
677  auto desc = lds_tile_window.get_bottom_tensor_view().desc_;
678  auto data = lds_tile_window.get_bottom_tensor_view().buf_.p_data_;
679 
680  int offset = desc.calculate_offset(make_tuple(0));
681  printf("[DEVICE] %s = %5.2f", name, ck_tile::type_convert<float>(data[offset]));
682  for(int e = 1; e < num_elems; ++e)
683  {
684  printf(", ");
685  offset = desc.calculate_offset(make_tuple(e));
686  printf("%5.2f", ck_tile::type_convert<float>(data[offset]));
687  }
688  printf("\n");
689  };
690 
691  // K_mem_su_ld_insts = 1 for 32 x 128
692  // V_mem_su_ld_insts = 1 for 128 x 32
693  constexpr int K_mem_su_ld_insts = k_dram_window.get_num_of_access();
694  constexpr int V_mem_su_ld_insts = v_dram_window.get_num_of_access();
695 
696  auto K_mem_load = [&](auto k_lds_write_idx) {
697  async_load_tile_raw(k_lds_window_store(k_lds_write_idx), k_dram_window);
698 
700  // move K tile windows
701  move_tile_window(k_dram_window, {kN0, 0});
702  };
703 
704  auto K_lds_load = [&](auto k_lds_read_idx) {
705  kv_tile.k_tile = load_tile(k_lds_window_load(k_lds_read_idx));
706  };
707 
708  auto V_mem_load = [&](auto v_lds_write_idx) {
709  async_load_tile_raw(v_lds_window_store(v_lds_write_idx), v_dram_window);
710 
712  move_tile_window(v_dram_window, {kK1, 0});
713  };
714 
715  auto V_lds_load = [&](auto v_lds_read_idx) {
716  kv_tile.v_tile = load_tile_transpose(v_lds_window_load(v_lds_read_idx));
717  };
718 
719  decltype(m) m_old;
720  SMPLComputeDataType o_acc_scale; // rescale o_acc in fmha_alu1() & fmha_alu_D_upd()
722  statically_indexed_array<decltype(sp(number<0>{}).sp_compute), 2> sp_delta;
723 
724  auto fmha_alu0 = [&](auto sp_reg_idx) {
725  m_old = m; // m{j-1}
726  static_assert(m.thread_buf_.size() == 1,
727  "assuming that each thread holds 1 rowmax value");
728  auto m_latest = block_tile_reduce<SMPLComputeDataType>(
729  sp(sp_reg_idx).sp_compute, sequence<1>{}, f_max, m.thread_buf_[0]);
730 #if defined(__gfx950__)
731  // assuming that we are using 32x32 mfma
732  int32x2_t swapped_regs =
733  __builtin_amdgcn_permlane32_swap(bit_cast<int32_t>(m_latest.thread_buf_[0]),
734  bit_cast<int32_t>(m_latest.thread_buf_[0]),
735  false,
736  false);
738  m_latest.thread_buf_[0] = f_max(bit_cast<SMPLComputeDataType>(swapped_regs.x),
739  bit_cast<SMPLComputeDataType>(swapped_regs.y));
740 #else
741  block_tile_reduce_sync(m_latest, f_max, bool_constant<false>{});
742 #endif
743  m = m_latest;
744 
745  constexpr auto p_spans =
746  std::decay_t<decltype(sp(sp_reg_idx).sp_compute)>::get_distributed_spans();
747  sweep_tile_span(p_spans[number<0>{}], [&](auto idx0) {
748  sweep_tile_span(p_spans[number<1>{}], [&](auto idx1) {
749  constexpr auto i_j_idx = make_tuple(idx0, idx1);
750  sp_delta(sp_reg_idx)(i_j_idx) = detail::fma_impl_vsv(
751  sp(sp_reg_idx).sp_compute(i_j_idx), scale_s, -scale_s * m(i_j_idx));
752  });
753  });
755  };
756 
757  auto fmha_alu1 = [&](auto sp_reg_idx) {
758  constexpr auto p_spans =
759  std::decay_t<decltype(sp(sp_reg_idx).sp_compute)>::get_distributed_spans();
760  sweep_tile_span(p_spans[number<0>{}], [&](auto idx0) {
761  sweep_tile_span(p_spans[number<1>{}], [&](auto idx1) {
762  constexpr auto i_j_idx = make_tuple(idx0, idx1);
763  sp(sp_reg_idx).sp_compute(i_j_idx) =
764  ck_tile::exp2(sp_delta(sp_reg_idx)(i_j_idx));
765  });
766  });
767 
768  auto rowsum_p = block_tile_reduce<SMPLComputeDataType>(
769  sp(sp_reg_idx).sp_compute,
770  sequence<1>{},
771  f_sum,
772  SMPLComputeDataType{0}); // rowsum(Pcompute{j})
773  static_assert(rowsum_p.thread_buf_.size() == 1,
774  "assuming that each thread holds 1 rowsum value");
775 #if defined(__gfx950__)
776  // assuming that we are using 32x32 mfma
777  int32x2_t swapped_regs =
778  __builtin_amdgcn_permlane32_swap(bit_cast<int32_t>(rowsum_p.thread_buf_[0]),
779  bit_cast<int32_t>(rowsum_p.thread_buf_[0]),
780  false,
781  false);
782  rowsum_p.thread_buf_[0] = f_sum(bit_cast<SMPLComputeDataType>(swapped_regs.x),
783  bit_cast<SMPLComputeDataType>(swapped_regs.y));
784 #else
785  block_tile_reduce_sync(rowsum_p, f_sum, bool_constant<false>{});
786 #endif
787 
788  // l{j}
793  constexpr auto o_spans = decltype(o_acc)::get_distributed_spans();
794  sweep_tile_span(o_spans[number<0>{}], [&](auto idx0) {
795  constexpr auto i_idx = make_tuple(idx0);
796  const auto tmp = ck_tile::exp2(scale_s * (m_old[i_idx] - m[i_idx]));
797 
798  l(i_idx) = detail::add_impl_vv(tmp * l[i_idx], rowsum_p[i_idx]);
799  });
800 
801  // update partial o_acc [0, fmha_alu_D_reg_cnt)
802  static_for<0, fmha_alu_D_reg_cnt, 1>{}([&](auto idx) {
803  o_acc.thread_buf_[idx] = detail::mul_impl_vv(o_acc.thread_buf_[idx], o_acc_scale);
804  });
805 
810  static_assert(sp(sp_reg_idx).p.thread_buf_.size() % 2 == 0);
811  static_for<0, sp(sp_reg_idx).p.thread_buf_.size(), 2>{}([&](auto idx) {
812  float x = p_compute_element_func(sp(sp_reg_idx).sp_compute.thread_buf_[idx]);
813  float y = p_compute_element_func(sp(sp_reg_idx).sp_compute.thread_buf_[idx + 1]);
814  if constexpr(std::is_same_v<PDataType, fp16_t>)
815  {
816  auto casted = detail::cvt_pk_fp16_f32(x, y);
817  sp(sp_reg_idx).p.thread_buf_[idx] = casted.x;
818  sp(sp_reg_idx).p.thread_buf_[idx + 1] = casted.y;
819  }
820  else
821  {
822  auto casted = detail::cvt_pk_bf16_f32(x, y);
823  sp(sp_reg_idx).p.thread_buf_[idx] = casted.x;
824  sp(sp_reg_idx).p.thread_buf_[idx + 1] = casted.y;
825  }
826  });
827 
831  };
832 
833  auto gemm = [&](auto sp_reg_idx, auto gemm_idx) {
834  if constexpr(gemm_idx == 0)
835  {
836  clear_tile(sp(sp_reg_idx).sp_compute); // initialize C
837  gemm_0(sp(sp_reg_idx).sp_compute,
838  get_slice_tile(q_tile,
839  sequence<0, (k0_loops - 1) * kK0>{},
841  get_slice_tile(kv_tile.k_tile,
842  sequence<0, (k0_loops - 1) * kK0>{},
844  }
845  else
846  {
847  gemm_1(o_acc,
848  get_slice_tile(sp(sp_reg_idx).p,
849  sequence<0, (k1_loops - 1) * kK1>{},
851  get_slice_tile(kv_tile.v_tile,
852  sequence<0, (k1_loops - 1) * kK1>{},
854  }
855  };
856 
857  auto cl_calc = [&](auto sp_reg_idx, auto gemm_idx) {
858  if constexpr(gemm_idx == 0)
859  {
860  clear_tile(sp(sp_reg_idx).sp_compute); // initialize C
861  gemm_0(sp(sp_reg_idx).sp_compute,
862  get_slice_tile(q_tile,
863  sequence<0, (k0_loops - 1) * kK0>{},
865  get_slice_tile(kv_tile.k_tile,
866  sequence<0, (k0_loops - 1) * kK0>{},
868  }
869  else
870  {
871  gemm_1(o_acc,
872  get_slice_tile(sp(sp_reg_idx).p,
873  sequence<0, (k1_loops - 1) * kK1>{},
875  get_slice_tile(kv_tile.v_tile,
876  sequence<0, (k1_loops - 1) * kK1>{},
878  fmha_alu0(number<1>{} - sp_reg_idx);
879  }
880  };
881 
882  auto fmha_alu_D_upd = [&] {
883  o_acc_scale = ck_tile::exp2(scale_s * (m_old.thread_buf_[0] - m.thread_buf_[0]));
884 
885  fp32x2_t pk_o_acc_scale;
886  pk_o_acc_scale.x = o_acc_scale;
887  pk_o_acc_scale.y = o_acc_scale;
888 
889  static_assert((o_acc.thread_buf_.size() - fmha_alu_D_reg_cnt) % 2 == 0);
890 #if CK_TILE_DISABLE_PACKED_FP32
891  static_assert(fmha_alu_D_reg_cnt + 2 <= o_acc.thread_buf_.size());
893  [&](auto idx) { o_acc.thread_buf_[idx] *= o_acc_scale; });
894 #endif
895 
896  constexpr auto issued_D_reg_cnt =
897 #if CK_TILE_DISABLE_PACKED_FP32
898  fmha_alu_D_reg_cnt + 2
899 #else
900  fmha_alu_D_reg_cnt
901 #endif
902  ;
905  // update partial o_acc after [issued_D_reg_cnt]
906  static_for<issued_D_reg_cnt, o_acc.thread_buf_.size(), 2>{}([&](auto idx) {
907  fp32x2_t input;
908  input.x = o_acc.thread_buf_[idx];
909  input.y = o_acc.thread_buf_[idx + 1];
910 
911  auto output = detail::pk_mul_f32(input, pk_o_acc_scale);
912 
913  o_acc.thread_buf_[idx] = output.x;
914  o_acc.thread_buf_[idx + 1] = output.y;
915  });
916  };
917 
918  auto fmha_mask = [&](auto sp_reg_idx) {
919  if constexpr(kPadSeqLenK || FmhaMask::IsMasking)
920  {
921  bool need_perpixel_check = mask.IsEdgeTile(
922  q_origin.at(number<0>{}), kv_token_start, number<kM0>{}, number<kN0>{});
923  if(need_perpixel_check)
924  {
925  set_tile_if(sp(sp_reg_idx).sp_compute,
927  [&](auto tile_idx) {
928  const auto row =
929  q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
930  const auto col = kv_token_start + tile_idx.at(number<1>{});
931  return mask.IsOutOfBound(row, col);
932  });
933  }
934  }
935  };
936 
937  auto cl_load = [&](auto load_type, auto mem_wr_idx, auto lds_rd_idx) {
938  if constexpr(load_type == 0)
939  {
940  V_mem_load(mem_wr_idx);
941  K_lds_load(lds_rd_idx);
942  }
943  else
944  {
945  K_mem_load(mem_wr_idx);
946  V_lds_load(lds_rd_idx);
947  }
948  };
949 
950  auto core_loop = [&](auto cl_p) {
951  auto gemm0 = number<0>{};
952  auto gemm1 = number<1>{};
953 
954  auto memV = number<0>{};
955  auto memK = number<1>{};
956 
958 
959  auto iteration = [&](auto pi) {
960  auto xdl_SP_p01_reg_idx = number<1>{} - pi;
961  auto xdl_SP_p23_reg_idx = pi;
962 
963  auto K_w0_lds_wr_idx = number<1>{} - pi;
964  auto V_w0_lds_wr_idx = pi;
965  auto K_w0_lds_rd_idx = pi;
966  auto V_w0_lds_rd_idx = pi;
967 
968  auto K_w4_lds_wr_idx = number<1>{} - pi;
969  auto V_w4_lds_wr_idx = number<1>{} - pi;
970  auto K_w4_lds_rd_idx = number<1>{} - pi;
971  auto V_w4_lds_rd_idx = pi;
972 
973  bool result = true;
974 
975  if constexpr(cl_p == 0)
976  {
977 #if ADD_SBARRIER_FOR_PHASE0
978  __builtin_amdgcn_sched_barrier(0);
979  __builtin_amdgcn_s_barrier();
980 #endif
981  __builtin_amdgcn_sched_barrier(0);
982  // phase0
983  if constexpr(pi == 0)
984  {
985  ASM_MARKER("phase0 Wave0-3 (pi=0)");
986  }
987  else
988  {
989  ASM_MARKER("phase0 Wave0-3 (pi=1)");
990  }
991  s_waitcnt_lgkmcnt<0>();
992  __builtin_amdgcn_sched_barrier(0);
993  cl_calc(xdl_SP_p01_reg_idx, gemm0);
994  fmha_alu1(xdl_SP_p23_reg_idx);
995 
996  Scheduler::schedule(cl_p, number<0>{});
997  __builtin_amdgcn_sched_barrier(0);
998  // phase1
999  ASM_MARKER("phase1 Wave0-3");
1000  s_waitcnt_vmcnt<K_mem_su_ld_insts + V_mem_su_ld_insts>();
1001  __builtin_amdgcn_sched_barrier(0);
1002  __builtin_amdgcn_s_barrier();
1003  __builtin_amdgcn_sched_barrier(0);
1004  cl_load(memK, K_w0_lds_wr_idx, V_w0_lds_rd_idx);
1005  Scheduler::schedule(cl_p, number<1>{});
1006  fmha_mask(xdl_SP_p01_reg_idx);
1007 
1008  __builtin_amdgcn_sched_barrier(0);
1009  // phase2
1010  ASM_MARKER("phase2 Wave0-3");
1011  s_waitcnt_lgkmcnt<0>();
1012  __builtin_amdgcn_sched_barrier(0);
1013  __builtin_amdgcn_s_barrier();
1014  __builtin_amdgcn_sched_barrier(0);
1015  asm volatile("s_nop 0");
1016  __builtin_amdgcn_sched_barrier(0);
1017  cl_calc(xdl_SP_p23_reg_idx, gemm1);
1018 
1019  Scheduler::schedule(cl_p, number<2>{});
1020  __builtin_amdgcn_sched_barrier(0);
1021  fmha_alu_D_upd();
1022 
1023  __builtin_amdgcn_sched_barrier(0);
1024  // phase3
1025  ASM_MARKER("phase3 Wave0-3");
1026  s_waitcnt_vmcnt<K_mem_su_ld_insts + V_mem_su_ld_insts>();
1027  __builtin_amdgcn_sched_barrier(0);
1028  __builtin_amdgcn_s_barrier();
1029  __builtin_amdgcn_sched_barrier(0);
1030  cl_load(memV, V_w0_lds_wr_idx, K_w0_lds_rd_idx);
1031 
1032  Scheduler::schedule(cl_p, number<3>{});
1033  kv_token_start += kN0;
1034  if(num_total_loop <= ++i_total_loops)
1035  {
1036  result = false;
1037  }
1038  }
1039  else
1040  {
1041 #if ADD_SBARRIER_FOR_PHASE0
1042  __builtin_amdgcn_sched_barrier(0);
1043  __builtin_amdgcn_s_barrier();
1044 #endif
1045  __builtin_amdgcn_sched_barrier(0);
1046  // phase0
1047  if constexpr(pi == 0)
1048  {
1049  ASM_MARKER("phase0 Wave4-7 (pi=0)");
1050  }
1051  else
1052  {
1053  ASM_MARKER("phase0 Wave4-7 (pi=1)");
1054  }
1055  cl_load(memV, V_w4_lds_wr_idx, K_w4_lds_rd_idx);
1056 
1057  Scheduler::schedule(cl_p, number<0>{});
1058  __builtin_amdgcn_sched_barrier(0);
1059  // phase1
1060  ASM_MARKER("phase1 Wave4-7");
1061  s_waitcnt<K_mem_su_ld_insts + V_mem_su_ld_insts, 0>();
1062  __builtin_amdgcn_sched_barrier(0);
1063  __builtin_amdgcn_s_barrier();
1064  __builtin_amdgcn_sched_barrier(0);
1065  asm volatile("s_nop 1");
1066  __builtin_amdgcn_sched_barrier(0);
1067  cl_calc(xdl_SP_p01_reg_idx, gemm0);
1068  fmha_alu1(xdl_SP_p23_reg_idx);
1069 
1070  Scheduler::schedule(cl_p, number<1>{});
1071  __builtin_amdgcn_sched_barrier(0);
1072  // phase2
1073  ASM_MARKER("phase2 Wave4-7");
1074  __builtin_amdgcn_s_barrier();
1075  __builtin_amdgcn_sched_barrier(0);
1076  cl_load(memK, K_w4_lds_wr_idx, V_w4_lds_rd_idx);
1077  Scheduler::schedule(cl_p, number<2>{});
1078  fmha_mask(xdl_SP_p01_reg_idx);
1079 
1080  kv_token_start += kN0;
1081  if(num_total_loop <= ++i_total_loops)
1082  {
1083  result = false;
1084  }
1085 
1086  __builtin_amdgcn_sched_barrier(0);
1087  // phase3
1088  ASM_MARKER("phase3 Wave4-7");
1089  s_waitcnt<K_mem_su_ld_insts + V_mem_su_ld_insts, 0>();
1090  __builtin_amdgcn_sched_barrier(0);
1091  __builtin_amdgcn_s_barrier();
1092  __builtin_amdgcn_sched_barrier(0);
1093  asm volatile("s_nop 1");
1094  __builtin_amdgcn_sched_barrier(0);
1095  cl_calc(xdl_SP_p23_reg_idx, gemm1);
1096 
1097  Scheduler::schedule(cl_p, number<3>{});
1098  __builtin_amdgcn_sched_barrier(0);
1099  fmha_alu_D_upd();
1100  }
1101  return result;
1102  };
1103  return iteration(number<0>{}) && iteration(number<1>{});
1104  };
1105 
1106  auto fmha_post_process = [&](auto d) {
1107  auto ps_pi = number<1>{} - d;
1108  auto V_lds_rd_idx = ps_pi;
1109 
1110  if(1 < num_total_loop)
1111  {
1112  s_waitcnt_vmcnt<K_mem_su_ld_insts>();
1113  }
1114  else
1115  {
1116  s_waitcnt_vmcnt<0>();
1117  }
1118  __builtin_amdgcn_s_barrier();
1119 
1120  V_lds_load(V_lds_rd_idx);
1121  fmha_alu1(ps_pi);
1122 
1123  s_waitcnt_lgkmcnt<0>();
1124 
1125  auto xdl_SP_p23_reg_idx = ps_pi;
1126  gemm(xdl_SP_p23_reg_idx, /*gemm_idx=*/number<1>{});
1127  };
1128 
1129  // pre-stage
1130  {
1131  ASM_MARKER("before pre-stage");
1132  // (1) load K0 to LDS & VGPR
1133  K_mem_load(number<0>{}); // mem_K0
1134 
1135  s_waitcnt_vmcnt<0>();
1136  __builtin_amdgcn_s_barrier();
1137 
1138  K_lds_load(number<0>{}); // lds_K0
1139 
1140  s_waitcnt_lgkmcnt<0>();
1141  __builtin_amdgcn_s_barrier();
1142 
1143  // (2) prefetch K1 and V0 to LDS in parallel with GEMM0
1144  if(1 < num_total_loop)
1145  {
1146  K_mem_load(number<1>{}); // mem_K1
1147  }
1148  V_mem_load(number<0>{}); // mem_V0
1149 
1150  // (3) mfma (Q*K0) + softmax
1151  gemm(number<0>{}, /*gemm_idx=*/number<0>{});
1152 
1153  fmha_mask(number<0>{});
1155  fmha_alu0(number<0>{});
1156  fmha_alu_D_upd();
1157 
1158  kv_token_start += kN0;
1159  ++i_total_loops;
1160  if(num_total_loop <= i_total_loops)
1161  {
1162  goto label_main_loops_exit;
1163  }
1164 
1165  if(2 < num_total_loop)
1166  {
1167  K_mem_load(number<0>{}); // mem_K2
1168 
1169  s_waitcnt_vmcnt<K_mem_su_ld_insts + V_mem_su_ld_insts>();
1170  __builtin_amdgcn_s_barrier();
1171  }
1172 
1173  ASM_MARKER("end pre-stage");
1174  }
1175 
1176  if(1 < num_total_loop)
1177  {
1178  if(warp_group_id == 0)
1179  {
1180  V_mem_load(number<1>{}); // V1
1181  K_lds_load(number<1>{}); // K1
1182 
1183  __builtin_amdgcn_s_setprio(0);
1184  __builtin_amdgcn_s_barrier();
1185  while(core_loop(number<0>{}))
1186  ;
1187  }
1188  if(warp_group_id != 0)
1189  {
1190  __builtin_amdgcn_s_setprio(1);
1191  __builtin_amdgcn_s_barrier();
1192  while(core_loop(number<1>{}))
1193  ;
1194  }
1195  }
1196  label_main_loops_exit:
1197  if(num_total_loop % 2)
1198  {
1199  fmha_post_process(number<1>{});
1200  }
1201  if(!(num_total_loop % 2))
1202  {
1203  fmha_post_process(number<0>{});
1204  }
1205 
1206  // store lse
1207  if constexpr(kStoreLSE)
1208  {
1209  auto lse = make_static_distributed_tensor<LSEDataType>(m.get_tile_distribution());
1210 
1211  constexpr auto lse_spans = decltype(lse)::get_distributed_spans();
1212  sweep_tile_span(lse_spans[number<0>{}], [&](auto idx0) {
1213  constexpr auto i_idx = make_tuple(idx0);
1214  lse(i_idx) = m[i_idx] / C_LOG2E + log(l[i_idx]);
1215  });
1216 
1217  store_tile(lse_dram_window_tmp, tile_elementwise_in(lse_element_func, lse));
1218  }
1219 
1220  // finally, O
1221  constexpr auto o_spans = decltype(o_acc)::get_distributed_spans();
1222 
1223  sweep_tile_span(o_spans[number<0>{}], [&](auto idx0) {
1224  constexpr auto i_idx = make_tuple(idx0);
1225  const auto tmp = [&]() {
1226  if constexpr(FmhaMask::IsMasking)
1227  {
1228  return l[i_idx] == 0.f ? 0.f : 1 / l[i_idx];
1229  }
1230  else
1231  return 1 / l[i_idx];
1232  }();
1233  sweep_tile_span(o_spans[number<1>{}], [&](auto idx1) {
1234  constexpr auto i_j_idx = make_tuple(idx0, idx1);
1235  o_acc(i_j_idx) *= tmp;
1236  });
1237  });
1238 
1239  o_acc = tile_elementwise_in(o_acc_element_func, o_acc);
1240 
1241  return o_acc;
1242  }
1243 
1244  template <typename QDramBlockWindowTmp,
1245  typename KDramBlockWindowTmp,
1246  typename VDramBlockWindowTmp,
1247  typename LSEDramBlockWindowTmp>
1248  CK_TILE_DEVICE auto operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile
1249  const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile
1250  const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile
1251  LSEDramBlockWindowTmp& lse_dram_block_window_tmp, // M0*1 tile
1252  FmhaMask mask,
1253  float scale_s,
1254  void* smem_ptr) const
1255  {
1256  using namespace ck_tile;
1257 
1258  return operator()(q_dram_block_window_tmp,
1259  identity{},
1260  k_dram_block_window_tmp,
1261  identity{},
1262  v_dram_block_window_tmp,
1263  identity{},
1264  lse_dram_block_window_tmp,
1265  identity{},
1266  identity{},
1267  identity{},
1268  identity{},
1269  mask,
1270  scale_s,
1271  smem_ptr);
1272  }
1273 };
1274 
1275 } // namespace ck_tile
#define ASM_MARKER(marker)
Definition: block_fmha_fwd_v3_pipeline.hpp:14
#define CK_TILE_DEVICE
Definition: config.hpp:45
#define CK_TILE_HOST_DEVICE
Definition: config.hpp:46
CK_TILE_DEVICE bf16x2_t cvt_pk_bf16_f32(float a, float b)
Definition: block_fmha_fwd_v3_pipeline.hpp:232
CK_TILE_DEVICE float fma_impl_vsv(float a, float b, float c)
Definition: block_fmha_fwd_v3_pipeline.hpp:192
CK_TILE_DEVICE float mul_impl_vv(float lhs, float rhs)
Definition: block_fmha_fwd_v3_pipeline.hpp:214
CK_TILE_DEVICE float add_impl_vv(float lhs, float rhs)
Definition: block_fmha_fwd_v3_pipeline.hpp:205
CK_TILE_DEVICE fp16x2_t cvt_pk_fp16_f32(float a, float b)
Definition: block_fmha_fwd_v3_pipeline.hpp:223
CK_TILE_DEVICE fp32x2_t pk_mul_f32(fp32x2_t lhs, fp32x2_t rhs)
Definition: block_fmha_fwd_v3_pipeline.hpp:241
Definition: cluster_descriptor.hpp:13
CK_TILE_DEVICE bfloat16_t log(bfloat16_t x)
Definition: bfloat16.hpp:431
constexpr CK_TILE_HOST_DEVICE auto make_naive_tensor_descriptor(const tuple< Lengths... > &lengths, const tuple< Strides... > &strides, number< GuaranteedLastDimensionVectorLength >=number<-1 >{}, number< GuaranteedLastDimensionVectorStride >=number<-1 >{})
Definition: tensor_descriptor.hpp:274
constexpr CK_TILE_HOST_DEVICE auto integer_divide_ceil(X x, Y y)
Definition: math.hpp:149
constexpr CK_TILE_DEVICE auto get_slice_tile(const tile_window_with_static_lengths< BottomTensorView_, WindowLengths_ > &tile, sequence< SliceBegins... > slice_begins, sequence< SliceEnds... > slice_ends)
Definition: slice_tile.hpp:23
CK_TILE_DEVICE auto tile_elementwise_in(const InElementFunc &in_element_func, const InTensor &... in_dstr_tensors)
Definition: tile_elementwise.hpp:40
CK_TILE_DEVICE void set_tile(DstrTensors &dstr_tensor, const T &value)
Definition: tile_elementwise.hpp:95
CK_TILE_DEVICE void async_load_tile_raw(LdsTileWindow_ &&lds_tile, const TileWindow_ &tile_window, number< i_access >={}, bool_constant< oob_conditional_check >={}, bool_constant< pre_nop >={})
Definition: load_tile.hpp:184
bfloat16_t bf16x2_t
Definition: bfloat16.hpp:433
CK_TILE_DEVICE void block_tile_reduce_sync(AccDistributedTensor_ &acc_tensor, const ReduceFunc &reduce_func, bool_constant< WithBroadcast >={}, bool_constant< CrossWarp >={})
Definition: block_reduce.hpp:21
tuple_array< T, N > statically_indexed_array
Definition: statically_indexed_array.hpp:16
float fp32x2_t
Definition: bfloat16.hpp:434
int32_t index_t
Definition: integer.hpp:9
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition: type_traits.hpp:21
_Float16 fp16x2_t
Definition: half.hpp:385
CK_TILE_DEVICE void sweep_tile_span(TileDistributedSpan_, const F &f)
Definition: sweep_tile.hpp:20
constexpr CK_TILE_DEVICE auto make_tile_window(null_tensor_view, const WindowLengths &window_lengths, const multi_index< WindowLengths::size()> &, Ts &&...)
Definition: null_tile_window.hpp:75
CK_TILE_DEVICE void move_tile_window(null_tile_window< WindowLengths > &, const typename null_tile_window< WindowLengths >::BottomTensorIndex &)
Definition: null_tile_window.hpp:95
CK_TILE_HOST_DEVICE void set_tile_if(static_distributed_tensor< DataType, StaticTileDistribution > &out_tensor, DataType value, XIndicesPredicate predicate)
Definition: static_distributed_tensor.hpp:185
constexpr CK_TILE_DEVICE auto make_tile_window_linear(const TensorView_ &tensor_view, const WindowLengths_ &window_lengths, const multi_index< TensorView_::get_num_of_dimension()> &origin, const StaticTileDistribution_ &tile_distribution, LinearBottomDims_={})
Definition: tile_window_linear.hpp:994
constexpr CK_TILE_HOST_DEVICE auto generate_tuple(F &&f, number< N >)
Definition: tuple.hpp:429
constexpr CK_TILE_HOST_DEVICE auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:360
CK_TILE_DEVICE void store_tile(tile_window_with_static_lengths< BottomTensorView_, WindowLengths_ > &tile_window_tmp, const static_distributed_tensor< DataType_, TileDistribution_ > &dstr_tensor)
Definition: store_tile.hpp:24
CK_TILE_DEVICE void clear_tile(DstrTensors &dstr_tensor)
Definition: tile_elementwise.hpp:177
int32_t int32x2_t
Definition: vector_type.hpp:154
CK_TILE_DEVICE auto load_tile_transpose(const tile_window_with_static_distribution< BottomTensorView_, WindowLengths_, TileDistribution_, NumCoord > &__restrict__ tile_window)
transpose loads tile from a tensor and returns the resulting tensor with a new (transposed) tile dist...
Definition: load_tile_transpose.hpp:486
constexpr CK_TILE_HOST_DEVICE T max(T x)
Definition: math.hpp:161
CK_TILE_DEVICE auto load_tile(const TileWindow_ &tile_window, number< i_access >={}, bool_constant< oob_conditional_check >={})
Definition: load_tile.hpp:36
CK_TILE_DEVICE bfloat16_t exp2(bfloat16_t x)
Definition: bfloat16.hpp:428
constexpr bool is_same_v
Definition: type.hpp:283
const GenericPointer< typename T::ValueType > T2 T::AllocatorType & a
Definition: pointer.h:1517
Definition: block_fmha_fwd_v3_pipeline.hpp:255
ck_tile::remove_cvref_t< typename Problem::PDataType > PDataType
Definition: block_fmha_fwd_v3_pipeline.hpp:264
ck_tile::remove_cvref_t< typename Problem::SMPLComputeDataType > SMPLComputeDataType
Definition: block_fmha_fwd_v3_pipeline.hpp:262
ck_tile::remove_cvref_t< typename Problem::KDataType > KDataType
Definition: block_fmha_fwd_v3_pipeline.hpp:259
ck_tile::remove_cvref_t< typename Problem::ODataType > ODataType
Definition: block_fmha_fwd_v3_pipeline.hpp:266
ck_tile::remove_cvref_t< typename Problem::BlockFmhaShape > BlockFmhaShape
Definition: block_fmha_fwd_v3_pipeline.hpp:273
static constexpr CK_TILE_DEVICE void s_waitcnt_lgkmcnt()
Definition: block_fmha_fwd_v3_pipeline.hpp:389
ck_tile::remove_cvref_t< typename Problem::OaccDataType > OaccDataType
Definition: block_fmha_fwd_v3_pipeline.hpp:265
static constexpr CK_TILE_DEVICE auto MakeSimpleLdsDesc1D()
Definition: block_fmha_fwd_v3_pipeline.hpp:352
static constexpr CK_TILE_DEVICE void s_waitcnt()
Definition: block_fmha_fwd_v3_pipeline.hpp:373
ck_tile::remove_cvref_t< Problem_ > Problem
Definition: block_fmha_fwd_v3_pipeline.hpp:256
static constexpr CK_TILE_DEVICE void s_waitcnt_vmcnt()
Definition: block_fmha_fwd_v3_pipeline.hpp:383
static constexpr CK_TILE_DEVICE auto make_lds_tile_window(void *base, const Descriptor &desc)
Definition: block_fmha_fwd_v3_pipeline.hpp:362
ck_tile::remove_cvref_t< typename Problem::QDataType > QDataType
Definition: block_fmha_fwd_v3_pipeline.hpp:258
ck_tile::remove_cvref_t< typename Problem::LSEDataType > LSEDataType
Definition: block_fmha_fwd_v3_pipeline.hpp:263
static constexpr CK_TILE_HOST_DEVICE ck_tile::index_t GetSmemSize()
Definition: block_fmha_fwd_v3_pipeline.hpp:328
remove_cvref_t< typename BlockFmhaShape::VLayout > VLayout
Definition: block_fmha_fwd_v3_pipeline.hpp:275
ck_tile::remove_cvref_t< typename Problem::SaccDataType > SaccDataType
Definition: block_fmha_fwd_v3_pipeline.hpp:261
ck_tile::remove_cvref_t< Policy_ > Policy
Definition: block_fmha_fwd_v3_pipeline.hpp:257
ck_tile::remove_cvref_t< typename Problem::FmhaMask > FmhaMask
Definition: block_fmha_fwd_v3_pipeline.hpp:267
CK_TILE_DEVICE auto operator()(const QDramBlockWindowTmp &q_dram_block_window_tmp, const QElementFunction &q_element_func, const KDramBlockWindowTmp &k_dram_block_window_tmp, [[maybe_unused]] const KElementFunction &k_element_func, const VDramBlockWindowTmp &v_dram_block_window_tmp, [[maybe_unused]] const VElementFunction &v_element_func, LSEDramBlockWindowTmp &lse_dram_window_tmp, const LSEElementFunction &lse_element_func, [[maybe_unused]] const SAccElementFunction &s_acc_element_func, const PComputeElementFunction &p_compute_element_func, const OAccElementFunction &o_acc_element_func, FmhaMask mask, float scale_s, void *smem_ptr) const
Definition: block_fmha_fwd_v3_pipeline.hpp:405
ck_tile::remove_cvref_t< typename Problem::VDataType > VDataType
Definition: block_fmha_fwd_v3_pipeline.hpp:260
static constexpr CK_TILE_DEVICE auto MakeSimpleLdsDesc()
Definition: block_fmha_fwd_v3_pipeline.hpp:338
CK_TILE_DEVICE auto operator()(const QDramBlockWindowTmp &q_dram_block_window_tmp, const KDramBlockWindowTmp &k_dram_block_window_tmp, const VDramBlockWindowTmp &v_dram_block_window_tmp, LSEDramBlockWindowTmp &lse_dram_block_window_tmp, FmhaMask mask, float scale_s, void *smem_ptr) const
Definition: block_fmha_fwd_v3_pipeline.hpp:1248
static constexpr CK_TILE_DEVICE void schedule(ck_tile::number< WaveGroup >, ck_tile::number< Phase >)
Definition: block_fmha_fwd_v3_pipeline.hpp:121
static constexpr CK_TILE_DEVICE void schedule(ck_tile::number< WaveGroup >, ck_tile::number< Phase >)
Definition: block_fmha_fwd_v3_pipeline.hpp:47
Definition: block_fmha_fwd_v3_pipeline.hpp:41
Definition: integral_constant.hpp:13
Definition: functional.hpp:86
Definition: numeric.hpp:18
Definition: coordinate_transform.hpp:1392
Definition: sequence.hpp:49
Definition: functional.hpp:43
Definition: tensor_view.hpp:41
#define C_LOG2E
Definition: math.hpp:469