/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_v3_pipeline.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_v3_pipeline.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_v3_pipeline.hpp Source File
block_fmha_fwd_v3_pipeline.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include "ck_tile/core.hpp"
9 
10 #define ENABLE_ASM_MARKER 1
11 #if ENABLE_ASM_MARKER
12 #define ASM_MARKER(marker) \
13  __builtin_amdgcn_sched_barrier(0); \
14  asm volatile("; [POYENC] " #marker); \
15  __builtin_amdgcn_sched_barrier(0);
16 #else
17 #define ASM_MARKER(marker)
18 #endif
19 
20 #define ADD_SBARRIER_FOR_PHASE0 1
21 #if !defined(CK_TILE_DISABLE_PACKED_FP32)
22 #define CK_TILE_DISABLE_PACKED_FP32 0
23 #endif
24 
25 #define WARP_ID 0
26 #define LANE_ID 0
27 
28 #define ENABLE_DEBUG_STMTS 1
29 #if ENABLE_DEBUG_STMTS
30 #define DEBUG_STMTS \
31  if(get_block_1d_id() == 0 && get_warp_id() == WARP_ID && get_lane_id() == LANE_ID)
32 #else
33 #define DEBUG_STMTS if constexpr(false)
34 #endif
35 
36 namespace ck_tile {
37 
38 template <typename PipelineProblem, bool kIsMasking>
40 
41 template <typename PipelineProblem>
42 struct CoreLoopScheduler<PipelineProblem, /*kIsMasking=*/true>
43 {
44  template <ck_tile::index_t WaveGroup, ck_tile::index_t Phase>
47  {
48  using namespace ck_tile;
49 
50  if constexpr(WaveGroup == 0)
51  {
52  if constexpr(Phase == 0)
53  {
54  static_for<0, 8, 1>{}([&](auto) {
55  __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
56  __builtin_amdgcn_sched_group_barrier(0x200, 2, 0); // TRANS
57  __builtin_amdgcn_sched_group_barrier(0x002, 2, 0); // VALU
58  });
59  }
60  else if constexpr(Phase == 1)
61  {
62  __builtin_amdgcn_sched_group_barrier(0x002, 2, 0); // VALU
63  __builtin_amdgcn_sched_group_barrier(0x004, 4, 0); // SALU
64  }
65  else if constexpr(Phase == 2)
66  {
67 #if !CK_TILE_DISABLE_PACKED_FP32
68  __builtin_amdgcn_sched_group_barrier(0x002, 4, 0); // VALU
69 #endif
70  static_for<0, 8, 1>{}([&](auto) {
71  __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
72  __builtin_amdgcn_sched_group_barrier(0x002, 4, 0); // VALU
73  });
74  }
75  else if constexpr(Phase == 3)
76  {
77  __builtin_amdgcn_sched_group_barrier(0x002, 2, 0); // VALU
78  __builtin_amdgcn_sched_group_barrier(0x004, 4, 0); // SALU
79  }
80  }
81  else
82  {
83  if constexpr(Phase == 0)
84  {
85  __builtin_amdgcn_sched_group_barrier(0x002, 2, 0); // VALU
86  __builtin_amdgcn_sched_group_barrier(0x004, 4, 0); // SALU
87  }
88  else if constexpr(Phase == 1)
89  {
90  static_for<0, 8, 1>{}([&](auto) {
91  __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
92  __builtin_amdgcn_sched_group_barrier(0x200, 2, 0); // TRANS
93  __builtin_amdgcn_sched_group_barrier(0x002, 2, 0); // VALU
94  });
95  }
96  else if constexpr(Phase == 2)
97  {
98  __builtin_amdgcn_sched_group_barrier(0x002, 2, 0); // VALU
99  __builtin_amdgcn_sched_group_barrier(0x004, 4, 0); // SALU
100  }
101  else if constexpr(Phase == 3)
102  {
103 #if !CK_TILE_DISABLE_PACKED_FP32
104  __builtin_amdgcn_sched_group_barrier(0x002, 4, 0); // VALU
105 #endif
106  static_for<0, 8, 1>{}([&](auto) {
107  __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
108  __builtin_amdgcn_sched_group_barrier(0x002, 4, 0); // VALU
109  });
110  }
111  }
112  }
113 };
114 
115 template <typename PipelineProblem>
116 struct CoreLoopScheduler<PipelineProblem, /*kIsMasking=*/false>
117 {
118  template <ck_tile::index_t WaveGroup, ck_tile::index_t Phase>
121  {
122  using namespace ck_tile;
123 
124  if constexpr(WaveGroup == 0)
125  {
126  if constexpr(Phase == 0)
127  {
128  static_for<0, 8, 1>{}([&](auto) {
129  __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
130  __builtin_amdgcn_sched_group_barrier(0x200, 2, 0); // TRANS
131  __builtin_amdgcn_sched_group_barrier(0x002, 2, 0); // VALU
132  });
133  }
134  else if constexpr(Phase == 1)
135  {
136  __builtin_amdgcn_sched_group_barrier(0x002, 2, 0); // VALU
137  __builtin_amdgcn_sched_group_barrier(0x004, 4, 0); // SALU
138  }
139  else if constexpr(Phase == 2)
140  {
141 #if !CK_TILE_DISABLE_PACKED_FP32
142  __builtin_amdgcn_sched_group_barrier(0x002, 4, 0); // VALU
143 #endif
144  static_for<0, 8, 1>{}([&](auto) {
145  __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
146  __builtin_amdgcn_sched_group_barrier(0x002, 4, 0); // VALU
147  });
148  }
149  else if constexpr(Phase == 3)
150  {
151  __builtin_amdgcn_sched_group_barrier(0x002, 2, 0); // VALU
152  __builtin_amdgcn_sched_group_barrier(0x004, 4, 0); // SALU
153  }
154  }
155  else
156  {
157  if constexpr(Phase == 0)
158  {
159  __builtin_amdgcn_sched_group_barrier(0x002, 2, 0); // VALU
160  __builtin_amdgcn_sched_group_barrier(0x004, 4, 0); // SALU
161  }
162  else if constexpr(Phase == 1)
163  {
164  static_for<0, 8, 1>{}([&](auto) {
165  __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
166  __builtin_amdgcn_sched_group_barrier(0x200, 2, 0); // TRANS
167  __builtin_amdgcn_sched_group_barrier(0x002, 2, 0); // VALU
168  });
169  }
170  else if constexpr(Phase == 2)
171  {
172  __builtin_amdgcn_sched_group_barrier(0x002, 2, 0); // VALU
173  __builtin_amdgcn_sched_group_barrier(0x004, 4, 0); // SALU
174  }
175  else if constexpr(Phase == 3)
176  {
177 #if !CK_TILE_DISABLE_PACKED_FP32
178  __builtin_amdgcn_sched_group_barrier(0x002, 4, 0); // VALU
179 #endif
180  static_for<0, 8, 1>{}([&](auto) {
181  __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
182  __builtin_amdgcn_sched_group_barrier(0x002, 4, 0); // VALU
183  });
184  }
185  }
186  }
187 };
188 
189 namespace detail {
190 CK_TILE_DEVICE float fma_impl_vsv(float a, float b, float c)
191 {
192 #if CK_TILE_DISABLE_PACKED_FP32
193  return a * b + c;
194 #else
195  float result;
196  asm volatile("v_fma_f32 %[result], %[a], %[b], %[c]"
197  : [result] "=v"(result)
198  : [a] "v"(a), [b] "s"(b), [c] "v"(c));
199  return result;
200 #endif
201 }
202 
203 CK_TILE_DEVICE float add_impl_vv(float lhs, float rhs)
204 {
205  float result;
206  asm volatile("v_add_f32_e32 %[result], %[lhs], %[rhs]"
207  : [result] "=v"(result)
208  : [lhs] "v"(lhs), [rhs] "v"(rhs));
209  return result;
210 }
211 
212 CK_TILE_DEVICE float mul_impl_vv(float lhs, float rhs)
213 {
214  float result;
215  asm volatile("v_mul_f32_e32 %[result], %[lhs], %[rhs]"
216  : [result] "=v"(result)
217  : [lhs] "v"(lhs), [rhs] "v"(rhs));
218  return result;
219 }
220 
222 {
223  fp16x2_t result;
224  asm volatile("v_cvt_pk_f16_f32 %[result], %[a], %[b]"
225  : [result] "=v"(result)
226  : [a] "v"(a), [b] "v"(b));
227  return result;
228 }
229 
231 {
232  bf16x2_t result;
233  asm volatile("v_cvt_pk_bf16_f32 %[result], %[a], %[b]"
234  : [result] "=v"(result)
235  : [a] "v"(a), [b] "v"(b));
236  return result;
237 }
238 
240 {
241  fp32x2_t result;
242  asm volatile("v_pk_mul_f32 %[result], %[lhs], %[rhs]"
243  : [result] "=v"(result)
244  : [lhs] "v"(lhs), [rhs] "v"(rhs));
245  return result;
246 }
247 } // namespace detail
248 
249 template <typename Problem_, typename Policy_ = BlockFmhaV3PipelineDefaultPolicy>
251 {
264 
265  static_assert(std::is_same_v<SaccDataType, SMPLComputeDataType>,
266  "we will the same dist tensor 'sp_compute' for both gemm0 & softmax");
267 
269 
270  static constexpr ck_tile::index_t kBlockSize = Problem::kBlockSize;
271 
272  static constexpr ck_tile::index_t kM0 = BlockFmhaShape::kM0;
273  static constexpr ck_tile::index_t kN0 = BlockFmhaShape::kN0;
274  static constexpr ck_tile::index_t kK0 = BlockFmhaShape::kK0;
275  static constexpr ck_tile::index_t kN1 = BlockFmhaShape::kN1;
276  static constexpr ck_tile::index_t kK1 = BlockFmhaShape::kK1;
277  static constexpr ck_tile::index_t kQKHeaddim = BlockFmhaShape::kQKHeaddim;
278  static constexpr ck_tile::index_t kSubQKHeaddim = BlockFmhaShape::kSubQKHeaddim;
279 
280  static_assert(kSubQKHeaddim <= 256, "hdim bigger than 256 is not suitable for this pipeline!");
281 
282  static constexpr bool kIsGroupMode = Problem::kIsGroupMode;
283  static constexpr bool kPadSeqLenQ = Problem::kPadSeqLenQ;
284  static constexpr bool kPadSeqLenK = Problem::kPadSeqLenK;
285  static constexpr bool kPadHeadDimQ = Problem::kPadHeadDimQ;
286  static constexpr bool kPadHeadDimV = Problem::kPadHeadDimV;
287  static constexpr bool kStoreLSE = Problem::kStoreLSE;
288 
289  // last dimension vector length used to create tensor view(and decide buffer_load vector length)
290  // ... together with tensor distribution. tensor dist should able to overwrite this
291  static constexpr ck_tile::index_t kAlignmentQ =
292  kPadHeadDimQ ? 1 : Policy::template GetAlignmentQ<Problem>();
293  static constexpr ck_tile::index_t kAlignmentK =
294  kPadHeadDimQ ? 1 : Policy::template GetAlignmentK<Problem>();
295  static constexpr ck_tile::index_t kAlignmentV =
296  kPadHeadDimV ? 1 : Policy::template GetAlignmentV<Problem>();
297 
298  static constexpr ck_tile::index_t kAlignmentO =
299  kPadHeadDimV ? 1 : Policy::template GetAlignmentO<Problem>();
300 
301  static constexpr ck_tile::index_t kBlockPerCu = []() {
302  if constexpr(Problem::kBlockPerCu != -1)
303  return Problem::kBlockPerCu;
304  else
305  {
306  return 2;
307  }
308  }();
309 
311  {
312  // create another LDS buffer for p
313  return ck_tile::max(kM0 * kN1 * sizeof(PDataType),
314  Policy::template GetSmemSize<Problem>() +
315  kM0 * kN0 * sizeof(PDataType));
316  }
317 
318  // for debug only
319  template <ck_tile::index_t MPerBlock, ck_tile::index_t NPerBlock>
320  CK_TILE_DEVICE static constexpr auto MakeSimpleLdsDesc()
321  {
322  using namespace ck_tile;
323  constexpr auto lds_block_desc =
326  number<1>{},
327  number<1>{});
328 
329  return lds_block_desc;
330  }
331 
332  // for debug only
333  template <ck_tile::index_t MPerBlock>
334  CK_TILE_DEVICE static constexpr auto MakeSimpleLdsDesc1D()
335  {
336  using namespace ck_tile;
337  constexpr auto lds_block_desc = make_naive_tensor_descriptor(
339 
340  return lds_block_desc;
341  }
342 
343  template <typename DataType, typename Descriptor>
344  CK_TILE_DEVICE static constexpr auto make_lds_tile_window(void* base, const Descriptor& desc)
345  {
346  using namespace ck_tile;
347 
348  auto tensor_view =
349  make_tensor_view<address_space_enum::lds>(reinterpret_cast<DataType*>(base), desc);
350  return make_tile_window(tensor_view, desc.get_lengths(), {0, 0});
351  }
352 
353  // vmcnt=0~63, lgkmcnt=0~15, expcnt=0~7
354  template <uint16_t Vmcnt, uint8_t Lgkmcnt, uint8_t Expcnt = 7>
355  CK_TILE_DEVICE static constexpr void s_waitcnt()
356  {
357  // vmcnt use bits {[15:14],[3:0]}
358  // expcnt use bits [6:4]
359  // lgkmcnt use bits [11:8]
360  __builtin_amdgcn_s_waitcnt((((0b110000 & Vmcnt) << (14 - 4)) | (0b1111 & Vmcnt)) |
361  ((0b111 & Expcnt) << 4) | ((0b1111 & Lgkmcnt) << 8));
362  }
363 
364  template <uint16_t Vmcnt>
365  CK_TILE_DEVICE static constexpr void s_waitcnt_vmcnt()
366  {
367  s_waitcnt<Vmcnt, 15>();
368  }
369 
370  template <uint8_t Lgkmcnt>
371  CK_TILE_DEVICE static constexpr void s_waitcnt_lgkmcnt()
372  {
373  s_waitcnt<63, Lgkmcnt>();
374  }
375 
376  template <typename QDramBlockWindowTmp,
377  typename KDramBlockWindowTmp,
378  typename VDramBlockWindowTmp,
379  typename LSEDramBlockWindowTmp,
380  typename QElementFunction,
381  typename KElementFunction,
382  typename VElementFunction,
383  typename LSEElementFunction,
384  typename SAccElementFunction,
385  typename PComputeElementFunction,
386  typename OAccElementFunction>
387  CK_TILE_DEVICE auto operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile
388  const QElementFunction& q_element_func,
389  const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile
390  [[maybe_unused]] const KElementFunction& k_element_func,
391  const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile
392  [[maybe_unused]] const VElementFunction& v_element_func,
393  LSEDramBlockWindowTmp& lse_dram_window_tmp, // M0*1 tile
394  const LSEElementFunction& lse_element_func,
395  [[maybe_unused]] const SAccElementFunction& s_acc_element_func,
396  const PComputeElementFunction& p_compute_element_func,
397  const OAccElementFunction& o_acc_element_func,
398  FmhaMask mask,
399  float scale_s,
400  void* smem_ptr) const
401  {
402  using namespace ck_tile;
403 
404  static_assert(
408  "wrong!");
409 
410  static_assert(kM0 == QDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
411  kN0 == KDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
412  kK0 == KDramBlockWindowTmp{}.get_window_lengths()[number<1>{}] &&
413  kK1 == VDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
414  kN1 == VDramBlockWindowTmp{}.get_window_lengths()[number<1>{}],
415  "wrong!");
416 
417  static_assert(sizeof(SaccDataType) * kM0 * kN0 <= GetSmemSize());
418  auto s_lds = make_tensor_view<address_space_enum::lds>(
419  reinterpret_cast<SaccDataType*>(static_cast<char*>(smem_ptr)),
420  MakeSimpleLdsDesc<kM0, kN0>());
421  [[maybe_unused]] auto s_lds_window =
422  make_tile_window(s_lds, make_tuple(number<kM0>{}, number<kN0>{}), {0, 0});
423 
424  auto p_lds = make_tensor_view<address_space_enum::lds>(
425  reinterpret_cast<PDataType*>(static_cast<char*>(smem_ptr) +
426  Policy::template GetSmemSize<Problem>()),
427  MakeSimpleLdsDesc<kM0, kN0>());
428  [[maybe_unused]] auto p_lds_window =
429  make_tile_window(p_lds, make_tuple(number<kM0>{}, number<kN0>{}), {0, 0});
430 
431  auto o_lds = make_tensor_view<address_space_enum::lds>(
432  reinterpret_cast<PDataType*>(static_cast<char*>(smem_ptr)),
433  MakeSimpleLdsDesc<kM0, kN1>());
434  [[maybe_unused]] auto o_lds_window =
435  make_tile_window(o_lds, make_tuple(number<kM0>{}, number<kN1>{}), {0, 0});
436 
437  auto m_lds = make_tensor_view<address_space_enum::lds>(
438  reinterpret_cast<SMPLComputeDataType*>(static_cast<char*>(smem_ptr) +
439  Policy::template GetSmemSize<Problem>()),
440  MakeSimpleLdsDesc1D<kM0>());
441  [[maybe_unused]] auto m_lds_window =
442  make_tile_window(m_lds, make_tuple(number<kM0>{}), {0});
443 
444  const index_t warp_group_id = get_warp_id() / 4;
445 
446  // Block GEMM
447  constexpr auto gemm_0 = Policy::template GetQKBlockGemm<Problem>();
448  constexpr auto gemm_1 = Policy::template GetPVBlockGemm<Problem>();
449 
450  auto q_dram_window = make_tile_window_linear(
451  q_dram_block_window_tmp, Policy::template MakeQRegTileDistribution<Problem>());
452 
453  // reduction function for softmax
454  const auto f_max = [](auto e0, auto e1) { return max(e0, e1); };
455  const auto f_sum = [](auto e0, auto e1) { return e0 + e1; };
456 
457  auto k_lds_window_store = generate_tuple(
458  [&](auto i_buf) {
459  return make_lds_tile_window<KDataType>(
460  smem_ptr, Policy::template MakeKLdsStoreBlockDescriptor<Problem>(i_buf));
461  },
462  number<2>{});
463 
464  auto v_lds_window_store = generate_tuple(
465  [&](auto i_buf) {
466  return make_lds_tile_window<KDataType>(
467  smem_ptr, Policy::template MakeVLdsStoreBlockDescriptor<Problem>(i_buf));
468  },
469  number<2>{});
470 
472  make_lds_tile_window<KDataType>(
473  nullptr,
474  Policy::template MakeKLdsLoadBlockDescriptor<Problem>()),
475  Policy::template MakeKRegTileDistribution<Problem>())),
476  2>
477  k_lds_window_load;
478 
480  make_lds_tile_window<VDataType>(
481  nullptr,
482  Policy::template MakeVLdsLoadBlockDescriptor<Problem>()),
483  Policy::template MakeVRegTileDistribution<Problem>())),
484  2>
485  v_lds_window_load;
486 
487  decltype(make_static_distributed_tensor<QDataType>(
488  Policy::template MakeQRegTileDistribution<Problem>())) q_tile;
489 
490  union kv_tile_type
491  {
492  CK_TILE_DEVICE kv_tile_type() {}
493 
494  decltype(load_tile(k_lds_window_load(number<0>{}))) k_tile;
495 
496  decltype(load_tile_transpose(v_lds_window_load(number<0>{}))) v_tile;
497  } kv_tile;
498 
499  union sp_compute_type
500  {
501  CK_TILE_DEVICE sp_compute_type() {}
502 
503  decltype(gemm_0.MakeCBlockTile()) sp_compute;
504  decltype(make_static_distributed_tensor<PDataType>(
505  Policy::template MakePRegTileDistribution<Problem>())) p;
506  };
508 
509  decltype(gemm_1.MakeCBlockTile()) o_acc;
510  constexpr index_t fmha_alu_D_reg_cnt = 6; // threshold to decide how many fmha_alu_D_upd()
511  // instructions should we move to fmha_alu1()
512  static_assert(fmha_alu_D_reg_cnt <= o_acc.thread_buf_.size());
513 
514  decltype(block_tile_reduce<SMPLComputeDataType>(
515  sp(number<0>{}).sp_compute, sequence<1>{}, f_max, SMPLComputeDataType{0})) m;
516  decltype(m) l;
517 
518  // initialize k_lds_window and v_lds_window
519  static_for<0, 2, 1>{}([&](auto idx) {
520  k_lds_window_load(idx) = make_tile_window(
521  make_lds_tile_window<KDataType>(
522  static_cast<char*>(smem_ptr) + (idx)*Policy::template GetSmemSizeKV<Problem>(),
523  Policy::template MakeKLdsLoadBlockDescriptor<Problem>()),
524  Policy::template MakeKRegTileDistribution<Problem>());
525  });
526 
527  static_for<0, 2, 1>{}([&](auto idx) {
528  v_lds_window_load(idx) =
529  make_tile_window(make_lds_tile_window<VDataType>(
530  static_cast<char*>(smem_ptr) +
531  (idx + 2) * Policy::template GetSmemSizeKV<Problem>(),
532  Policy::template MakeVLdsLoadBlockDescriptor<Problem>()),
533  Policy::template MakeVRegTileDistribution<Problem>());
534  });
535 
536  {
537  auto origin_q = load_tile(q_dram_window);
538  auto transformed_q = tile_elementwise_in(q_element_func, origin_q);
539 
540  q_tile = transformed_q;
541  }
542 
543  clear_tile(o_acc);
544  set_tile(m, bit_cast<float>(0xff7fffff)); // a bit larger than -infinity
545  clear_tile(l);
546 
547  const auto q_origin = q_dram_window.get_window_origin();
548  const auto [seqlen_k_start, seqlen_k_end] =
549  mask.GetTileRangeAlongX(q_origin.at(number<0>{}), number<kM0>{}, number<kN0>{});
550 
551  const auto num_total_loop = integer_divide_ceil(seqlen_k_end - seqlen_k_start, kN0);
552  index_t kv_token_start = seqlen_k_start;
553 
554  // check early exit if no work to do
555  if constexpr(FmhaMask::IsMasking || kPadSeqLenK)
556  {
557  if(num_total_loop <= 0)
558  {
559  if constexpr(kStoreLSE)
560  {
561  auto lse =
562  make_static_distributed_tensor<LSEDataType>(m.get_tile_distribution());
563 
565 
566  store_tile(lse_dram_window_tmp, tile_elementwise_in(lse_element_func, lse));
567  }
568 
569  // Note: here occ are all cleard, return it
570  // Note: q loaded but no fence, ignore it.
571  return o_acc;
572  }
573  }
574 
575  auto k_dram_window =
576  make_tile_window(k_dram_block_window_tmp.get_bottom_tensor_view(),
577  k_dram_block_window_tmp.get_window_lengths(),
578  {seqlen_k_start, 0},
579  Policy::template MakeKDramTileDistribution<Problem>());
580  k_dram_window.init_raw();
581 
582  auto v_dram_window =
583  make_tile_window(v_dram_block_window_tmp.get_bottom_tensor_view(),
584  v_dram_block_window_tmp.get_window_lengths(),
585  {seqlen_k_start, 0}, // TODO: hdim split?
586  Policy::template MakeVDramTileDistribution<Problem>());
587  v_dram_window.init_raw();
588 
589  // prefetch K tile
590  index_t i_total_loops = 0;
591  constexpr index_t k0_loops = kQKHeaddim / kK0;
592  constexpr index_t k1_loops = kN0 / kK1;
593  static_assert(1 == k0_loops);
594  static_assert(1 == k1_loops);
595  static_assert(kN0 == kK1);
596 
597  constexpr index_t NumWarpGroups = Problem::kBlockSize / Policy::NumThreadPerWarpGroup;
598  static_assert(NumWarpGroups == 2);
599 
600  [[maybe_unused]] auto print_dist_tensor = [&](const auto& dist_tensor, const char* name) {
601  printf("[POYENC] %s (size=%d): %5.2f",
602  name,
603  decltype(dist_tensor.thread_buf_)::size(),
604  ck_tile::type_convert<float>(dist_tensor.thread_buf_[0]));
605  static_for<1, decltype(dist_tensor.thread_buf_)::size(), 1>{}([&](auto i) {
606  printf(", %5.2f", ck_tile::type_convert<float>(dist_tensor.thread_buf_[i]));
607  });
608  printf("\n");
609  };
610 
611  [[maybe_unused]] auto print_lds = [&](auto lds_tile_window, const char* name) {
612  const auto num_rows = lds_tile_window.get_window_lengths().at(number<0>{});
613  const auto num_cols = lds_tile_window.get_window_lengths().at(number<1>{});
614 
615  auto desc = lds_tile_window.get_bottom_tensor_view().desc_;
616  auto data = lds_tile_window.get_bottom_tensor_view().buf_.p_data_;
617 
618  if constexpr(true || num_rows < num_cols)
619  {
620  for(int row = 0; row < num_rows; ++row)
621  {
622  int offset = desc.calculate_offset(make_tuple(row, 0));
623  printf("[DEVICE] %s[%3d] = %5.2f",
624  name,
625  row,
626  ck_tile::type_convert<float>(data[offset]));
627  for(int col = 1; col < num_cols; ++col)
628  {
629  printf(", ");
630  offset = desc.calculate_offset(make_tuple(row, col));
631  printf("%5.2f", ck_tile::type_convert<float>(data[offset]));
632  }
633  printf("\n");
634  }
635  }
636  else
637  {
638  for(int col = 0; col < num_cols; ++col)
639  {
640  int offset = desc.calculate_offset(make_tuple(0, col));
641  printf("[DEVICE] %s[%3d] = %5.2f",
642  name,
643  col,
644  ck_tile::type_convert<float>(data[offset]));
645  for(int row = 1; row < num_rows; ++row)
646  {
647  printf(", ");
648  offset = desc.calculate_offset(make_tuple(row, col));
649  printf("%5.2f", ck_tile::type_convert<float>(data[offset]));
650  }
651  printf("\n");
652  }
653  }
654  };
655 
656  [[maybe_unused]] auto print_lds_1d = [&](auto lds_tile_window, const char* name) {
657  const auto num_elems = lds_tile_window.get_window_lengths().at(number<0>{});
658 
659  auto desc = lds_tile_window.get_bottom_tensor_view().desc_;
660  auto data = lds_tile_window.get_bottom_tensor_view().buf_.p_data_;
661 
662  int offset = desc.calculate_offset(make_tuple(0));
663  printf("[DEVICE] %s = %5.2f", name, ck_tile::type_convert<float>(data[offset]));
664  for(int e = 1; e < num_elems; ++e)
665  {
666  printf(", ");
667  offset = desc.calculate_offset(make_tuple(e));
668  printf("%5.2f", ck_tile::type_convert<float>(data[offset]));
669  }
670  printf("\n");
671  };
672 
673  // K_mem_su_ld_insts = 1 for 32 x 128
674  // V_mem_su_ld_insts = 1 for 128 x 32
675  constexpr int K_mem_su_ld_insts = k_dram_window.get_num_of_access();
676  constexpr int V_mem_su_ld_insts = v_dram_window.get_num_of_access();
677 
678  auto K_mem_load = [&](auto k_lds_write_idx) {
679  async_load_tile_raw(k_lds_window_store(k_lds_write_idx), k_dram_window);
680 
682  // move K tile windows
683  move_tile_window(k_dram_window, {kN0, 0});
684  };
685 
686  auto K_lds_load = [&](auto k_lds_read_idx) {
687  kv_tile.k_tile = load_tile(k_lds_window_load(k_lds_read_idx));
688  };
689 
690  auto V_mem_load = [&](auto v_lds_write_idx) {
691  async_load_tile_raw(v_lds_window_store(v_lds_write_idx), v_dram_window);
692 
694  move_tile_window(v_dram_window, {kK1, 0});
695  };
696 
697  auto V_lds_load = [&](auto v_lds_read_idx) {
698  kv_tile.v_tile = load_tile_transpose(v_lds_window_load(v_lds_read_idx));
699  };
700 
701  decltype(m) m_old;
702  SMPLComputeDataType o_acc_scale; // rescale o_acc in fmha_alu1() & fmha_alu_D_upd()
704  statically_indexed_array<decltype(sp(number<0>{}).sp_compute), 2> sp_delta;
705 
706  auto fmha_alu0 = [&](auto sp_reg_idx) {
707  m_old = m; // m{j-1}
708  static_assert(m.thread_buf_.size() == 1,
709  "assuming that each thread holds 1 rowmax value");
710  auto m_latest = block_tile_reduce<SMPLComputeDataType>(
711  sp(sp_reg_idx).sp_compute, sequence<1>{}, f_max, m.thread_buf_[0]);
712 #if defined(__gfx950__)
713  // assuming that we are using 32x32 mfma
714  int32x2_t swapped_regs =
715  __builtin_amdgcn_permlane32_swap(bit_cast<int32_t>(m_latest.thread_buf_[0]),
716  bit_cast<int32_t>(m_latest.thread_buf_[0]),
717  false,
718  false);
720  m_latest.thread_buf_[0] = f_max(bit_cast<SMPLComputeDataType>(swapped_regs.x),
721  bit_cast<SMPLComputeDataType>(swapped_regs.y));
722 #else
723  block_tile_reduce_sync(m_latest, f_max, bool_constant<false>{});
724 #endif
725  m = m_latest;
726 
727  constexpr auto p_spans =
728  std::decay_t<decltype(sp(sp_reg_idx).sp_compute)>::get_distributed_spans();
729  sweep_tile_span(p_spans[number<0>{}], [&](auto idx0) {
730  sweep_tile_span(p_spans[number<1>{}], [&](auto idx1) {
731  constexpr auto i_j_idx = make_tuple(idx0, idx1);
732  sp_delta(sp_reg_idx)(i_j_idx) = detail::fma_impl_vsv(
733  sp(sp_reg_idx).sp_compute(i_j_idx), scale_s, -scale_s * m(i_j_idx));
734  });
735  });
737  };
738 
739  auto fmha_alu1 = [&](auto sp_reg_idx) {
740  constexpr auto p_spans =
741  std::decay_t<decltype(sp(sp_reg_idx).sp_compute)>::get_distributed_spans();
742  sweep_tile_span(p_spans[number<0>{}], [&](auto idx0) {
743  sweep_tile_span(p_spans[number<1>{}], [&](auto idx1) {
744  constexpr auto i_j_idx = make_tuple(idx0, idx1);
745  sp(sp_reg_idx).sp_compute(i_j_idx) =
746  ck_tile::exp2(sp_delta(sp_reg_idx)(i_j_idx));
747  });
748  });
749 
750  auto rowsum_p = block_tile_reduce<SMPLComputeDataType>(
751  sp(sp_reg_idx).sp_compute,
752  sequence<1>{},
753  f_sum,
754  SMPLComputeDataType{0}); // rowsum(Pcompute{j})
755  static_assert(rowsum_p.thread_buf_.size() == 1,
756  "assuming that each thread holds 1 rowsum value");
757 #if defined(__gfx950__)
758  // assuming that we are using 32x32 mfma
759  int32x2_t swapped_regs =
760  __builtin_amdgcn_permlane32_swap(bit_cast<int32_t>(rowsum_p.thread_buf_[0]),
761  bit_cast<int32_t>(rowsum_p.thread_buf_[0]),
762  false,
763  false);
764  rowsum_p.thread_buf_[0] = f_sum(bit_cast<SMPLComputeDataType>(swapped_regs.x),
765  bit_cast<SMPLComputeDataType>(swapped_regs.y));
766 #else
767  block_tile_reduce_sync(rowsum_p, f_sum, bool_constant<false>{});
768 #endif
769 
770  // l{j}
775  constexpr auto o_spans = decltype(o_acc)::get_distributed_spans();
776  sweep_tile_span(o_spans[number<0>{}], [&](auto idx0) {
777  constexpr auto i_idx = make_tuple(idx0);
778  const auto tmp = ck_tile::exp2(scale_s * (m_old[i_idx] - m[i_idx]));
779 
780  l(i_idx) = detail::add_impl_vv(tmp * l[i_idx], rowsum_p[i_idx]);
781  });
782 
783  // update partial o_acc [0, fmha_alu_D_reg_cnt)
784  static_for<0, fmha_alu_D_reg_cnt, 1>{}([&](auto idx) {
785  o_acc.thread_buf_[idx] = detail::mul_impl_vv(o_acc.thread_buf_[idx], o_acc_scale);
786  });
787 
792  static_assert(sp(sp_reg_idx).p.thread_buf_.size() % 2 == 0);
793  static_for<0, sp(sp_reg_idx).p.thread_buf_.size(), 2>{}([&](auto idx) {
794  float x = p_compute_element_func(sp(sp_reg_idx).sp_compute.thread_buf_[idx]);
795  float y = p_compute_element_func(sp(sp_reg_idx).sp_compute.thread_buf_[idx + 1]);
796  if constexpr(std::is_same_v<PDataType, fp16_t>)
797  {
798  auto casted = detail::cvt_pk_fp16_f32(x, y);
799  sp(sp_reg_idx).p.thread_buf_[idx] = casted.x;
800  sp(sp_reg_idx).p.thread_buf_[idx + 1] = casted.y;
801  }
802  else
803  {
804  auto casted = detail::cvt_pk_bf16_f32(x, y);
805  sp(sp_reg_idx).p.thread_buf_[idx] = casted.x;
806  sp(sp_reg_idx).p.thread_buf_[idx + 1] = casted.y;
807  }
808  });
809 
813  };
814 
815  auto gemm = [&](auto sp_reg_idx, auto gemm_idx) {
816  if constexpr(gemm_idx == 0)
817  {
818  clear_tile(sp(sp_reg_idx).sp_compute); // initialize C
819  gemm_0(sp(sp_reg_idx).sp_compute,
820  get_slice_tile(q_tile,
821  sequence<0, (k0_loops - 1) * kK0>{},
823  get_slice_tile(kv_tile.k_tile,
824  sequence<0, (k0_loops - 1) * kK0>{},
826  }
827  else
828  {
829  gemm_1(o_acc,
830  get_slice_tile(sp(sp_reg_idx).p,
831  sequence<0, (k1_loops - 1) * kK1>{},
833  get_slice_tile(kv_tile.v_tile,
834  sequence<0, (k1_loops - 1) * kK1>{},
836  }
837  };
838 
839  auto cl_calc = [&](auto sp_reg_idx, auto gemm_idx) {
840  if constexpr(gemm_idx == 0)
841  {
842  clear_tile(sp(sp_reg_idx).sp_compute); // initialize C
843  gemm_0(sp(sp_reg_idx).sp_compute,
844  get_slice_tile(q_tile,
845  sequence<0, (k0_loops - 1) * kK0>{},
847  get_slice_tile(kv_tile.k_tile,
848  sequence<0, (k0_loops - 1) * kK0>{},
850  }
851  else
852  {
853  gemm_1(o_acc,
854  get_slice_tile(sp(sp_reg_idx).p,
855  sequence<0, (k1_loops - 1) * kK1>{},
857  get_slice_tile(kv_tile.v_tile,
858  sequence<0, (k1_loops - 1) * kK1>{},
860  fmha_alu0(number<1>{} - sp_reg_idx);
861  }
862  };
863 
864  auto fmha_alu_D_upd = [&] {
865  o_acc_scale = ck_tile::exp2(scale_s * (m_old.thread_buf_[0] - m.thread_buf_[0]));
866 
867  fp32x2_t pk_o_acc_scale;
868  pk_o_acc_scale.x = o_acc_scale;
869  pk_o_acc_scale.y = o_acc_scale;
870 
871  static_assert((o_acc.thread_buf_.size() - fmha_alu_D_reg_cnt) % 2 == 0);
872 #if CK_TILE_DISABLE_PACKED_FP32
873  static_assert(fmha_alu_D_reg_cnt + 2 <= o_acc.thread_buf_.size());
875  [&](auto idx) { o_acc.thread_buf_[idx] *= o_acc_scale; });
876 #endif
877 
878  constexpr auto issued_D_reg_cnt =
879 #if CK_TILE_DISABLE_PACKED_FP32
880  fmha_alu_D_reg_cnt + 2
881 #else
882  fmha_alu_D_reg_cnt
883 #endif
884  ;
887  // update partial o_acc after [issued_D_reg_cnt]
888  static_for<issued_D_reg_cnt, o_acc.thread_buf_.size(), 2>{}([&](auto idx) {
889  fp32x2_t input;
890  input.x = o_acc.thread_buf_[idx];
891  input.y = o_acc.thread_buf_[idx + 1];
892 
893  auto output = detail::pk_mul_f32(input, pk_o_acc_scale);
894 
895  o_acc.thread_buf_[idx] = output.x;
896  o_acc.thread_buf_[idx + 1] = output.y;
897  });
898  };
899 
900  auto fmha_mask = [&](auto sp_reg_idx) {
901  if constexpr(kPadSeqLenK || FmhaMask::IsMasking)
902  {
903  bool need_perpixel_check = mask.IsEdgeTile(
904  q_origin.at(number<0>{}), kv_token_start, number<kM0>{}, number<kN0>{});
905  if(need_perpixel_check)
906  {
907  set_tile_if(sp(sp_reg_idx).sp_compute,
909  [&](auto tile_idx) {
910  const auto row =
911  q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
912  const auto col = kv_token_start + tile_idx.at(number<1>{});
913  return mask.IsOutOfBound(row, col);
914  });
915  }
916  }
917  };
918 
919  auto cl_load = [&](auto load_type, auto mem_wr_idx, auto lds_rd_idx) {
920  if constexpr(load_type == 0)
921  {
922  V_mem_load(mem_wr_idx);
923  K_lds_load(lds_rd_idx);
924  }
925  else
926  {
927  K_mem_load(mem_wr_idx);
928  V_lds_load(lds_rd_idx);
929  }
930  };
931 
932  auto core_loop = [&](auto cl_p) {
933  auto gemm0 = number<0>{};
934  auto gemm1 = number<1>{};
935 
936  auto memV = number<0>{};
937  auto memK = number<1>{};
938 
940 
941  auto iteration = [&](auto pi) {
942  auto xdl_SP_p01_reg_idx = number<1>{} - pi;
943  auto xdl_SP_p23_reg_idx = pi;
944 
945  auto K_w0_lds_wr_idx = number<1>{} - pi;
946  auto V_w0_lds_wr_idx = pi;
947  auto K_w0_lds_rd_idx = pi;
948  auto V_w0_lds_rd_idx = pi;
949 
950  auto K_w4_lds_wr_idx = number<1>{} - pi;
951  auto V_w4_lds_wr_idx = number<1>{} - pi;
952  auto K_w4_lds_rd_idx = number<1>{} - pi;
953  auto V_w4_lds_rd_idx = pi;
954 
955  bool result = true;
956 
957  if constexpr(cl_p == 0)
958  {
959 #if ADD_SBARRIER_FOR_PHASE0
960  __builtin_amdgcn_sched_barrier(0);
961  __builtin_amdgcn_s_barrier();
962 #endif
963  __builtin_amdgcn_sched_barrier(0);
964  // phase0
965  if constexpr(pi == 0)
966  {
967  ASM_MARKER("phase0 Wave0-3 (pi=0)");
968  }
969  else
970  {
971  ASM_MARKER("phase0 Wave0-3 (pi=1)");
972  }
973  s_waitcnt_lgkmcnt<0>();
974  __builtin_amdgcn_sched_barrier(0);
975  cl_calc(xdl_SP_p01_reg_idx, gemm0);
976  fmha_alu1(xdl_SP_p23_reg_idx);
977 
978  Scheduler::schedule(cl_p, number<0>{});
979  __builtin_amdgcn_sched_barrier(0);
980  // phase1
981  ASM_MARKER("phase1 Wave0-3");
982  s_waitcnt_vmcnt<K_mem_su_ld_insts + V_mem_su_ld_insts>();
983  __builtin_amdgcn_sched_barrier(0);
984  __builtin_amdgcn_s_barrier();
985  __builtin_amdgcn_sched_barrier(0);
986  cl_load(memK, K_w0_lds_wr_idx, V_w0_lds_rd_idx);
987  Scheduler::schedule(cl_p, number<1>{});
988  fmha_mask(xdl_SP_p01_reg_idx);
989 
990  __builtin_amdgcn_sched_barrier(0);
991  // phase2
992  ASM_MARKER("phase2 Wave0-3");
993  s_waitcnt_lgkmcnt<0>();
994  __builtin_amdgcn_sched_barrier(0);
995  __builtin_amdgcn_s_barrier();
996  __builtin_amdgcn_sched_barrier(0);
997  asm volatile("s_nop 0");
998  __builtin_amdgcn_sched_barrier(0);
999  cl_calc(xdl_SP_p23_reg_idx, gemm1);
1000 
1001  Scheduler::schedule(cl_p, number<2>{});
1002  __builtin_amdgcn_sched_barrier(0);
1003  fmha_alu_D_upd();
1004 
1005  __builtin_amdgcn_sched_barrier(0);
1006  // phase3
1007  ASM_MARKER("phase3 Wave0-3");
1008  s_waitcnt_vmcnt<K_mem_su_ld_insts + V_mem_su_ld_insts>();
1009  __builtin_amdgcn_sched_barrier(0);
1010  __builtin_amdgcn_s_barrier();
1011  __builtin_amdgcn_sched_barrier(0);
1012  cl_load(memV, V_w0_lds_wr_idx, K_w0_lds_rd_idx);
1013 
1014  Scheduler::schedule(cl_p, number<3>{});
1015  kv_token_start += kN0;
1016  if(num_total_loop <= ++i_total_loops)
1017  {
1018  result = false;
1019  }
1020  }
1021  else
1022  {
1023 #if ADD_SBARRIER_FOR_PHASE0
1024  __builtin_amdgcn_sched_barrier(0);
1025  __builtin_amdgcn_s_barrier();
1026 #endif
1027  __builtin_amdgcn_sched_barrier(0);
1028  // phase0
1029  if constexpr(pi == 0)
1030  {
1031  ASM_MARKER("phase0 Wave4-7 (pi=0)");
1032  }
1033  else
1034  {
1035  ASM_MARKER("phase0 Wave4-7 (pi=1)");
1036  }
1037  cl_load(memV, V_w4_lds_wr_idx, K_w4_lds_rd_idx);
1038 
1039  Scheduler::schedule(cl_p, number<0>{});
1040  __builtin_amdgcn_sched_barrier(0);
1041  // phase1
1042  ASM_MARKER("phase1 Wave4-7");
1043  s_waitcnt<K_mem_su_ld_insts + V_mem_su_ld_insts, 0>();
1044  __builtin_amdgcn_sched_barrier(0);
1045  __builtin_amdgcn_s_barrier();
1046  __builtin_amdgcn_sched_barrier(0);
1047  asm volatile("s_nop 1");
1048  __builtin_amdgcn_sched_barrier(0);
1049  cl_calc(xdl_SP_p01_reg_idx, gemm0);
1050  fmha_alu1(xdl_SP_p23_reg_idx);
1051 
1052  Scheduler::schedule(cl_p, number<1>{});
1053  __builtin_amdgcn_sched_barrier(0);
1054  // phase2
1055  ASM_MARKER("phase2 Wave4-7");
1056  __builtin_amdgcn_s_barrier();
1057  __builtin_amdgcn_sched_barrier(0);
1058  cl_load(memK, K_w4_lds_wr_idx, V_w4_lds_rd_idx);
1059  Scheduler::schedule(cl_p, number<2>{});
1060  fmha_mask(xdl_SP_p01_reg_idx);
1061 
1062  kv_token_start += kN0;
1063  if(num_total_loop <= ++i_total_loops)
1064  {
1065  result = false;
1066  }
1067 
1068  __builtin_amdgcn_sched_barrier(0);
1069  // phase3
1070  ASM_MARKER("phase3 Wave4-7");
1071  s_waitcnt<K_mem_su_ld_insts + V_mem_su_ld_insts, 0>();
1072  __builtin_amdgcn_sched_barrier(0);
1073  __builtin_amdgcn_s_barrier();
1074  __builtin_amdgcn_sched_barrier(0);
1075  asm volatile("s_nop 1");
1076  __builtin_amdgcn_sched_barrier(0);
1077  cl_calc(xdl_SP_p23_reg_idx, gemm1);
1078 
1079  Scheduler::schedule(cl_p, number<3>{});
1080  __builtin_amdgcn_sched_barrier(0);
1081  fmha_alu_D_upd();
1082  }
1083  return result;
1084  };
1085  return iteration(number<0>{}) && iteration(number<1>{});
1086  };
1087 
1088  auto fmha_post_process = [&](auto d) {
1089  auto ps_pi = number<1>{} - d;
1090  auto V_lds_rd_idx = ps_pi;
1091 
1092  if(1 < num_total_loop)
1093  {
1094  s_waitcnt_vmcnt<K_mem_su_ld_insts>();
1095  }
1096  else
1097  {
1098  s_waitcnt_vmcnt<0>();
1099  }
1100  __builtin_amdgcn_s_barrier();
1101 
1102  V_lds_load(V_lds_rd_idx);
1103  fmha_alu1(ps_pi);
1104 
1105  s_waitcnt_lgkmcnt<0>();
1106 
1107  auto xdl_SP_p23_reg_idx = ps_pi;
1108  gemm(xdl_SP_p23_reg_idx, /*gemm_idx=*/number<1>{});
1109  };
1110 
1111  // pre-stage
1112  {
1113  ASM_MARKER("before pre-stage");
1114  // (1) load K0 to LDS & VGPR
1115  K_mem_load(number<0>{}); // mem_K0
1116 
1117  s_waitcnt_vmcnt<0>();
1118  __builtin_amdgcn_s_barrier();
1119 
1120  K_lds_load(number<0>{}); // lds_K0
1121 
1122  s_waitcnt_lgkmcnt<0>();
1123  __builtin_amdgcn_s_barrier();
1124 
1125  // (2) prefetch K1 and V0 to LDS in parallel with GEMM0
1126  if(1 < num_total_loop)
1127  {
1128  K_mem_load(number<1>{}); // mem_K1
1129  }
1130  V_mem_load(number<0>{}); // mem_V0
1131 
1132  // (3) mfma (Q*K0) + softmax
1133  gemm(number<0>{}, /*gemm_idx=*/number<0>{});
1134 
1135  fmha_mask(number<0>{});
1137  fmha_alu0(number<0>{});
1138  fmha_alu_D_upd();
1139 
1140  kv_token_start += kN0;
1141  ++i_total_loops;
1142  if(num_total_loop <= i_total_loops)
1143  {
1144  goto label_main_loops_exit;
1145  }
1146 
1147  if(2 < num_total_loop)
1148  {
1149  K_mem_load(number<0>{}); // mem_K2
1150 
1151  s_waitcnt_vmcnt<K_mem_su_ld_insts + V_mem_su_ld_insts>();
1152  __builtin_amdgcn_s_barrier();
1153  }
1154 
1155  ASM_MARKER("end pre-stage");
1156  }
1157 
1158  if(1 < num_total_loop)
1159  {
1160  if(warp_group_id == 0)
1161  {
1162  V_mem_load(number<1>{}); // V1
1163  K_lds_load(number<1>{}); // K1
1164 
1165  __builtin_amdgcn_s_setprio(0);
1166  __builtin_amdgcn_s_barrier();
1167  while(core_loop(number<0>{}))
1168  ;
1169  }
1170  if(warp_group_id != 0)
1171  {
1172  __builtin_amdgcn_s_setprio(1);
1173  __builtin_amdgcn_s_barrier();
1174  while(core_loop(number<1>{}))
1175  ;
1176  }
1177  }
1178  label_main_loops_exit:
1179  if(num_total_loop % 2)
1180  {
1181  fmha_post_process(number<1>{});
1182  }
1183  if(!(num_total_loop % 2))
1184  {
1185  fmha_post_process(number<0>{});
1186  }
1187 
1188  // store lse
1189  if constexpr(kStoreLSE)
1190  {
1191  auto lse = make_static_distributed_tensor<LSEDataType>(m.get_tile_distribution());
1192 
1193  constexpr auto lse_spans = decltype(lse)::get_distributed_spans();
1194  sweep_tile_span(lse_spans[number<0>{}], [&](auto idx0) {
1195  constexpr auto i_idx = make_tuple(idx0);
1196  lse(i_idx) = m[i_idx] / C_LOG2E + log(l[i_idx]);
1197  });
1198 
1199  store_tile(lse_dram_window_tmp, tile_elementwise_in(lse_element_func, lse));
1200  }
1201 
1202  // finally, O
1203  constexpr auto o_spans = decltype(o_acc)::get_distributed_spans();
1204 
1205  sweep_tile_span(o_spans[number<0>{}], [&](auto idx0) {
1206  constexpr auto i_idx = make_tuple(idx0);
1207  const auto tmp = [&]() {
1208  if constexpr(FmhaMask::IsMasking)
1209  {
1210  return l[i_idx] == 0.f ? 0.f : 1 / l[i_idx];
1211  }
1212  else
1213  return 1 / l[i_idx];
1214  }();
1215  sweep_tile_span(o_spans[number<1>{}], [&](auto idx1) {
1216  constexpr auto i_j_idx = make_tuple(idx0, idx1);
1217  o_acc(i_j_idx) *= tmp;
1218  });
1219  });
1220 
1221  o_acc = tile_elementwise_in(o_acc_element_func, o_acc);
1222 
1223  return o_acc;
1224  }
1225 
1226  template <typename QDramBlockWindowTmp,
1227  typename KDramBlockWindowTmp,
1228  typename VDramBlockWindowTmp,
1229  typename LSEDramBlockWindowTmp>
1230  CK_TILE_DEVICE auto operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile
1231  const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile
1232  const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile
1233  LSEDramBlockWindowTmp& lse_dram_block_window_tmp, // M0*1 tile
1234  FmhaMask mask,
1235  float scale_s,
1236  void* smem_ptr) const
1237  {
1238  using namespace ck_tile;
1239 
1240  return operator()(q_dram_block_window_tmp,
1241  identity{},
1242  k_dram_block_window_tmp,
1243  identity{},
1244  v_dram_block_window_tmp,
1245  identity{},
1246  lse_dram_block_window_tmp,
1247  identity{},
1248  identity{},
1249  identity{},
1250  identity{},
1251  mask,
1252  scale_s,
1253  smem_ptr);
1254  }
1255 };
1256 
1257 } // namespace ck_tile
#define ASM_MARKER(marker)
Definition: block_fmha_fwd_v3_pipeline.hpp:12
#define CK_TILE_DEVICE
Definition: config.hpp:41
#define CK_TILE_HOST_DEVICE
Definition: config.hpp:42
CK_TILE_DEVICE bf16x2_t cvt_pk_bf16_f32(float a, float b)
Definition: block_fmha_fwd_v3_pipeline.hpp:230
CK_TILE_DEVICE float fma_impl_vsv(float a, float b, float c)
Definition: block_fmha_fwd_v3_pipeline.hpp:190
CK_TILE_DEVICE float mul_impl_vv(float lhs, float rhs)
Definition: block_fmha_fwd_v3_pipeline.hpp:212
CK_TILE_DEVICE float add_impl_vv(float lhs, float rhs)
Definition: block_fmha_fwd_v3_pipeline.hpp:203
CK_TILE_DEVICE fp16x2_t cvt_pk_fp16_f32(float a, float b)
Definition: block_fmha_fwd_v3_pipeline.hpp:221
CK_TILE_DEVICE fp32x2_t pk_mul_f32(fp32x2_t lhs, fp32x2_t rhs)
Definition: block_fmha_fwd_v3_pipeline.hpp:239
Definition: cluster_descriptor.hpp:13
CK_TILE_DEVICE bfloat16_t log(bfloat16_t x)
Definition: bfloat16.hpp:420
constexpr CK_TILE_HOST_DEVICE auto make_naive_tensor_descriptor(const tuple< Lengths... > &lengths, const tuple< Strides... > &strides, number< GuaranteedLastDimensionVectorLength >=number<-1 >{}, number< GuaranteedLastDimensionVectorStride >=number<-1 >{})
Definition: tensor_descriptor.hpp:274
constexpr CK_TILE_HOST_DEVICE auto integer_divide_ceil(X x, Y y)
Definition: math.hpp:149
constexpr CK_TILE_DEVICE auto get_slice_tile(const tile_window_with_static_lengths< BottomTensorView_, WindowLengths_ > &tile, sequence< SliceBegins... > slice_begins, sequence< SliceEnds... > slice_ends)
Definition: slice_tile.hpp:23
CK_TILE_DEVICE auto tile_elementwise_in(const InElementFunc &in_element_func, const InTensor &... in_dstr_tensors)
Definition: tile_elementwise.hpp:40
CK_TILE_DEVICE void set_tile(DstrTensors &dstr_tensor, const T &value)
Definition: tile_elementwise.hpp:95
bfloat16_t bf16x2_t
Definition: pk_fp4.hpp:24
CK_TILE_DEVICE void block_tile_reduce_sync(AccDistributedTensor_ &acc_tensor, const ReduceFunc &reduce_func, bool_constant< WithBroadcast >={}, bool_constant< CrossWarp >={})
Definition: block_reduce.hpp:21
tuple_array< T, N > statically_indexed_array
Definition: statically_indexed_array.hpp:16
float fp32x2_t
Definition: pk_fp4.hpp:22
int32_t index_t
Definition: integer.hpp:9
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition: type_traits.hpp:21
_Float16 fp16x2_t
Definition: half.hpp:385
CK_TILE_DEVICE auto load_tile_transpose(const tile_window_with_static_distribution< BottomTensorView_, WindowLengths_, TileDistribution_, NumCoord > &tile_window)
transpose loads tile from a tensor and returns the resulting tensor with a new (transposed) tile dist...
Definition: load_tile_transpose.hpp:403
CK_TILE_DEVICE void sweep_tile_span(TileDistributedSpan_, const F &f)
Definition: sweep_tile.hpp:20
constexpr CK_TILE_DEVICE auto make_tile_window(null_tensor_view, const WindowLengths &window_lengths, const multi_index< WindowLengths::size()> &, Ts &&...)
Definition: null_tile_window.hpp:75
CK_TILE_DEVICE void move_tile_window(null_tile_window< WindowLengths > &, const typename null_tile_window< WindowLengths >::BottomTensorIndex &)
Definition: null_tile_window.hpp:95
CK_TILE_HOST_DEVICE void set_tile_if(static_distributed_tensor< DataType, StaticTileDistribution > &out_tensor, DataType value, XIndicesPredicate predicate)
Definition: static_distributed_tensor.hpp:175
constexpr CK_TILE_DEVICE auto make_tile_window_linear(const TensorView_ &tensor_view, const WindowLengths_ &window_lengths, const multi_index< TensorView_::get_num_of_dimension()> &origin, const StaticTileDistribution_ &tile_distribution, LinearBottomDims_={})
Definition: tile_window_linear.hpp:993
constexpr CK_TILE_HOST_DEVICE auto generate_tuple(F &&f, number< N >)
Definition: tuple.hpp:429
constexpr CK_TILE_HOST_DEVICE auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:360
CK_TILE_DEVICE auto async_load_tile_raw(LdsTileWindow_ &&lds_tile, const TileWindow_ &tile_window, number< i_access >={}, bool_constant< oob_conditional_check >={}, bool_constant< pre_nop >={})
Definition: load_tile.hpp:133
CK_TILE_DEVICE void store_tile(tile_window_with_static_lengths< BottomTensorView_, WindowLengths_ > &tile_window_tmp, const static_distributed_tensor< DataType_, TileDistribution_ > &dstr_tensor)
Definition: store_tile.hpp:23
CK_TILE_DEVICE void clear_tile(DstrTensors &dstr_tensor)
Definition: tile_elementwise.hpp:177
int32_t int32x2_t
Definition: vector_type.hpp:143
constexpr CK_TILE_HOST_DEVICE T max(T x)
Definition: math.hpp:161
CK_TILE_DEVICE auto load_tile(const TileWindow_ &tile_window, number< i_access >={}, bool_constant< oob_conditional_check >={})
Definition: load_tile.hpp:22
CK_TILE_DEVICE bfloat16_t exp2(bfloat16_t x)
Definition: bfloat16.hpp:417
constexpr bool is_same_v
Definition: type.hpp:283
const GenericPointer< typename T::ValueType > T2 T::AllocatorType & a
Definition: pointer.h:1249
Definition: block_fmha_fwd_v3_pipeline.hpp:251
ck_tile::remove_cvref_t< typename Problem::PDataType > PDataType
Definition: block_fmha_fwd_v3_pipeline.hpp:260
ck_tile::remove_cvref_t< typename Problem::SMPLComputeDataType > SMPLComputeDataType
Definition: block_fmha_fwd_v3_pipeline.hpp:258
ck_tile::remove_cvref_t< typename Problem::KDataType > KDataType
Definition: block_fmha_fwd_v3_pipeline.hpp:255
ck_tile::remove_cvref_t< typename Problem::ODataType > ODataType
Definition: block_fmha_fwd_v3_pipeline.hpp:262
ck_tile::remove_cvref_t< typename Problem::BlockFmhaShape > BlockFmhaShape
Definition: block_fmha_fwd_v3_pipeline.hpp:268
static constexpr CK_TILE_DEVICE void s_waitcnt_lgkmcnt()
Definition: block_fmha_fwd_v3_pipeline.hpp:371
ck_tile::remove_cvref_t< typename Problem::OaccDataType > OaccDataType
Definition: block_fmha_fwd_v3_pipeline.hpp:261
static constexpr CK_TILE_DEVICE auto MakeSimpleLdsDesc1D()
Definition: block_fmha_fwd_v3_pipeline.hpp:334
static constexpr CK_TILE_DEVICE void s_waitcnt()
Definition: block_fmha_fwd_v3_pipeline.hpp:355
ck_tile::remove_cvref_t< Problem_ > Problem
Definition: block_fmha_fwd_v3_pipeline.hpp:252
static constexpr CK_TILE_DEVICE void s_waitcnt_vmcnt()
Definition: block_fmha_fwd_v3_pipeline.hpp:365
static constexpr CK_TILE_DEVICE auto make_lds_tile_window(void *base, const Descriptor &desc)
Definition: block_fmha_fwd_v3_pipeline.hpp:344
ck_tile::remove_cvref_t< typename Problem::QDataType > QDataType
Definition: block_fmha_fwd_v3_pipeline.hpp:254
ck_tile::remove_cvref_t< typename Problem::LSEDataType > LSEDataType
Definition: block_fmha_fwd_v3_pipeline.hpp:259
static constexpr CK_TILE_HOST_DEVICE ck_tile::index_t GetSmemSize()
Definition: block_fmha_fwd_v3_pipeline.hpp:310
ck_tile::remove_cvref_t< typename Problem::SaccDataType > SaccDataType
Definition: block_fmha_fwd_v3_pipeline.hpp:257
ck_tile::remove_cvref_t< Policy_ > Policy
Definition: block_fmha_fwd_v3_pipeline.hpp:253
ck_tile::remove_cvref_t< typename Problem::FmhaMask > FmhaMask
Definition: block_fmha_fwd_v3_pipeline.hpp:263
CK_TILE_DEVICE auto operator()(const QDramBlockWindowTmp &q_dram_block_window_tmp, const QElementFunction &q_element_func, const KDramBlockWindowTmp &k_dram_block_window_tmp, [[maybe_unused]] const KElementFunction &k_element_func, const VDramBlockWindowTmp &v_dram_block_window_tmp, [[maybe_unused]] const VElementFunction &v_element_func, LSEDramBlockWindowTmp &lse_dram_window_tmp, const LSEElementFunction &lse_element_func, [[maybe_unused]] const SAccElementFunction &s_acc_element_func, const PComputeElementFunction &p_compute_element_func, const OAccElementFunction &o_acc_element_func, FmhaMask mask, float scale_s, void *smem_ptr) const
Definition: block_fmha_fwd_v3_pipeline.hpp:387
ck_tile::remove_cvref_t< typename Problem::VDataType > VDataType
Definition: block_fmha_fwd_v3_pipeline.hpp:256
static constexpr CK_TILE_DEVICE auto MakeSimpleLdsDesc()
Definition: block_fmha_fwd_v3_pipeline.hpp:320
CK_TILE_DEVICE auto operator()(const QDramBlockWindowTmp &q_dram_block_window_tmp, const KDramBlockWindowTmp &k_dram_block_window_tmp, const VDramBlockWindowTmp &v_dram_block_window_tmp, LSEDramBlockWindowTmp &lse_dram_block_window_tmp, FmhaMask mask, float scale_s, void *smem_ptr) const
Definition: block_fmha_fwd_v3_pipeline.hpp:1230
static constexpr CK_TILE_DEVICE void schedule(ck_tile::number< WaveGroup >, ck_tile::number< Phase >)
Definition: block_fmha_fwd_v3_pipeline.hpp:119
static constexpr CK_TILE_DEVICE void schedule(ck_tile::number< WaveGroup >, ck_tile::number< Phase >)
Definition: block_fmha_fwd_v3_pipeline.hpp:45
Definition: block_fmha_fwd_v3_pipeline.hpp:39
Definition: integral_constant.hpp:13
Definition: functional.hpp:86
Definition: numeric.hpp:18
Definition: coordinate_transform.hpp:1392
Definition: sequence.hpp:49
Definition: functional.hpp:43
Definition: tensor_view.hpp:41
#define C_LOG2E
Definition: math.hpp:469