/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_v3_pipeline.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_v3_pipeline.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_v3_pipeline.hpp Source File
block_fmha_fwd_v3_pipeline.hpp
Go to the documentation of this file.
1 // Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
2 // SPDX-License-Identifier: MIT
3 
4 #pragma once
5 
6 #include "ck_tile/core.hpp"
11 
12 #define ENABLE_ASM_MARKER 1
13 #if ENABLE_ASM_MARKER
14 #define ASM_MARKER(marker) \
15  __builtin_amdgcn_sched_barrier(0); \
16  asm volatile("; [POYENC] " #marker); \
17  __builtin_amdgcn_sched_barrier(0);
18 #else
19 #define ASM_MARKER(marker)
20 #endif
21 
22 #define ADD_SBARRIER_FOR_PHASE0 1
23 #if !defined(CK_TILE_DISABLE_PACKED_FP32)
24 #define CK_TILE_DISABLE_PACKED_FP32 0
25 #endif
26 
27 #define WARP_ID 0
28 #define LANE_ID 0
29 
30 #define ENABLE_DEBUG_STMTS 1
31 #if ENABLE_DEBUG_STMTS
32 #define DEBUG_STMTS \
33  if(get_block_1d_id() == 0 && get_warp_id() == WARP_ID && get_lane_id() == LANE_ID)
34 #else
35 #define DEBUG_STMTS if constexpr(false)
36 #endif
37 
38 namespace ck_tile {
39 
40 template <typename PipelineProblem, bool kIsMasking>
42 
43 template <typename PipelineProblem>
44 struct CoreLoopScheduler<PipelineProblem, /*kIsMasking=*/true>
45 {
46  template <ck_tile::index_t WaveGroup, ck_tile::index_t Phase>
49  {
50  using namespace ck_tile;
51 
52  if constexpr(WaveGroup == 0)
53  {
54  if constexpr(Phase == 0)
55  {
56  static_for<0, 8, 1>{}([&](auto) {
57  __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
58  __builtin_amdgcn_sched_group_barrier(0x200, 2, 0); // TRANS
59  __builtin_amdgcn_sched_group_barrier(0x002, 2, 0); // VALU
60  });
61  }
62  else if constexpr(Phase == 1)
63  {
64  __builtin_amdgcn_sched_group_barrier(0x002, 2, 0); // VALU
65  __builtin_amdgcn_sched_group_barrier(0x004, 4, 0); // SALU
66  }
67  else if constexpr(Phase == 2)
68  {
69 #if !CK_TILE_DISABLE_PACKED_FP32
70  __builtin_amdgcn_sched_group_barrier(0x002, 4, 0); // VALU
71 #endif
72  static_for<0, 8, 1>{}([&](auto) {
73  __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
74  __builtin_amdgcn_sched_group_barrier(0x002, 4, 0); // VALU
75  });
76  }
77  else if constexpr(Phase == 3)
78  {
79  __builtin_amdgcn_sched_group_barrier(0x002, 2, 0); // VALU
80  __builtin_amdgcn_sched_group_barrier(0x004, 4, 0); // SALU
81  }
82  }
83  else
84  {
85  if constexpr(Phase == 0)
86  {
87  __builtin_amdgcn_sched_group_barrier(0x002, 2, 0); // VALU
88  __builtin_amdgcn_sched_group_barrier(0x004, 4, 0); // SALU
89  }
90  else if constexpr(Phase == 1)
91  {
92  static_for<0, 8, 1>{}([&](auto) {
93  __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
94  __builtin_amdgcn_sched_group_barrier(0x200, 2, 0); // TRANS
95  __builtin_amdgcn_sched_group_barrier(0x002, 2, 0); // VALU
96  });
97  }
98  else if constexpr(Phase == 2)
99  {
100  __builtin_amdgcn_sched_group_barrier(0x002, 2, 0); // VALU
101  __builtin_amdgcn_sched_group_barrier(0x004, 4, 0); // SALU
102  }
103  else if constexpr(Phase == 3)
104  {
105 #if !CK_TILE_DISABLE_PACKED_FP32
106  __builtin_amdgcn_sched_group_barrier(0x002, 4, 0); // VALU
107 #endif
108  static_for<0, 8, 1>{}([&](auto) {
109  __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
110  __builtin_amdgcn_sched_group_barrier(0x002, 4, 0); // VALU
111  });
112  }
113  }
114  }
115 };
116 
117 template <typename PipelineProblem>
118 struct CoreLoopScheduler<PipelineProblem, /*kIsMasking=*/false>
119 {
120  template <ck_tile::index_t WaveGroup, ck_tile::index_t Phase>
123  {
124  using namespace ck_tile;
125 
126  if constexpr(WaveGroup == 0)
127  {
128  if constexpr(Phase == 0)
129  {
130  static_for<0, 8, 1>{}([&](auto) {
131  __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
132  __builtin_amdgcn_sched_group_barrier(0x200, 2, 0); // TRANS
133  __builtin_amdgcn_sched_group_barrier(0x002, 2, 0); // VALU
134  });
135  }
136  else if constexpr(Phase == 1)
137  {
138  __builtin_amdgcn_sched_group_barrier(0x002, 2, 0); // VALU
139  __builtin_amdgcn_sched_group_barrier(0x004, 4, 0); // SALU
140  }
141  else if constexpr(Phase == 2)
142  {
143 #if !CK_TILE_DISABLE_PACKED_FP32
144  __builtin_amdgcn_sched_group_barrier(0x002, 4, 0); // VALU
145 #endif
146  static_for<0, 8, 1>{}([&](auto) {
147  __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
148  __builtin_amdgcn_sched_group_barrier(0x002, 4, 0); // VALU
149  });
150  }
151  else if constexpr(Phase == 3)
152  {
153  __builtin_amdgcn_sched_group_barrier(0x002, 2, 0); // VALU
154  __builtin_amdgcn_sched_group_barrier(0x004, 4, 0); // SALU
155  }
156  }
157  else
158  {
159  if constexpr(Phase == 0)
160  {
161  __builtin_amdgcn_sched_group_barrier(0x002, 2, 0); // VALU
162  __builtin_amdgcn_sched_group_barrier(0x004, 4, 0); // SALU
163  }
164  else if constexpr(Phase == 1)
165  {
166  static_for<0, 8, 1>{}([&](auto) {
167  __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
168  __builtin_amdgcn_sched_group_barrier(0x200, 2, 0); // TRANS
169  __builtin_amdgcn_sched_group_barrier(0x002, 2, 0); // VALU
170  });
171  }
172  else if constexpr(Phase == 2)
173  {
174  __builtin_amdgcn_sched_group_barrier(0x002, 2, 0); // VALU
175  __builtin_amdgcn_sched_group_barrier(0x004, 4, 0); // SALU
176  }
177  else if constexpr(Phase == 3)
178  {
179 #if !CK_TILE_DISABLE_PACKED_FP32
180  __builtin_amdgcn_sched_group_barrier(0x002, 4, 0); // VALU
181 #endif
182  static_for<0, 8, 1>{}([&](auto) {
183  __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
184  __builtin_amdgcn_sched_group_barrier(0x002, 4, 0); // VALU
185  });
186  }
187  }
188  }
189 };
190 
191 namespace detail {
192 CK_TILE_DEVICE float fma_impl_vsv(float a, float b, float c)
193 {
194 #if CK_TILE_DISABLE_PACKED_FP32
195  return a * b + c;
196 #else
197  float result;
198  asm volatile("v_fma_f32 %[result], %[a], %[b], %[c]"
199  : [result] "=v"(result)
200  : [a] "v"(a), [b] "s"(b), [c] "v"(c));
201  return result;
202 #endif
203 }
204 
205 CK_TILE_DEVICE float add_impl_vv(float lhs, float rhs)
206 {
207  float result;
208  asm volatile("v_add_f32_e32 %[result], %[lhs], %[rhs]"
209  : [result] "=v"(result)
210  : [lhs] "v"(lhs), [rhs] "v"(rhs));
211  return result;
212 }
213 
214 CK_TILE_DEVICE float mul_impl_vv(float lhs, float rhs)
215 {
216  float result;
217  asm volatile("v_mul_f32_e32 %[result], %[lhs], %[rhs]"
218  : [result] "=v"(result)
219  : [lhs] "v"(lhs), [rhs] "v"(rhs));
220  return result;
221 }
222 
224 {
225  fp16x2_t result;
226  asm volatile("v_cvt_pk_f16_f32 %[result], %[a], %[b]"
227  : [result] "=v"(result)
228  : [a] "v"(a), [b] "v"(b));
229  return result;
230 }
231 
233 {
234  bf16x2_t result;
235  asm volatile("v_cvt_pk_bf16_f32 %[result], %[a], %[b]"
236  : [result] "=v"(result)
237  : [a] "v"(a), [b] "v"(b));
238  return result;
239 }
240 
242 {
243  fp32x2_t result;
244  asm volatile("v_pk_mul_f32 %[result], %[lhs], %[rhs]"
245  : [result] "=v"(result)
246  : [lhs] "v"(lhs), [rhs] "v"(rhs));
247  return result;
248 }
249 } // namespace detail
250 
253 template <typename Problem_, typename Policy_ = BlockFmhaV3PipelineDefaultPolicy>
255 {
269  static_assert(is_generic_attention_mask_v<FmhaMask>);
270 
271  static_assert(std::is_same_v<SaccDataType, SMPLComputeDataType>,
272  "we will the same dist tensor 'sp_compute' for both gemm0 & softmax");
273 
275 
277  static_assert(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>);
278 
279  static constexpr ck_tile::index_t kBlockSize = Problem::kBlockSize;
280 
281  static constexpr ck_tile::index_t kM0 = BlockFmhaShape::kM0;
282  static constexpr ck_tile::index_t kN0 = BlockFmhaShape::kN0;
283  static constexpr ck_tile::index_t kK0 = BlockFmhaShape::kK0;
284  static constexpr ck_tile::index_t kN1 = BlockFmhaShape::kN1;
285  static constexpr ck_tile::index_t kK1 = BlockFmhaShape::kK1;
286  static constexpr ck_tile::index_t kQKHeaddim = BlockFmhaShape::kQKHeaddim;
287  static constexpr ck_tile::index_t kSubQKHeaddim = BlockFmhaShape::kSubQKHeaddim;
288 
289  static_assert(kQKHeaddim == 128 && kSubQKHeaddim == 128, "only supports hdim=hdim_v=128");
290 
291  static constexpr bool kIsGroupMode = Problem::kIsGroupMode;
292  static constexpr bool kPadSeqLenQ = Problem::kPadSeqLenQ;
293  static constexpr bool kPadSeqLenK = Problem::kPadSeqLenK;
294  static constexpr bool kPadHeadDimQ = Problem::kPadHeadDimQ;
295  static constexpr bool kPadHeadDimV = Problem::kPadHeadDimV;
296  static constexpr bool kHasLogitsSoftCap = Problem::kHasLogitsSoftCap;
297  static constexpr auto BiasEnum = Problem::BiasEnum;
298  static constexpr bool kStoreLSE = Problem::kStoreLSE;
299  static constexpr bool kHasDropout = Problem::kHasDropout;
300  static constexpr auto QScaleEnum = Problem::QScaleEnum;
301  static constexpr bool kSkipMinSeqlenQ = Problem::kSkipMinSeqlenQ;
302  static_assert((BiasEnum == BlockAttentionBiasEnum::NO_BIAS && !kStoreLSE && !kHasDropout &&
304  !kSkipMinSeqlenQ),
305  "enable unsupported features");
306 
307  // last dimension vector length used to create tensor view(and decide buffer_load vector length)
308  // ... together with tensor distribution. tensor dist should able to overwrite this
309  static constexpr ck_tile::index_t kAlignmentQ =
310  kPadHeadDimQ ? 1 : Policy::template GetAlignmentQ<Problem>();
311  static constexpr ck_tile::index_t kAlignmentK =
312  kPadHeadDimQ ? 1 : Policy::template GetAlignmentK<Problem>();
313  static constexpr ck_tile::index_t kAlignmentV =
314  kPadHeadDimV ? 1 : Policy::template GetAlignmentV<Problem>();
315 
316  static constexpr ck_tile::index_t kAlignmentO =
317  kPadHeadDimV ? 1 : Policy::template GetAlignmentO<Problem>();
318 
319  static constexpr ck_tile::index_t kBlockPerCu = []() {
320  if constexpr(Problem::kBlockPerCu != -1)
321  return Problem::kBlockPerCu;
322  else
323  {
324  return 2;
325  }
326  }();
327 
329  {
330  // create another LDS buffer for p
331  return ck_tile::max(kM0 * kN1 * sizeof(PDataType),
332  Policy::template GetSmemSize<Problem>() +
333  kM0 * kN0 * sizeof(PDataType));
334  }
335 
336  // for debug only
337  template <ck_tile::index_t MPerBlock, ck_tile::index_t NPerBlock>
338  CK_TILE_DEVICE static constexpr auto MakeSimpleLdsDesc()
339  {
340  using namespace ck_tile;
341  constexpr auto lds_block_desc =
344  number<1>{},
345  number<1>{});
346 
347  return lds_block_desc;
348  }
349 
350  // for debug only
351  template <ck_tile::index_t MPerBlock>
352  CK_TILE_DEVICE static constexpr auto MakeSimpleLdsDesc1D()
353  {
354  using namespace ck_tile;
355  constexpr auto lds_block_desc = make_naive_tensor_descriptor(
357 
358  return lds_block_desc;
359  }
360 
361  template <typename DataType, typename Descriptor>
362  CK_TILE_DEVICE static constexpr auto make_lds_tile_window(void* base, const Descriptor& desc)
363  {
364  using namespace ck_tile;
365 
366  auto tensor_view =
367  make_tensor_view<address_space_enum::lds>(reinterpret_cast<DataType*>(base), desc);
368  return make_tile_window(tensor_view, desc.get_lengths(), {0, 0});
369  }
370 
371  // vmcnt=0~63, lgkmcnt=0~15, expcnt=0~7
372  template <uint16_t Vmcnt, uint8_t Lgkmcnt, uint8_t Expcnt = 7>
373  CK_TILE_DEVICE static constexpr void s_waitcnt()
374  {
375  // vmcnt use bits {[15:14],[3:0]}
376  // expcnt use bits [6:4]
377  // lgkmcnt use bits [11:8]
378  __builtin_amdgcn_s_waitcnt((((0b110000 & Vmcnt) << (14 - 4)) | (0b1111 & Vmcnt)) |
379  ((0b111 & Expcnt) << 4) | ((0b1111 & Lgkmcnt) << 8));
380  }
381 
382  template <uint16_t Vmcnt>
383  CK_TILE_DEVICE static constexpr void s_waitcnt_vmcnt()
384  {
385  s_waitcnt<Vmcnt, 15>();
386  }
387 
388  template <uint8_t Lgkmcnt>
389  CK_TILE_DEVICE static constexpr void s_waitcnt_lgkmcnt()
390  {
391  s_waitcnt<63, Lgkmcnt>();
392  }
393 
394  template <typename QDramBlockWindowTmp,
395  typename KDramBlockWindowTmp,
396  typename VDramBlockWindowTmp,
397  typename LSEDramBlockWindowTmp,
398  typename QElementFunction,
399  typename KElementFunction,
400  typename VElementFunction,
401  typename LSEElementFunction,
402  typename SAccElementFunction,
403  typename PComputeElementFunction,
404  typename OAccElementFunction,
405  typename AttentionVariantParams,
406  typename BlockIndices>
407  CK_TILE_DEVICE auto operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile
408  const QElementFunction& q_element_func,
409  const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile
410  [[maybe_unused]] const KElementFunction& k_element_func,
411  const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile
412  [[maybe_unused]] const VElementFunction& v_element_func,
413  LSEDramBlockWindowTmp& lse_dram_window_tmp, // M0*1 tile
414  const LSEElementFunction& lse_element_func,
415  [[maybe_unused]] const SAccElementFunction& s_acc_element_func,
416  const PComputeElementFunction& p_compute_element_func,
417  const OAccElementFunction& o_acc_element_func,
418  FmhaMask mask,
419  float scale_s,
420  const AttentionVariant& variant,
421  const AttentionVariantParams& variant_params,
422  const BlockIndices& block_indices,
423  void* smem_ptr) const
424  {
425  using namespace ck_tile;
426 
427  static_assert(
431  "wrong!");
432 
433  static_assert(kM0 == QDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
434  kN0 == KDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
435  kK0 == KDramBlockWindowTmp{}.get_window_lengths()[number<1>{}] &&
436  kK1 == VDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
437  kN1 == VDramBlockWindowTmp{}.get_window_lengths()[number<1>{}],
438  "wrong!");
439 
440  static_assert(sizeof(SaccDataType) * kM0 * kN0 <= GetSmemSize());
441  auto s_lds = make_tensor_view<address_space_enum::lds>(
442  reinterpret_cast<SaccDataType*>(static_cast<char*>(smem_ptr)),
443  MakeSimpleLdsDesc<kM0, kN0>());
444  [[maybe_unused]] auto s_lds_window =
445  make_tile_window(s_lds, make_tuple(number<kM0>{}, number<kN0>{}), {0, 0});
446 
447  auto p_lds = make_tensor_view<address_space_enum::lds>(
448  reinterpret_cast<PDataType*>(static_cast<char*>(smem_ptr) +
449  Policy::template GetSmemSize<Problem>()),
450  MakeSimpleLdsDesc<kM0, kN0>());
451  [[maybe_unused]] auto p_lds_window =
452  make_tile_window(p_lds, make_tuple(number<kM0>{}, number<kN0>{}), {0, 0});
453 
454  auto o_lds = make_tensor_view<address_space_enum::lds>(
455  reinterpret_cast<PDataType*>(static_cast<char*>(smem_ptr)),
456  MakeSimpleLdsDesc<kM0, kN1>());
457  [[maybe_unused]] auto o_lds_window =
458  make_tile_window(o_lds, make_tuple(number<kM0>{}, number<kN1>{}), {0, 0});
459 
460  auto m_lds = make_tensor_view<address_space_enum::lds>(
461  reinterpret_cast<SMPLComputeDataType*>(static_cast<char*>(smem_ptr) +
462  Policy::template GetSmemSize<Problem>()),
463  MakeSimpleLdsDesc1D<kM0>());
464  [[maybe_unused]] auto m_lds_window =
465  make_tile_window(m_lds, make_tuple(number<kM0>{}), {0});
466 
467  const index_t warp_group_id = get_warp_id() / 4;
468 
469  // Block GEMM
470  constexpr auto gemm_0 = Policy::template GetQKBlockGemm<Problem>();
471  constexpr auto gemm_1 = Policy::template GetPVBlockGemm<Problem>();
472 
473  auto q_dram_window = make_tile_window_linear(
474  q_dram_block_window_tmp, Policy::template MakeQRegTileDistribution<Problem>());
475 
476  // reduction function for softmax
477  const auto f_max = [](auto e0, auto e1) { return max(e0, e1); };
478  const auto f_sum = [](auto e0, auto e1) { return e0 + e1; };
479 
480  auto k_lds_window_store = generate_tuple(
481  [&](auto i_buf) {
482  return make_lds_tile_window<KDataType>(
483  smem_ptr, Policy::template MakeKLdsStoreBlockDescriptor<Problem>(i_buf));
484  },
485  number<2>{});
486 
487  auto v_lds_window_store = generate_tuple(
488  [&](auto i_buf) {
489  return make_lds_tile_window<KDataType>(
490  smem_ptr, Policy::template MakeVLdsStoreBlockDescriptor<Problem>(i_buf));
491  },
492  number<2>{});
493 
495  make_lds_tile_window<KDataType>(
496  nullptr,
497  Policy::template MakeKLdsLoadBlockDescriptor<Problem>()),
498  Policy::template MakeKRegTileDistribution<Problem>())),
499  2>
500  k_lds_window_load;
501 
503  make_lds_tile_window<VDataType>(
504  nullptr,
505  Policy::template MakeVLdsLoadBlockDescriptor<Problem>()),
506  Policy::template MakeVRegTileDistribution<Problem>())),
507  2>
508  v_lds_window_load;
509 
510  decltype(make_static_distributed_tensor<QDataType>(
511  Policy::template MakeQRegTileDistribution<Problem>())) q_tile;
512 
513  union kv_tile_type
514  {
515  CK_TILE_DEVICE kv_tile_type() {}
516 
517  decltype(load_tile(k_lds_window_load(number<0>{}))) k_tile;
518 
519  decltype(load_tile_transpose(v_lds_window_load(number<0>{}))) v_tile;
520  } kv_tile;
521 
522  union sp_compute_type
523  {
524  CK_TILE_DEVICE sp_compute_type() {}
525 
526  decltype(gemm_0.MakeCBlockTile()) sp_compute;
527  decltype(make_static_distributed_tensor<PDataType>(
528  Policy::template MakePRegTileDistribution<Problem>())) p;
529  };
531 
532  decltype(gemm_1.MakeCBlockTile()) o_acc;
533  constexpr index_t fmha_alu_D_reg_cnt = 6; // threshold to decide how many fmha_alu_D_upd()
534  // instructions should we move to fmha_alu1()
535  static_assert(fmha_alu_D_reg_cnt <= o_acc.thread_buf_.size());
536 
537  decltype(block_tile_reduce<SMPLComputeDataType>(
538  sp(number<0>{}).sp_compute, sequence<1>{}, f_max, SMPLComputeDataType{0})) m;
539  decltype(m) l;
540 
541  // initialize k_lds_window and v_lds_window
542  static_for<0, 2, 1>{}([&](auto idx) {
543  k_lds_window_load(idx) = make_tile_window(
544  make_lds_tile_window<KDataType>(
545  static_cast<char*>(smem_ptr) + (idx)*Policy::template GetSmemSizeKV<Problem>(),
546  Policy::template MakeKLdsLoadBlockDescriptor<Problem>()),
547  Policy::template MakeKRegTileDistribution<Problem>());
548  });
549 
550  static_for<0, 2, 1>{}([&](auto idx) {
551  v_lds_window_load(idx) =
552  make_tile_window(make_lds_tile_window<VDataType>(
553  static_cast<char*>(smem_ptr) +
554  (idx + 2) * Policy::template GetSmemSizeKV<Problem>(),
555  Policy::template MakeVLdsLoadBlockDescriptor<Problem>()),
556  Policy::template MakeVRegTileDistribution<Problem>());
557  });
558 
559  {
560  auto origin_q = load_tile(q_dram_window);
561  auto transformed_q = tile_elementwise_in(q_element_func, origin_q);
562 
563  q_tile = transformed_q;
564  }
565 
566  clear_tile(o_acc);
567  set_tile(m, bit_cast<float>(0xff7fffff)); // a bit larger than -infinity
568  clear_tile(l);
569 
570  const auto q_origin = q_dram_window.get_window_origin();
571  const auto [seqlen_k_start, seqlen_k_end] =
572  mask.GetTileRangeAlongX(q_origin.at(number<0>{}), number<kM0>{}, number<kN0>{});
573 
574  const auto num_total_loop = integer_divide_ceil(seqlen_k_end - seqlen_k_start, kN0);
575  index_t kv_token_start = seqlen_k_start;
576 
577  // check early exit if no work to do
578  if constexpr(FmhaMask::IsMasking || kPadSeqLenK)
579  {
580  if(num_total_loop <= 0)
581  {
582  if constexpr(kStoreLSE)
583  {
584  auto lse =
585  make_static_distributed_tensor<LSEDataType>(m.get_tile_distribution());
586 
588 
589  store_tile(lse_dram_window_tmp, tile_elementwise_in(lse_element_func, lse));
590  }
591 
592  // Note: here occ are all cleard, return it
593  // Note: q loaded but no fence, ignore it.
594  return o_acc;
595  }
596  }
597 
598  auto k_dram_window =
599  make_tile_window(k_dram_block_window_tmp.get_bottom_tensor_view(),
600  k_dram_block_window_tmp.get_window_lengths(),
601  {seqlen_k_start, 0},
602  Policy::template MakeKDramTileDistribution<Problem>());
603  k_dram_window.init_raw();
604 
605  auto v_dram_window =
606  make_tile_window(v_dram_block_window_tmp.get_bottom_tensor_view(),
607  v_dram_block_window_tmp.get_window_lengths(),
608  {seqlen_k_start, 0}, // TODO: hdim split?
609  Policy::template MakeVDramTileDistribution<Problem>());
610  v_dram_window.init_raw();
611 
612  // prefetch K tile
613  index_t i_total_loops = 0;
614  constexpr index_t k0_loops = kQKHeaddim / kK0;
615  constexpr index_t k1_loops = kN0 / kK1;
616  static_assert(1 == k0_loops);
617  static_assert(1 == k1_loops);
618  static_assert(kN0 == kK1);
619 
620  constexpr index_t NumWarpGroups = Problem::kBlockSize / Policy::NumThreadPerWarpGroup;
621  static_assert(NumWarpGroups == 2);
622 
623  [[maybe_unused]] auto print_dist_tensor = [&](const auto& dist_tensor, const char* name) {
624  printf("[POYENC] %s (size=%d): %5.2f",
625  name,
626  decltype(dist_tensor.thread_buf_)::size(),
627  ck_tile::type_convert<float>(dist_tensor.thread_buf_[0]));
628  static_for<1, decltype(dist_tensor.thread_buf_)::size(), 1>{}([&](auto i) {
629  printf(", %5.2f", ck_tile::type_convert<float>(dist_tensor.thread_buf_[i]));
630  });
631  printf("\n");
632  };
633 
634  [[maybe_unused]] auto print_lds = [&](auto lds_tile_window, const char* name) {
635  const auto num_rows = lds_tile_window.get_window_lengths().at(number<0>{});
636  const auto num_cols = lds_tile_window.get_window_lengths().at(number<1>{});
637 
638  auto desc = lds_tile_window.get_bottom_tensor_view().desc_;
639  auto data = lds_tile_window.get_bottom_tensor_view().buf_.p_data_;
640 
641  if constexpr(true || num_rows < num_cols)
642  {
643  for(int row = 0; row < num_rows; ++row)
644  {
645  int offset = desc.calculate_offset(make_tuple(row, 0));
646  printf("[DEVICE] %s[%3d] = %5.2f",
647  name,
648  row,
649  ck_tile::type_convert<float>(data[offset]));
650  for(int col = 1; col < num_cols; ++col)
651  {
652  printf(", ");
653  offset = desc.calculate_offset(make_tuple(row, col));
654  printf("%5.2f", ck_tile::type_convert<float>(data[offset]));
655  }
656  printf("\n");
657  }
658  }
659  else
660  {
661  for(int col = 0; col < num_cols; ++col)
662  {
663  int offset = desc.calculate_offset(make_tuple(0, col));
664  printf("[DEVICE] %s[%3d] = %5.2f",
665  name,
666  col,
667  ck_tile::type_convert<float>(data[offset]));
668  for(int row = 1; row < num_rows; ++row)
669  {
670  printf(", ");
671  offset = desc.calculate_offset(make_tuple(row, col));
672  printf("%5.2f", ck_tile::type_convert<float>(data[offset]));
673  }
674  printf("\n");
675  }
676  }
677  };
678 
679  [[maybe_unused]] auto print_lds_1d = [&](auto lds_tile_window, const char* name) {
680  const auto num_elems = lds_tile_window.get_window_lengths().at(number<0>{});
681 
682  auto desc = lds_tile_window.get_bottom_tensor_view().desc_;
683  auto data = lds_tile_window.get_bottom_tensor_view().buf_.p_data_;
684 
685  int offset = desc.calculate_offset(make_tuple(0));
686  printf("[DEVICE] %s = %5.2f", name, ck_tile::type_convert<float>(data[offset]));
687  for(int e = 1; e < num_elems; ++e)
688  {
689  printf(", ");
690  offset = desc.calculate_offset(make_tuple(e));
691  printf("%5.2f", ck_tile::type_convert<float>(data[offset]));
692  }
693  printf("\n");
694  };
695 
696  // K_mem_su_ld_insts = 1 for 32 x 128
697  // V_mem_su_ld_insts = 1 for 128 x 32
698  constexpr int K_mem_su_ld_insts = k_dram_window.get_num_of_access();
699  constexpr int V_mem_su_ld_insts = v_dram_window.get_num_of_access();
700 
701  auto K_mem_load = [&](auto k_lds_write_idx) {
702  async_load_tile_raw(k_lds_window_store(k_lds_write_idx), k_dram_window);
703 
705  // move K tile windows
706  move_tile_window(k_dram_window, {kN0, 0});
707  };
708 
709  auto K_lds_load = [&](auto k_lds_read_idx) {
710  kv_tile.k_tile = load_tile(k_lds_window_load(k_lds_read_idx));
711  };
712 
713  auto V_mem_load = [&](auto v_lds_write_idx) {
714  async_load_tile_raw(v_lds_window_store(v_lds_write_idx), v_dram_window);
715 
717  move_tile_window(v_dram_window, {kK1, 0});
718  };
719 
720  auto V_lds_load = [&](auto v_lds_read_idx) {
721  kv_tile.v_tile = load_tile_transpose(v_lds_window_load(v_lds_read_idx));
722  };
723 
724  decltype(m) m_old;
725  SMPLComputeDataType o_acc_scale; // rescale o_acc in fmha_alu1() & fmha_alu_D_upd()
727  statically_indexed_array<decltype(sp(number<0>{}).sp_compute), 2> sp_delta;
728 
729  auto fmha_logits_trans = [&](auto sp_reg_idx) {
730  if constexpr(kHasLogitsSoftCap)
731  {
732  auto apply_logits_transform = [&variant, &variant_params, &block_indices](
733  auto& logits) {
734  logits = variant.LogitsTransform(variant_params,
735  variant.QueryTransform(variant_params, logits),
736  block_indices.batch_idx,
737  block_indices.qo_head_idx,
738  block_indices.kv_head_idx);
739  };
740 
741  tile_elementwise_inout(apply_logits_transform, sp(sp_reg_idx).sp_compute);
742  }
743  };
744 
745  auto fmha_alu0 = [&](auto sp_reg_idx) {
746  m_old = m; // m{j-1}
747  static_assert(m.thread_buf_.size() == 1,
748  "assuming that each thread holds 1 rowmax value");
749  auto m_latest = block_tile_reduce<SMPLComputeDataType>(
750  sp(sp_reg_idx).sp_compute, sequence<1>{}, f_max, m.thread_buf_[0]);
751 #if defined(__gfx950__)
752  // assuming that we are using 32x32 mfma
753  int32x2_t swapped_regs =
754  __builtin_amdgcn_permlane32_swap(bit_cast<int32_t>(m_latest.thread_buf_[0]),
755  bit_cast<int32_t>(m_latest.thread_buf_[0]),
756  false,
757  false);
759  m_latest.thread_buf_[0] = f_max(bit_cast<SMPLComputeDataType>(swapped_regs.x),
760  bit_cast<SMPLComputeDataType>(swapped_regs.y));
761 #else
762  block_tile_reduce_sync(m_latest, f_max, bool_constant<false>{});
763 #endif
764  m = m_latest;
765 
766  constexpr auto p_spans =
767  std::decay_t<decltype(sp(sp_reg_idx).sp_compute)>::get_distributed_spans();
768  sweep_tile_span(p_spans[number<0>{}], [&](auto idx0) {
769  sweep_tile_span(p_spans[number<1>{}], [&](auto idx1) {
770  constexpr auto i_j_idx = make_tuple(idx0, idx1);
771  if constexpr(kHasLogitsSoftCap)
772  {
773  sp_delta(sp_reg_idx)(i_j_idx) =
774  sp(sp_reg_idx).sp_compute(i_j_idx) - m(i_j_idx);
775  }
776  else
777  {
778  sp_delta(sp_reg_idx)(i_j_idx) = detail::fma_impl_vsv(
779  sp(sp_reg_idx).sp_compute(i_j_idx), scale_s, -scale_s * m(i_j_idx));
780  }
781  });
782  });
784  };
785 
786  auto fmha_alu1 = [&](auto sp_reg_idx) {
787  constexpr auto p_spans =
788  std::decay_t<decltype(sp(sp_reg_idx).sp_compute)>::get_distributed_spans();
789  sweep_tile_span(p_spans[number<0>{}], [&](auto idx0) {
790  sweep_tile_span(p_spans[number<1>{}], [&](auto idx1) {
791  constexpr auto i_j_idx = make_tuple(idx0, idx1);
792  sp(sp_reg_idx).sp_compute(i_j_idx) =
793  ck_tile::exp2(sp_delta(sp_reg_idx)(i_j_idx));
794  });
795  });
796 
797  auto rowsum_p = block_tile_reduce<SMPLComputeDataType>(
798  sp(sp_reg_idx).sp_compute,
799  sequence<1>{},
800  f_sum,
801  SMPLComputeDataType{0}); // rowsum(Pcompute{j})
802  static_assert(rowsum_p.thread_buf_.size() == 1,
803  "assuming that each thread holds 1 rowsum value");
804 #if defined(__gfx950__)
805  // assuming that we are using 32x32 mfma
806  int32x2_t swapped_regs =
807  __builtin_amdgcn_permlane32_swap(bit_cast<int32_t>(rowsum_p.thread_buf_[0]),
808  bit_cast<int32_t>(rowsum_p.thread_buf_[0]),
809  false,
810  false);
811  rowsum_p.thread_buf_[0] = f_sum(bit_cast<SMPLComputeDataType>(swapped_regs.x),
812  bit_cast<SMPLComputeDataType>(swapped_regs.y));
813 #else
814  block_tile_reduce_sync(rowsum_p, f_sum, bool_constant<false>{});
815 #endif
816 
817  // l{j}
822  constexpr auto o_spans = decltype(o_acc)::get_distributed_spans();
823  sweep_tile_span(o_spans[number<0>{}], [&](auto idx0) {
824  constexpr auto i_idx = make_tuple(idx0);
825  const auto tmp = [&] {
826  if constexpr(kHasLogitsSoftCap)
827  {
828  return ck_tile::exp2(m_old[i_idx] - m[i_idx]);
829  }
830  else
831  {
832  return ck_tile::exp2(scale_s * (m_old[i_idx] - m[i_idx]));
833  }
834  }();
835  l(i_idx) = detail::add_impl_vv(tmp * l[i_idx], rowsum_p[i_idx]);
836  });
837 
838  // update partial o_acc [0, fmha_alu_D_reg_cnt)
839  static_for<0, fmha_alu_D_reg_cnt, 1>{}([&](auto idx) {
840  o_acc.thread_buf_[idx] = detail::mul_impl_vv(o_acc.thread_buf_[idx], o_acc_scale);
841  });
842 
847  static_assert(sp(sp_reg_idx).p.thread_buf_.size() % 2 == 0);
848  static_for<0, sp(sp_reg_idx).p.thread_buf_.size(), 2>{}([&](auto idx) {
849  float x = p_compute_element_func(sp(sp_reg_idx).sp_compute.thread_buf_[idx]);
850  float y = p_compute_element_func(sp(sp_reg_idx).sp_compute.thread_buf_[idx + 1]);
851  if constexpr(std::is_same_v<PDataType, fp16_t>)
852  {
853  auto casted = detail::cvt_pk_fp16_f32(x, y);
854  sp(sp_reg_idx).p.thread_buf_[idx] = casted.x;
855  sp(sp_reg_idx).p.thread_buf_[idx + 1] = casted.y;
856  }
857  else
858  {
859  auto casted = detail::cvt_pk_bf16_f32(x, y);
860  sp(sp_reg_idx).p.thread_buf_[idx] = casted.x;
861  sp(sp_reg_idx).p.thread_buf_[idx + 1] = casted.y;
862  }
863  });
864 
868  };
869 
870  auto gemm = [&](auto sp_reg_idx, auto gemm_idx) {
871  if constexpr(gemm_idx == 0)
872  {
873  clear_tile(sp(sp_reg_idx).sp_compute); // initialize C
874  gemm_0(sp(sp_reg_idx).sp_compute,
875  get_slice_tile(q_tile,
876  sequence<0, (k0_loops - 1) * kK0>{},
878  get_slice_tile(kv_tile.k_tile,
879  sequence<0, (k0_loops - 1) * kK0>{},
881  }
882  else
883  {
884  gemm_1(o_acc,
885  get_slice_tile(sp(sp_reg_idx).p,
886  sequence<0, (k1_loops - 1) * kK1>{},
888  get_slice_tile(kv_tile.v_tile,
889  sequence<0, (k1_loops - 1) * kK1>{},
891  }
892  };
893 
894  auto cl_calc = [&](auto sp_reg_idx, auto gemm_idx) {
895  if constexpr(gemm_idx == 0)
896  {
897  clear_tile(sp(sp_reg_idx).sp_compute); // initialize C
898  gemm_0(sp(sp_reg_idx).sp_compute,
899  get_slice_tile(q_tile,
900  sequence<0, (k0_loops - 1) * kK0>{},
902  get_slice_tile(kv_tile.k_tile,
903  sequence<0, (k0_loops - 1) * kK0>{},
905  }
906  else
907  {
908  gemm_1(o_acc,
909  get_slice_tile(sp(sp_reg_idx).p,
910  sequence<0, (k1_loops - 1) * kK1>{},
912  get_slice_tile(kv_tile.v_tile,
913  sequence<0, (k1_loops - 1) * kK1>{},
915  fmha_alu0(number<1>{} - sp_reg_idx);
916  }
917  };
918 
919  auto fmha_alu_D_upd = [&] {
920  o_acc_scale = [&] {
921  if constexpr(kHasLogitsSoftCap)
922  {
923  return ck_tile::exp2(m_old.thread_buf_[0] - m.thread_buf_[0]);
924  }
925  else
926  {
927  return ck_tile::exp2(scale_s * (m_old.thread_buf_[0] - m.thread_buf_[0]));
928  }
929  }();
930 
931  fp32x2_t pk_o_acc_scale;
932  pk_o_acc_scale.x = o_acc_scale;
933  pk_o_acc_scale.y = o_acc_scale;
934 
935  static_assert((o_acc.thread_buf_.size() - fmha_alu_D_reg_cnt) % 2 == 0);
936 #if CK_TILE_DISABLE_PACKED_FP32
937  static_assert(fmha_alu_D_reg_cnt + 2 <= o_acc.thread_buf_.size());
939  [&](auto idx) { o_acc.thread_buf_[idx] *= o_acc_scale; });
940 #endif
941 
942  constexpr auto issued_D_reg_cnt =
943 #if CK_TILE_DISABLE_PACKED_FP32
944  fmha_alu_D_reg_cnt + 2
945 #else
946  fmha_alu_D_reg_cnt
947 #endif
948  ;
951  // update partial o_acc after [issued_D_reg_cnt]
952  static_for<issued_D_reg_cnt, o_acc.thread_buf_.size(), 2>{}([&](auto idx) {
953  fp32x2_t input;
954  input.x = o_acc.thread_buf_[idx];
955  input.y = o_acc.thread_buf_[idx + 1];
956 
957  auto output = detail::pk_mul_f32(input, pk_o_acc_scale);
958 
959  o_acc.thread_buf_[idx] = output.x;
960  o_acc.thread_buf_[idx + 1] = output.y;
961  });
962  };
963 
964  auto fmha_mask = [&](auto sp_reg_idx) {
965  if constexpr(kPadSeqLenK || FmhaMask::IsMasking)
966  {
967  bool need_perpixel_check = mask.IsEdgeTile(
968  q_origin.at(number<0>{}), kv_token_start, number<kM0>{}, number<kN0>{});
969  if(need_perpixel_check)
970  {
971  set_tile_if(sp(sp_reg_idx).sp_compute,
973  [&](auto tile_idx) {
974  const auto row =
975  q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
976  const auto col = kv_token_start + tile_idx.at(number<1>{});
977  return !variant.LogitsMask(variant_params,
978  block_indices.batch_idx,
979  row,
980  col,
981  block_indices.qo_head_idx,
982  block_indices.kv_head_idx);
983  });
984  }
985  }
986  };
987 
988  auto cl_load = [&](auto load_type, auto mem_wr_idx, auto lds_rd_idx) {
989  if constexpr(load_type == 0)
990  {
991  V_mem_load(mem_wr_idx);
992  K_lds_load(lds_rd_idx);
993  }
994  else
995  {
996  K_mem_load(mem_wr_idx);
997  V_lds_load(lds_rd_idx);
998  }
999  };
1000 
1001  auto core_loop = [&](auto cl_p) {
1002  auto gemm0 = number<0>{};
1003  auto gemm1 = number<1>{};
1004 
1005  auto memV = number<0>{};
1006  auto memK = number<1>{};
1007 
1009 
1010  auto iteration = [&](auto pi) {
1011  auto xdl_SP_p01_reg_idx = number<1>{} - pi;
1012  auto xdl_SP_p23_reg_idx = pi;
1013 
1014  auto K_w0_lds_wr_idx = number<1>{} - pi;
1015  auto V_w0_lds_wr_idx = pi;
1016  auto K_w0_lds_rd_idx = pi;
1017  auto V_w0_lds_rd_idx = pi;
1018 
1019  auto K_w4_lds_wr_idx = number<1>{} - pi;
1020  auto V_w4_lds_wr_idx = number<1>{} - pi;
1021  auto K_w4_lds_rd_idx = number<1>{} - pi;
1022  auto V_w4_lds_rd_idx = pi;
1023 
1024  bool result = true;
1025 
1026  if constexpr(cl_p == 0)
1027  {
1028 #if ADD_SBARRIER_FOR_PHASE0
1029  __builtin_amdgcn_sched_barrier(0);
1030  __builtin_amdgcn_s_barrier();
1031 #endif
1032  __builtin_amdgcn_sched_barrier(0);
1033  // phase0
1034  if constexpr(pi == 0)
1035  {
1036  ASM_MARKER("phase0 Wave0-3 (pi=0)");
1037  }
1038  else
1039  {
1040  ASM_MARKER("phase0 Wave0-3 (pi=1)");
1041  }
1042  s_waitcnt_lgkmcnt<0>();
1043  __builtin_amdgcn_sched_barrier(0);
1044  cl_calc(xdl_SP_p01_reg_idx, gemm0);
1045  fmha_alu1(xdl_SP_p23_reg_idx);
1046  fmha_logits_trans(xdl_SP_p01_reg_idx);
1047 
1048  Scheduler::schedule(cl_p, number<0>{});
1049  __builtin_amdgcn_sched_barrier(0);
1050  // phase1
1051  ASM_MARKER("phase1 Wave0-3");
1052  s_waitcnt_vmcnt<K_mem_su_ld_insts + V_mem_su_ld_insts>();
1053  __builtin_amdgcn_sched_barrier(0);
1054  __builtin_amdgcn_s_barrier();
1055  __builtin_amdgcn_sched_barrier(0);
1056  cl_load(memK, K_w0_lds_wr_idx, V_w0_lds_rd_idx);
1057  Scheduler::schedule(cl_p, number<1>{});
1058  fmha_mask(xdl_SP_p01_reg_idx);
1059 
1060  __builtin_amdgcn_sched_barrier(0);
1061  // phase2
1062  ASM_MARKER("phase2 Wave0-3");
1063  s_waitcnt_lgkmcnt<0>();
1064  __builtin_amdgcn_sched_barrier(0);
1065  __builtin_amdgcn_s_barrier();
1066  __builtin_amdgcn_sched_barrier(0);
1067  asm volatile("s_nop 0");
1068  __builtin_amdgcn_sched_barrier(0);
1069  cl_calc(xdl_SP_p23_reg_idx, gemm1);
1070 
1071  Scheduler::schedule(cl_p, number<2>{});
1072  __builtin_amdgcn_sched_barrier(0);
1073  fmha_alu_D_upd();
1074 
1075  __builtin_amdgcn_sched_barrier(0);
1076  // phase3
1077  ASM_MARKER("phase3 Wave0-3");
1078  s_waitcnt_vmcnt<K_mem_su_ld_insts + V_mem_su_ld_insts>();
1079  __builtin_amdgcn_sched_barrier(0);
1080  __builtin_amdgcn_s_barrier();
1081  __builtin_amdgcn_sched_barrier(0);
1082  cl_load(memV, V_w0_lds_wr_idx, K_w0_lds_rd_idx);
1083 
1084  Scheduler::schedule(cl_p, number<3>{});
1085  kv_token_start += kN0;
1086  if(num_total_loop <= ++i_total_loops)
1087  {
1088  result = false;
1089  }
1090  }
1091  else
1092  {
1093 #if ADD_SBARRIER_FOR_PHASE0
1094  __builtin_amdgcn_sched_barrier(0);
1095  __builtin_amdgcn_s_barrier();
1096 #endif
1097  __builtin_amdgcn_sched_barrier(0);
1098  // phase0
1099  if constexpr(pi == 0)
1100  {
1101  ASM_MARKER("phase0 Wave4-7 (pi=0)");
1102  }
1103  else
1104  {
1105  ASM_MARKER("phase0 Wave4-7 (pi=1)");
1106  }
1107  cl_load(memV, V_w4_lds_wr_idx, K_w4_lds_rd_idx);
1108 
1109  Scheduler::schedule(cl_p, number<0>{});
1110  __builtin_amdgcn_sched_barrier(0);
1111  // phase1
1112  ASM_MARKER("phase1 Wave4-7");
1113  s_waitcnt<K_mem_su_ld_insts + V_mem_su_ld_insts, 0>();
1114  __builtin_amdgcn_sched_barrier(0);
1115  __builtin_amdgcn_s_barrier();
1116  __builtin_amdgcn_sched_barrier(0);
1117  asm volatile("s_nop 1");
1118  __builtin_amdgcn_sched_barrier(0);
1119  cl_calc(xdl_SP_p01_reg_idx, gemm0);
1120  fmha_alu1(xdl_SP_p23_reg_idx);
1121  fmha_logits_trans(xdl_SP_p01_reg_idx);
1122 
1123  Scheduler::schedule(cl_p, number<1>{});
1124  __builtin_amdgcn_sched_barrier(0);
1125  // phase2
1126  ASM_MARKER("phase2 Wave4-7");
1127  __builtin_amdgcn_s_barrier();
1128  __builtin_amdgcn_sched_barrier(0);
1129  cl_load(memK, K_w4_lds_wr_idx, V_w4_lds_rd_idx);
1130  Scheduler::schedule(cl_p, number<2>{});
1131  fmha_mask(xdl_SP_p01_reg_idx);
1132 
1133  kv_token_start += kN0;
1134  if(num_total_loop <= ++i_total_loops)
1135  {
1136  result = false;
1137  }
1138 
1139  __builtin_amdgcn_sched_barrier(0);
1140  // phase3
1141  ASM_MARKER("phase3 Wave4-7");
1142  s_waitcnt<K_mem_su_ld_insts + V_mem_su_ld_insts, 0>();
1143  __builtin_amdgcn_sched_barrier(0);
1144  __builtin_amdgcn_s_barrier();
1145  __builtin_amdgcn_sched_barrier(0);
1146  asm volatile("s_nop 1");
1147  __builtin_amdgcn_sched_barrier(0);
1148  cl_calc(xdl_SP_p23_reg_idx, gemm1);
1149 
1150  Scheduler::schedule(cl_p, number<3>{});
1151  __builtin_amdgcn_sched_barrier(0);
1152  fmha_alu_D_upd();
1153  }
1154  return result;
1155  };
1156  return iteration(number<0>{}) && iteration(number<1>{});
1157  };
1158 
1159  auto fmha_post_process = [&](auto d) {
1160  auto ps_pi = number<1>{} - d;
1161  auto V_lds_rd_idx = ps_pi;
1162 
1163  if(1 < num_total_loop)
1164  {
1165  s_waitcnt_vmcnt<K_mem_su_ld_insts>();
1166  }
1167  else
1168  {
1169  s_waitcnt_vmcnt<0>();
1170  }
1171  __builtin_amdgcn_s_barrier();
1172 
1173  V_lds_load(V_lds_rd_idx);
1174  fmha_alu1(ps_pi);
1175 
1176  s_waitcnt_lgkmcnt<0>();
1177 
1178  auto xdl_SP_p23_reg_idx = ps_pi;
1179  gemm(xdl_SP_p23_reg_idx, /*gemm_idx=*/number<1>{});
1180  };
1181 
1182  // pre-stage
1183  {
1184  ASM_MARKER("before pre-stage");
1185  // (1) load K0 to LDS & VGPR
1186  K_mem_load(number<0>{}); // mem_K0
1187 
1188  s_waitcnt_vmcnt<0>();
1189  __builtin_amdgcn_s_barrier();
1190 
1191  K_lds_load(number<0>{}); // lds_K0
1192 
1193  s_waitcnt_lgkmcnt<0>();
1194  __builtin_amdgcn_s_barrier();
1195 
1196  // (2) prefetch K1 and V0 to LDS in parallel with GEMM0
1197  if(1 < num_total_loop)
1198  {
1199  K_mem_load(number<1>{}); // mem_K1
1200  }
1201  V_mem_load(number<0>{}); // mem_V0
1202 
1203  // (3) mfma (Q*K0) + softmax
1204  gemm(number<0>{}, /*gemm_idx=*/number<0>{});
1205  fmha_logits_trans(number<0>{});
1206  fmha_mask(number<0>{});
1208  fmha_alu0(number<0>{});
1209  fmha_alu_D_upd();
1210 
1211  kv_token_start += kN0;
1212  ++i_total_loops;
1213  if(num_total_loop <= i_total_loops)
1214  {
1215  goto label_main_loops_exit;
1216  }
1217 
1218  if(2 < num_total_loop)
1219  {
1220  K_mem_load(number<0>{}); // mem_K2
1221 
1222  s_waitcnt_vmcnt<K_mem_su_ld_insts + V_mem_su_ld_insts>();
1223  __builtin_amdgcn_s_barrier();
1224  }
1225 
1226  ASM_MARKER("end pre-stage");
1227  }
1228 
1229  if(1 < num_total_loop)
1230  {
1231  if(warp_group_id == 0)
1232  {
1233  V_mem_load(number<1>{}); // V1
1234  K_lds_load(number<1>{}); // K1
1235 
1236  __builtin_amdgcn_s_setprio(0);
1237  __builtin_amdgcn_s_barrier();
1238  while(core_loop(number<0>{}))
1239  ;
1240  }
1241  if(warp_group_id != 0)
1242  {
1243  __builtin_amdgcn_s_setprio(1);
1244  __builtin_amdgcn_s_barrier();
1245  while(core_loop(number<1>{}))
1246  ;
1247  }
1248  }
1249  label_main_loops_exit:
1250  if(num_total_loop % 2)
1251  {
1252  fmha_post_process(number<1>{});
1253  }
1254  if(!(num_total_loop % 2))
1255  {
1256  fmha_post_process(number<0>{});
1257  }
1258 
1259  // store lse
1260  if constexpr(kStoreLSE)
1261  {
1262  auto lse = make_static_distributed_tensor<LSEDataType>(m.get_tile_distribution());
1263 
1264  constexpr auto lse_spans = decltype(lse)::get_distributed_spans();
1265  sweep_tile_span(lse_spans[number<0>{}], [&](auto idx0) {
1266  constexpr auto i_idx = make_tuple(idx0);
1267  lse(i_idx) = m[i_idx] / C_LOG2E + log(l[i_idx]);
1268  });
1269 
1270  store_tile(lse_dram_window_tmp, tile_elementwise_in(lse_element_func, lse));
1271  }
1272 
1273  // finally, O
1274  constexpr auto o_spans = decltype(o_acc)::get_distributed_spans();
1275 
1276  sweep_tile_span(o_spans[number<0>{}], [&](auto idx0) {
1277  constexpr auto i_idx = make_tuple(idx0);
1278  const auto tmp = [&]() {
1279  if constexpr(FmhaMask::IsMasking)
1280  {
1281  return l[i_idx] == 0.f ? 0.f : 1 / l[i_idx];
1282  }
1283  else
1284  return 1 / l[i_idx];
1285  }();
1286  sweep_tile_span(o_spans[number<1>{}], [&](auto idx1) {
1287  constexpr auto i_j_idx = make_tuple(idx0, idx1);
1288  o_acc(i_j_idx) *= tmp;
1289  });
1290  });
1291 
1292  o_acc = tile_elementwise_in(o_acc_element_func, o_acc);
1293 
1294  return o_acc;
1295  }
1296 
1297  template <typename QDramBlockWindowTmp,
1298  typename KDramBlockWindowTmp,
1299  typename VDramBlockWindowTmp,
1300  typename LSEDramBlockWindowTmp,
1301  typename AttentionVariantParams,
1302  typename BlockIndices>
1303  CK_TILE_DEVICE auto operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile
1304  const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile
1305  const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile
1306  LSEDramBlockWindowTmp& lse_dram_block_window_tmp, // M0*1 tile
1307  FmhaMask mask,
1308  float scale_s,
1309  const AttentionVariant& variant,
1310  const AttentionVariantParams& variant_params,
1311  const BlockIndices& block_indices,
1312  void* smem_ptr) const
1313  {
1314  using namespace ck_tile;
1315 
1316  return operator()(q_dram_block_window_tmp,
1317  identity{},
1318  k_dram_block_window_tmp,
1319  identity{},
1320  v_dram_block_window_tmp,
1321  identity{},
1322  lse_dram_block_window_tmp,
1323  identity{},
1324  identity{},
1325  identity{},
1326  identity{},
1327  mask,
1328  scale_s,
1329  variant,
1330  variant_params,
1331  block_indices,
1332  smem_ptr);
1333  }
1334 };
1335 
1336 } // namespace ck_tile
#define ASM_MARKER(marker)
Definition: block_fmha_fwd_v3_pipeline.hpp:14
#define CK_TILE_DEVICE
Definition: config.hpp:45
#define CK_TILE_HOST_DEVICE
Definition: config.hpp:46
CK_TILE_DEVICE bf16x2_t cvt_pk_bf16_f32(float a, float b)
Definition: block_fmha_fwd_v3_pipeline.hpp:232
CK_TILE_DEVICE float fma_impl_vsv(float a, float b, float c)
Definition: block_fmha_fwd_v3_pipeline.hpp:192
CK_TILE_DEVICE float mul_impl_vv(float lhs, float rhs)
Definition: block_fmha_fwd_v3_pipeline.hpp:214
CK_TILE_DEVICE float add_impl_vv(float lhs, float rhs)
Definition: block_fmha_fwd_v3_pipeline.hpp:205
CK_TILE_DEVICE fp16x2_t cvt_pk_fp16_f32(float a, float b)
Definition: block_fmha_fwd_v3_pipeline.hpp:223
CK_TILE_DEVICE fp32x2_t pk_mul_f32(fp32x2_t lhs, fp32x2_t rhs)
Definition: block_fmha_fwd_v3_pipeline.hpp:241
Definition: cluster_descriptor.hpp:13
CK_TILE_DEVICE bfloat16_t log(bfloat16_t x)
Definition: bfloat16.hpp:431
constexpr CK_TILE_HOST_DEVICE auto make_naive_tensor_descriptor(const tuple< Lengths... > &lengths, const tuple< Strides... > &strides, number< GuaranteedLastDimensionVectorLength >=number<-1 >{}, number< GuaranteedLastDimensionVectorStride >=number<-1 >{})
Definition: tensor_descriptor.hpp:274
constexpr CK_TILE_HOST_DEVICE auto integer_divide_ceil(X x, Y y)
Definition: math.hpp:145
constexpr CK_TILE_DEVICE auto get_slice_tile(const tile_window_with_static_lengths< BottomTensorView_, WindowLengths_ > &tile, sequence< SliceBegins... > slice_begins, sequence< SliceEnds... > slice_ends)
Definition: slice_tile.hpp:23
CK_TILE_DEVICE auto tile_elementwise_in(const InElementFunc &in_element_func, const InTensor &... in_dstr_tensors)
Definition: tile_elementwise.hpp:40
CK_TILE_DEVICE void set_tile(DstrTensors &dstr_tensor, const T &value)
Definition: tile_elementwise.hpp:95
CK_TILE_DEVICE void async_load_tile_raw(LdsTileWindow_ &&lds_tile, const TileWindow_ &tile_window, number< i_access >={}, bool_constant< oob_conditional_check >={}, bool_constant< pre_nop >={})
Definition: load_tile.hpp:184
bfloat16_t bf16x2_t
Definition: bfloat16.hpp:433
CK_TILE_DEVICE void block_tile_reduce_sync(AccDistributedTensor_ &acc_tensor, const ReduceFunc &reduce_func, bool_constant< WithBroadcast >={}, bool_constant< CrossWarp >={})
Definition: block_reduce.hpp:21
CK_TILE_DEVICE void tile_elementwise_inout(const InOutElementFunc &inout_element_func, InOutDstrTensors &... inout_dstr_tensors)
Definition: tile_elementwise.hpp:23
tuple_array< T, N > statically_indexed_array
Definition: statically_indexed_array.hpp:16
float fp32x2_t
Definition: bfloat16.hpp:434
int32_t index_t
Definition: integer.hpp:9
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition: type_traits.hpp:21
_Float16 fp16x2_t
Definition: half.hpp:385
CK_TILE_DEVICE void sweep_tile_span(TileDistributedSpan_, const F &f)
Definition: sweep_tile.hpp:20
constexpr CK_TILE_DEVICE auto make_tile_window(null_tensor_view, const WindowLengths &window_lengths, const multi_index< WindowLengths::size()> &, Ts &&...)
Definition: null_tile_window.hpp:75
CK_TILE_DEVICE void move_tile_window(null_tile_window< WindowLengths > &, const typename null_tile_window< WindowLengths >::BottomTensorIndex &)
Definition: null_tile_window.hpp:95
CK_TILE_HOST_DEVICE void set_tile_if(static_distributed_tensor< DataType, StaticTileDistribution > &out_tensor, DataType value, XIndicesPredicate predicate)
Definition: static_distributed_tensor.hpp:185
constexpr CK_TILE_DEVICE auto make_tile_window_linear(const TensorView_ &tensor_view, const WindowLengths_ &window_lengths, const multi_index< TensorView_::get_num_of_dimension()> &origin, const StaticTileDistribution_ &tile_distribution, LinearBottomDims_={})
Definition: tile_window_linear.hpp:994
constexpr CK_TILE_HOST_DEVICE auto generate_tuple(F &&f, number< N >)
Definition: tuple.hpp:429
constexpr CK_TILE_HOST_DEVICE auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:360
CK_TILE_DEVICE void store_tile(tile_window_with_static_lengths< BottomTensorView_, WindowLengths_ > &tile_window_tmp, const static_distributed_tensor< DataType_, TileDistribution_ > &dstr_tensor)
Definition: store_tile.hpp:24
CK_TILE_DEVICE void clear_tile(DstrTensors &dstr_tensor)
Definition: tile_elementwise.hpp:177
int32_t int32x2_t
Definition: vector_type.hpp:154
CK_TILE_DEVICE auto load_tile_transpose(const tile_window_with_static_distribution< BottomTensorView_, WindowLengths_, TileDistribution_, NumCoord > &__restrict__ tile_window)
transpose loads tile from a tensor and returns the resulting tensor with a new (transposed) tile dist...
Definition: load_tile_transpose.hpp:486
constexpr CK_TILE_HOST_DEVICE T max(T x)
Definition: math.hpp:157
CK_TILE_DEVICE auto load_tile(const TileWindow_ &tile_window, number< i_access >={}, bool_constant< oob_conditional_check >={})
Definition: load_tile.hpp:36
CK_TILE_DEVICE bfloat16_t exp2(bfloat16_t x)
Definition: bfloat16.hpp:428
constexpr bool is_same_v
Definition: type.hpp:283
const GenericPointer< typename T::ValueType > T2 T::AllocatorType & a
Definition: pointer.h:1517
Definition: block_fmha_fwd_v3_pipeline.hpp:255
ck_tile::remove_cvref_t< typename Problem::AttentionVariant > AttentionVariant
Definition: block_fmha_fwd_v3_pipeline.hpp:267
ck_tile::remove_cvref_t< typename Problem::PDataType > PDataType
Definition: block_fmha_fwd_v3_pipeline.hpp:264
ck_tile::remove_cvref_t< typename Problem::SMPLComputeDataType > SMPLComputeDataType
Definition: block_fmha_fwd_v3_pipeline.hpp:262
ck_tile::remove_cvref_t< typename Problem::KDataType > KDataType
Definition: block_fmha_fwd_v3_pipeline.hpp:259
ck_tile::remove_cvref_t< typename Problem::ODataType > ODataType
Definition: block_fmha_fwd_v3_pipeline.hpp:266
ck_tile::remove_cvref_t< typename Problem::BlockFmhaShape > BlockFmhaShape
Definition: block_fmha_fwd_v3_pipeline.hpp:274
static constexpr CK_TILE_DEVICE void s_waitcnt_lgkmcnt()
Definition: block_fmha_fwd_v3_pipeline.hpp:389
ck_tile::remove_cvref_t< typename Problem::OaccDataType > OaccDataType
Definition: block_fmha_fwd_v3_pipeline.hpp:265
static constexpr CK_TILE_DEVICE auto MakeSimpleLdsDesc1D()
Definition: block_fmha_fwd_v3_pipeline.hpp:352
static constexpr CK_TILE_DEVICE void s_waitcnt()
Definition: block_fmha_fwd_v3_pipeline.hpp:373
ck_tile::remove_cvref_t< Problem_ > Problem
Definition: block_fmha_fwd_v3_pipeline.hpp:256
static constexpr CK_TILE_DEVICE void s_waitcnt_vmcnt()
Definition: block_fmha_fwd_v3_pipeline.hpp:383
static constexpr CK_TILE_DEVICE auto make_lds_tile_window(void *base, const Descriptor &desc)
Definition: block_fmha_fwd_v3_pipeline.hpp:362
ck_tile::remove_cvref_t< typename Problem::QDataType > QDataType
Definition: block_fmha_fwd_v3_pipeline.hpp:258
ck_tile::remove_cvref_t< typename Problem::LSEDataType > LSEDataType
Definition: block_fmha_fwd_v3_pipeline.hpp:263
static constexpr CK_TILE_HOST_DEVICE ck_tile::index_t GetSmemSize()
Definition: block_fmha_fwd_v3_pipeline.hpp:328
CK_TILE_DEVICE auto operator()(const QDramBlockWindowTmp &q_dram_block_window_tmp, const QElementFunction &q_element_func, const KDramBlockWindowTmp &k_dram_block_window_tmp, [[maybe_unused]] const KElementFunction &k_element_func, const VDramBlockWindowTmp &v_dram_block_window_tmp, [[maybe_unused]] const VElementFunction &v_element_func, LSEDramBlockWindowTmp &lse_dram_window_tmp, const LSEElementFunction &lse_element_func, [[maybe_unused]] const SAccElementFunction &s_acc_element_func, const PComputeElementFunction &p_compute_element_func, const OAccElementFunction &o_acc_element_func, FmhaMask mask, float scale_s, const AttentionVariant &variant, const AttentionVariantParams &variant_params, const BlockIndices &block_indices, void *smem_ptr) const
Definition: block_fmha_fwd_v3_pipeline.hpp:407
remove_cvref_t< typename BlockFmhaShape::VLayout > VLayout
Definition: block_fmha_fwd_v3_pipeline.hpp:276
ck_tile::remove_cvref_t< typename Problem::SaccDataType > SaccDataType
Definition: block_fmha_fwd_v3_pipeline.hpp:261
ck_tile::remove_cvref_t< Policy_ > Policy
Definition: block_fmha_fwd_v3_pipeline.hpp:257
ck_tile::remove_cvref_t< typename Problem::FmhaMask > FmhaMask
Definition: block_fmha_fwd_v3_pipeline.hpp:268
CK_TILE_DEVICE auto operator()(const QDramBlockWindowTmp &q_dram_block_window_tmp, const KDramBlockWindowTmp &k_dram_block_window_tmp, const VDramBlockWindowTmp &v_dram_block_window_tmp, LSEDramBlockWindowTmp &lse_dram_block_window_tmp, FmhaMask mask, float scale_s, const AttentionVariant &variant, const AttentionVariantParams &variant_params, const BlockIndices &block_indices, void *smem_ptr) const
Definition: block_fmha_fwd_v3_pipeline.hpp:1303
ck_tile::remove_cvref_t< typename Problem::VDataType > VDataType
Definition: block_fmha_fwd_v3_pipeline.hpp:260
static constexpr CK_TILE_DEVICE auto MakeSimpleLdsDesc()
Definition: block_fmha_fwd_v3_pipeline.hpp:338
static constexpr CK_TILE_DEVICE void schedule(ck_tile::number< WaveGroup >, ck_tile::number< Phase >)
Definition: block_fmha_fwd_v3_pipeline.hpp:121
static constexpr CK_TILE_DEVICE void schedule(ck_tile::number< WaveGroup >, ck_tile::number< Phase >)
Definition: block_fmha_fwd_v3_pipeline.hpp:47
Definition: block_fmha_fwd_v3_pipeline.hpp:41
Definition: integral_constant.hpp:13
Definition: functional.hpp:114
Definition: numeric.hpp:18
Definition: coordinate_transform.hpp:1392
Definition: sequence.hpp:49
Definition: functional.hpp:43
Definition: tensor_view.hpp:41
#define C_LOG2E
Definition: math.hpp:462