/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_v3_pipeline.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_v3_pipeline.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_v3_pipeline.hpp Source File
block_fmha_fwd_v3_pipeline.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include "ck_tile/core.hpp"
9 
10 #define ENABLE_ASM_MARKER 1
11 #if ENABLE_ASM_MARKER
12 #define ASM_MARKER(marker) \
13  __builtin_amdgcn_sched_barrier(0); \
14  asm volatile("; [POYENC] " #marker); \
15  __builtin_amdgcn_sched_barrier(0);
16 #else
17 #define ASM_MARKER(marker)
18 #endif
19 
20 #define ADD_SBARRIER_FOR_PHASE0 1
21 #if !defined(CK_TILE_DISABLE_PACKED_FP32)
22 #define CK_TILE_DISABLE_PACKED_FP32 0
23 #endif
24 
25 #define WARP_ID 0
26 #define LANE_ID 0
27 
28 #define ENABLE_DEBUG_STMTS 1
29 #if ENABLE_DEBUG_STMTS
30 #define DEBUG_STMTS \
31  if(get_block_1d_id() == 0 && get_warp_id() == WARP_ID && get_lane_id() == LANE_ID)
32 #else
33 #define DEBUG_STMTS if constexpr(false)
34 #endif
35 
36 namespace ck_tile {
37 
38 template <typename PipelineProblem, bool kIsMasking>
40 
41 template <typename PipelineProblem>
42 struct CoreLoopScheduler<PipelineProblem, /*kIsMasking=*/true>
43 {
44  template <ck_tile::index_t WaveGroup, ck_tile::index_t Phase>
47  {
48  using namespace ck_tile;
49 
50  if constexpr(WaveGroup == 0)
51  {
52  if constexpr(Phase == 0)
53  {
54  static_for<0, 8, 1>{}([&](auto) {
55  __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
56  __builtin_amdgcn_sched_group_barrier(0x200, 2, 0); // TRANS
57  __builtin_amdgcn_sched_group_barrier(0x002, 2, 0); // VALU
58  });
59  }
60  else if constexpr(Phase == 1) {}
61  else if constexpr(Phase == 2)
62  {
63 #if !CK_TILE_DISABLE_PACKED_FP32
64  __builtin_amdgcn_sched_group_barrier(0x002, 4, 0); // VALU
65 #endif
66  static_for<0, 8, 1>{}([&](auto) {
67  __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
68  __builtin_amdgcn_sched_group_barrier(0x002, 4, 0); // VALU
69  });
70  }
71  else if constexpr(Phase == 3) {}
72  }
73  else
74  {
75  if constexpr(Phase == 0) {}
76  else if constexpr(Phase == 1)
77  {
78  static_for<0, 8, 1>{}([&](auto) {
79  __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
80  __builtin_amdgcn_sched_group_barrier(0x200, 2, 0); // TRANS
81  __builtin_amdgcn_sched_group_barrier(0x002, 2, 0); // VALU
82  });
83  }
84  else if constexpr(Phase == 2) {}
85  else if constexpr(Phase == 3)
86  {
87 #if !CK_TILE_DISABLE_PACKED_FP32
88  __builtin_amdgcn_sched_group_barrier(0x002, 4, 0); // VALU
89 #endif
90  static_for<0, 8, 1>{}([&](auto) {
91  __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
92  __builtin_amdgcn_sched_group_barrier(0x002, 4, 0); // VALU
93  });
94  }
95  }
96  }
97 };
98 
99 template <typename PipelineProblem>
100 struct CoreLoopScheduler<PipelineProblem, /*kIsMasking=*/false>
101 {
102  template <ck_tile::index_t WaveGroup, ck_tile::index_t Phase>
105  {
106  using namespace ck_tile;
107 
108  if constexpr(WaveGroup == 0)
109  {
110  if constexpr(Phase == 0)
111  {
112  static_for<0, 8, 1>{}([&](auto) {
113  __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
114  __builtin_amdgcn_sched_group_barrier(0x200, 2, 0); // TRANS
115  __builtin_amdgcn_sched_group_barrier(0x002, 2, 0); // VALU
116  });
117  }
118  else if constexpr(Phase == 1) {}
119  else if constexpr(Phase == 2)
120  {
121 #if !CK_TILE_DISABLE_PACKED_FP32
122  __builtin_amdgcn_sched_group_barrier(0x002, 4, 0); // VALU
123 #endif
124  static_for<0, 8, 1>{}([&](auto) {
125  __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
126  __builtin_amdgcn_sched_group_barrier(0x002, 4, 0); // VALU
127  });
128  }
129  else if constexpr(Phase == 3) {}
130  }
131  else
132  {
133  if constexpr(Phase == 0) {}
134  else if constexpr(Phase == 1)
135  {
136  static_for<0, 8, 1>{}([&](auto) {
137  __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
138  __builtin_amdgcn_sched_group_barrier(0x200, 2, 0); // TRANS
139  __builtin_amdgcn_sched_group_barrier(0x002, 2, 0); // VALU
140  });
141  }
142  else if constexpr(Phase == 2) {}
143  else if constexpr(Phase == 3)
144  {
145 #if !CK_TILE_DISABLE_PACKED_FP32
146  __builtin_amdgcn_sched_group_barrier(0x002, 4, 0); // VALU
147 #endif
148  static_for<0, 8, 1>{}([&](auto) {
149  __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
150  __builtin_amdgcn_sched_group_barrier(0x002, 4, 0); // VALU
151  });
152  }
153  }
154  }
155 };
156 
157 namespace detail {
158 CK_TILE_DEVICE float fma_impl_vsv(float a, float b, float c)
159 {
160 #if CK_TILE_DISABLE_PACKED_FP32
161  return a * b + c;
162 #else
163  float result;
164  asm volatile("v_fma_f32 %[result], %[a], %[b], %[c]"
165  : [result] "=v"(result)
166  : [a] "v"(a), [b] "s"(b), [c] "v"(c));
167  return result;
168 #endif
169 }
170 
171 CK_TILE_DEVICE float add_impl_vv(float lhs, float rhs)
172 {
173  float result;
174  asm volatile("v_add_f32_e32 %[result], %[lhs], %[rhs]"
175  : [result] "=v"(result)
176  : [lhs] "v"(lhs), [rhs] "v"(rhs));
177  return result;
178 }
179 
181 {
182  fp16x2_t result;
183  asm volatile("v_cvt_pk_f16_f32 %[result], %[a], %[b]"
184  : [result] "=v"(result)
185  : [a] "v"(a), [b] "v"(b));
186  return result;
187 }
188 
190 {
191  bf16x2_t result;
192  asm volatile("v_cvt_pk_bf16_f32 %[result], %[a], %[b]"
193  : [result] "=v"(result)
194  : [a] "v"(a), [b] "v"(b));
195  return result;
196 }
197 
199 {
200  fp32x2_t result;
201  asm volatile("v_pk_mul_f32 %[result], %[lhs], %[rhs]"
202  : [result] "=v"(result)
203  : [lhs] "v"(lhs), [rhs] "v"(rhs));
204  return result;
205 }
206 } // namespace detail
207 
208 template <typename Problem_, typename Policy_ = BlockFmhaV3PipelineDefaultPolicy>
210 {
223 
224  static_assert(std::is_same_v<SaccDataType, SMPLComputeDataType>,
225  "we will the same dist tensor 'sp_compute' for both gemm0 & softmax");
226 
228 
229  static constexpr ck_tile::index_t kBlockSize = Problem::kBlockSize;
230 
231  static constexpr ck_tile::index_t kM0 = BlockFmhaShape::kM0;
232  static constexpr ck_tile::index_t kN0 = BlockFmhaShape::kN0;
233  static constexpr ck_tile::index_t kK0 = BlockFmhaShape::kK0;
234  static constexpr ck_tile::index_t kN1 = BlockFmhaShape::kN1;
235  static constexpr ck_tile::index_t kK1 = BlockFmhaShape::kK1;
236  static constexpr ck_tile::index_t kQKHeaddim = BlockFmhaShape::kQKHeaddim;
237  static constexpr ck_tile::index_t kSubQKHeaddim = BlockFmhaShape::kSubQKHeaddim;
238 
239  static_assert(kSubQKHeaddim <= 256, "hdim bigger than 256 is not suitable for this pipeline!");
240 
241  static constexpr bool kIsGroupMode = Problem::kIsGroupMode;
242  static constexpr bool kPadSeqLenQ = Problem::kPadSeqLenQ;
243  static constexpr bool kPadSeqLenK = Problem::kPadSeqLenK;
244  static constexpr bool kPadHeadDimQ = Problem::kPadHeadDimQ;
245  static constexpr bool kPadHeadDimV = Problem::kPadHeadDimV;
246  static constexpr bool kStoreLSE = Problem::kStoreLSE;
247 
248  // last dimension vector length used to create tensor view(and decide buffer_load vector length)
249  // ... together with tensor distribution. tensor dist should able to overwrite this
250  static constexpr ck_tile::index_t kAlignmentQ =
251  kPadHeadDimQ ? 1 : Policy::template GetAlignmentQ<Problem>();
252  static constexpr ck_tile::index_t kAlignmentK =
253  kPadHeadDimQ ? 1 : Policy::template GetAlignmentK<Problem>();
254  static constexpr ck_tile::index_t kAlignmentV =
255  kPadHeadDimV ? 1 : Policy::template GetAlignmentV<Problem>();
256 
257  static constexpr ck_tile::index_t kAlignmentO =
258  kPadHeadDimV ? 1 : Policy::template GetAlignmentO<Problem>();
259 
260  static constexpr ck_tile::index_t kBlockPerCu = []() {
261  if constexpr(Problem::kBlockPerCu != -1)
262  return Problem::kBlockPerCu;
263  else
264  {
265  return 2;
266  }
267  }();
268 
270  {
271  // create another LDS buffer for p
272  return ck_tile::max(kM0 * kN1 * sizeof(PDataType),
273  Policy::template GetSmemSize<Problem>() +
274  kM0 * kN0 * sizeof(PDataType));
275  }
276 
277  // for debug only
278  template <ck_tile::index_t MPerBlock, ck_tile::index_t NPerBlock>
279  CK_TILE_DEVICE static constexpr auto MakeSimpleLdsDesc()
280  {
281  using namespace ck_tile;
282  constexpr auto lds_block_desc =
285  number<1>{},
286  number<1>{});
287 
288  return lds_block_desc;
289  }
290 
291  // for debug only
292  template <ck_tile::index_t MPerBlock>
293  CK_TILE_DEVICE static constexpr auto MakeSimpleLdsDesc1D()
294  {
295  using namespace ck_tile;
296  constexpr auto lds_block_desc = make_naive_tensor_descriptor(
298 
299  return lds_block_desc;
300  }
301 
302  template <typename DataType, typename Descriptor>
303  CK_TILE_DEVICE static constexpr auto make_lds_tile_window(void* base, const Descriptor& desc)
304  {
305  using namespace ck_tile;
306 
307  auto tensor_view =
308  make_tensor_view<address_space_enum::lds>(reinterpret_cast<DataType*>(base), desc);
309  return make_tile_window(tensor_view, desc.get_lengths(), {0, 0});
310  }
311 
312  // vmcnt=0~63, lgkmcnt=0~15, expcnt=0~7
313  template <uint16_t Vmcnt, uint8_t Lgkmcnt, uint8_t Expcnt = 7>
314  CK_TILE_DEVICE static constexpr void s_waitcnt()
315  {
316  // vmcnt use bits {[15:14],[3:0]}
317  // expcnt use bits [6:4]
318  // lgkmcnt use bits [11:8]
319  __builtin_amdgcn_s_waitcnt((((0b110000 & Vmcnt) << (14 - 4)) | (0b1111 & Vmcnt)) |
320  ((0b111 & Expcnt) << 4) | ((0b1111 & Lgkmcnt) << 8));
321  }
322 
323  template <uint16_t Vmcnt>
324  CK_TILE_DEVICE static constexpr void s_waitcnt_vmcnt()
325  {
326  s_waitcnt<Vmcnt, 15>();
327  }
328 
329  template <uint8_t Lgkmcnt>
330  CK_TILE_DEVICE static constexpr void s_waitcnt_lgkmcnt()
331  {
332  s_waitcnt<63, Lgkmcnt>();
333  }
334 
335  template <typename QDramBlockWindowTmp,
336  typename KDramBlockWindowTmp,
337  typename VDramBlockWindowTmp,
338  typename LSEDramBlockWindowTmp,
339  typename QElementFunction,
340  typename KElementFunction,
341  typename VElementFunction,
342  typename LSEElementFunction,
343  typename SAccElementFunction,
344  typename PComputeElementFunction,
345  typename OAccElementFunction>
346  CK_TILE_DEVICE auto operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile
347  const QElementFunction& q_element_func,
348  const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile
349  [[maybe_unused]] const KElementFunction& k_element_func,
350  const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile
351  [[maybe_unused]] const VElementFunction& v_element_func,
352  LSEDramBlockWindowTmp& lse_dram_window_tmp, // M0*1 tile
353  const LSEElementFunction& lse_element_func,
354  [[maybe_unused]] const SAccElementFunction& s_acc_element_func,
355  const PComputeElementFunction& p_compute_element_func,
356  const OAccElementFunction& o_acc_element_func,
357  FmhaMask mask,
358  float scale_s,
359  void* smem_ptr) const
360  {
361  using namespace ck_tile;
362 
363  static_assert(
367  "wrong!");
368 
369  static_assert(kM0 == QDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
370  kN0 == KDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
371  kK0 == KDramBlockWindowTmp{}.get_window_lengths()[number<1>{}] &&
372  kK1 == VDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
373  kN1 == VDramBlockWindowTmp{}.get_window_lengths()[number<1>{}],
374  "wrong!");
375 
376  static_assert(sizeof(SaccDataType) * kM0 * kN0 <= GetSmemSize());
377  auto s_lds = make_tensor_view<address_space_enum::lds>(
378  reinterpret_cast<SaccDataType*>(static_cast<char*>(smem_ptr)),
379  MakeSimpleLdsDesc<kM0, kN0>());
380  [[maybe_unused]] auto s_lds_window =
381  make_tile_window(s_lds, make_tuple(number<kM0>{}, number<kN0>{}), {0, 0});
382 
383  auto p_lds = make_tensor_view<address_space_enum::lds>(
384  reinterpret_cast<PDataType*>(static_cast<char*>(smem_ptr) +
385  Policy::template GetSmemSize<Problem>()),
386  MakeSimpleLdsDesc<kM0, kN0>());
387  [[maybe_unused]] auto p_lds_window =
388  make_tile_window(p_lds, make_tuple(number<kM0>{}, number<kN0>{}), {0, 0});
389 
390  auto o_lds = make_tensor_view<address_space_enum::lds>(
391  reinterpret_cast<PDataType*>(static_cast<char*>(smem_ptr)),
392  MakeSimpleLdsDesc<kM0, kN1>());
393  [[maybe_unused]] auto o_lds_window =
394  make_tile_window(o_lds, make_tuple(number<kM0>{}, number<kN1>{}), {0, 0});
395 
396  auto m_lds = make_tensor_view<address_space_enum::lds>(
397  reinterpret_cast<SMPLComputeDataType*>(static_cast<char*>(smem_ptr) +
398  Policy::template GetSmemSize<Problem>()),
399  MakeSimpleLdsDesc1D<kM0>());
400  [[maybe_unused]] auto m_lds_window =
401  make_tile_window(m_lds, make_tuple(number<kM0>{}), {0});
402 
403  const index_t warp_group_id = get_warp_id() / 4;
404 
405  // Block GEMM
406  constexpr auto gemm_0 = Policy::template GetQKBlockGemm<Problem>();
407  constexpr auto gemm_1 = Policy::template GetPVBlockGemm<Problem>();
408 
409  auto q_dram_window = make_tile_window_linear(
410  q_dram_block_window_tmp, Policy::template MakeQRegTileDistribution<Problem>());
411 
412  // reduction function for softmax
413  const auto f_max = [](auto e0, auto e1) { return max(e0, e1); };
414  const auto f_sum = [](auto e0, auto e1) { return e0 + e1; };
415 
416  auto k_lds_window_store = generate_tuple(
417  [&](auto i_buf) {
418  return make_lds_tile_window<KDataType>(
419  smem_ptr, Policy::template MakeKLdsStoreBlockDescriptor<Problem>(i_buf));
420  },
421  number<2>{});
422 
423  auto v_lds_window_store = generate_tuple(
424  [&](auto i_buf) {
425  return make_lds_tile_window<KDataType>(
426  smem_ptr, Policy::template MakeVLdsStoreBlockDescriptor<Problem>(i_buf));
427  },
428  number<2>{});
429 
431  make_lds_tile_window<KDataType>(
432  nullptr,
433  Policy::template MakeKLdsLoadBlockDescriptor<Problem>()),
434  Policy::template MakeKRegTileDistribution<Problem>())),
435  2>
436  k_lds_window_load;
437 
439  make_lds_tile_window<VDataType>(
440  nullptr,
441  Policy::template MakeVLdsLoadBlockDescriptor<Problem>()),
442  Policy::template MakeVRegTileDistribution<Problem>())),
443  2>
444  v_lds_window_load;
445 
446  decltype(make_static_distributed_tensor<QDataType>(
447  Policy::template MakeQRegTileDistribution<Problem>())) q_tile;
448 
449  union kv_tile_type
450  {
451  CK_TILE_DEVICE kv_tile_type() {}
452 
453  decltype(load_tile(k_lds_window_load(number<0>{}))) k_tile;
454 
455  decltype(load_tile_transpose(v_lds_window_load(number<0>{}))) v_tile;
456  } kv_tile;
457 
458  union sp_compute_type
459  {
460  CK_TILE_DEVICE sp_compute_type() {}
461 
462  decltype(gemm_0.MakeCBlockTile()) sp_compute;
463  decltype(make_static_distributed_tensor<PDataType>(
464  Policy::template MakePRegTileDistribution<Problem>())) p;
465  };
467 
468  decltype(gemm_1.MakeCBlockTile()) o_acc;
469  constexpr index_t fmha_alu_D_reg_cnt = 0; // threshold to decide how many fmha_alu_D_upd()
470  // instructions should we move to fmha_alu1()
471  static_assert(fmha_alu_D_reg_cnt <= o_acc.thread_buf_.size());
472 
473  decltype(block_tile_reduce<SMPLComputeDataType>(
474  sp(number<0>{}).sp_compute, sequence<1>{}, f_max, SMPLComputeDataType{0})) m;
475  decltype(m) l;
476 
477  // initialize k_lds_window and v_lds_window
478  static_for<0, 2, 1>{}([&](auto idx) {
479  k_lds_window_load(idx) = make_tile_window(
480  make_lds_tile_window<KDataType>(
481  static_cast<char*>(smem_ptr) + (idx)*Policy::template GetSmemSizeKV<Problem>(),
482  Policy::template MakeKLdsLoadBlockDescriptor<Problem>()),
483  Policy::template MakeKRegTileDistribution<Problem>());
484  });
485 
486  static_for<0, 2, 1>{}([&](auto idx) {
487  v_lds_window_load(idx) =
488  make_tile_window(make_lds_tile_window<VDataType>(
489  static_cast<char*>(smem_ptr) +
490  (idx + 2) * Policy::template GetSmemSizeKV<Problem>(),
491  Policy::template MakeVLdsLoadBlockDescriptor<Problem>()),
492  Policy::template MakeVRegTileDistribution<Problem>());
493  });
494 
495  {
496  auto origin_q = load_tile(q_dram_window);
497  auto transformed_q = tile_elementwise_in(q_element_func, origin_q);
498 
499  q_tile = transformed_q;
500  }
501 
502  clear_tile(o_acc);
503  set_tile(m, bit_cast<float>(0xff7fffff)); // a bit larger than -infinity
504  clear_tile(l);
505 
506  const auto q_origin = q_dram_window.get_window_origin();
507  const auto [seqlen_k_start, seqlen_k_end] =
508  mask.GetTileRangeAlongX(q_origin.at(number<0>{}), number<kM0>{}, number<kN0>{});
509 
510  const auto num_total_loop = integer_divide_ceil(seqlen_k_end - seqlen_k_start, kN0);
511  index_t kv_token_start = seqlen_k_start;
512 
513  // check early exit if no work to do
514  if constexpr(FmhaMask::IsMasking || kPadSeqLenK)
515  {
516  if(num_total_loop <= 0)
517  {
518  if constexpr(kStoreLSE)
519  {
520  auto lse =
521  make_static_distributed_tensor<LSEDataType>(m.get_tile_distribution());
522 
524 
525  store_tile(lse_dram_window_tmp, tile_elementwise_in(lse_element_func, lse));
526  }
527 
528  // Note: here occ are all cleard, return it
529  // Note: q loaded but no fence, ignore it.
530  return o_acc;
531  }
532  }
533 
534  auto k_dram_window =
535  make_tile_window(k_dram_block_window_tmp.get_bottom_tensor_view(),
536  k_dram_block_window_tmp.get_window_lengths(),
537  {seqlen_k_start, 0},
538  Policy::template MakeKDramTileDistribution<Problem>());
539  k_dram_window.init_raw();
540 
541  auto v_dram_window =
542  make_tile_window(v_dram_block_window_tmp.get_bottom_tensor_view(),
543  v_dram_block_window_tmp.get_window_lengths(),
544  {seqlen_k_start, 0}, // TODO: hdim split?
545  Policy::template MakeVDramTileDistribution<Problem>());
546  v_dram_window.init_raw();
547 
548  // prefetch K tile
549  index_t i_total_loops = 0;
550  constexpr index_t k0_loops = kQKHeaddim / kK0;
551  constexpr index_t k1_loops = kN0 / kK1;
552  static_assert(1 == k0_loops);
553  static_assert(1 == k1_loops);
554  static_assert(kN0 == kK1);
555 
556  constexpr index_t NumWarpGroups = Problem::kBlockSize / Policy::NumThreadPerWarpGroup;
557  static_assert(NumWarpGroups == 2);
558 
559  [[maybe_unused]] auto print_dist_tensor = [&](const auto& dist_tensor, const char* name) {
560  printf("[POYENC] %s (size=%d): %5.2f",
561  name,
562  decltype(dist_tensor.thread_buf_)::size(),
563  ck_tile::type_convert<float>(dist_tensor.thread_buf_[0]));
564  static_for<1, decltype(dist_tensor.thread_buf_)::size(), 1>{}([&](auto i) {
565  printf(", %5.2f", ck_tile::type_convert<float>(dist_tensor.thread_buf_[i]));
566  });
567  printf("\n");
568  };
569 
570  [[maybe_unused]] auto print_lds = [&](auto lds_tile_window, const char* name) {
571  const auto num_rows = lds_tile_window.get_window_lengths().at(number<0>{});
572  const auto num_cols = lds_tile_window.get_window_lengths().at(number<1>{});
573 
574  auto desc = lds_tile_window.get_bottom_tensor_view().desc_;
575  auto data = lds_tile_window.get_bottom_tensor_view().buf_.p_data_;
576 
577  if constexpr(true || num_rows < num_cols)
578  {
579  for(int row = 0; row < num_rows; ++row)
580  {
581  int offset = desc.calculate_offset(make_tuple(row, 0));
582  printf("[DEVICE] %s[%3d] = %5.2f",
583  name,
584  row,
585  ck_tile::type_convert<float>(data[offset]));
586  for(int col = 1; col < num_cols; ++col)
587  {
588  printf(", ");
589  offset = desc.calculate_offset(make_tuple(row, col));
590  printf("%5.2f", ck_tile::type_convert<float>(data[offset]));
591  }
592  printf("\n");
593  }
594  }
595  else
596  {
597  for(int col = 0; col < num_cols; ++col)
598  {
599  int offset = desc.calculate_offset(make_tuple(0, col));
600  printf("[DEVICE] %s[%3d] = %5.2f",
601  name,
602  col,
603  ck_tile::type_convert<float>(data[offset]));
604  for(int row = 1; row < num_rows; ++row)
605  {
606  printf(", ");
607  offset = desc.calculate_offset(make_tuple(row, col));
608  printf("%5.2f", ck_tile::type_convert<float>(data[offset]));
609  }
610  printf("\n");
611  }
612  }
613  };
614 
615  [[maybe_unused]] auto print_lds_1d = [&](auto lds_tile_window, const char* name) {
616  const auto num_elems = lds_tile_window.get_window_lengths().at(number<0>{});
617 
618  auto desc = lds_tile_window.get_bottom_tensor_view().desc_;
619  auto data = lds_tile_window.get_bottom_tensor_view().buf_.p_data_;
620 
621  int offset = desc.calculate_offset(make_tuple(0));
622  printf("[DEVICE] %s = %5.2f", name, ck_tile::type_convert<float>(data[offset]));
623  for(int e = 1; e < num_elems; ++e)
624  {
625  printf(", ");
626  offset = desc.calculate_offset(make_tuple(e));
627  printf("%5.2f", ck_tile::type_convert<float>(data[offset]));
628  }
629  printf("\n");
630  };
631 
632  // K_mem_su_ld_insts = 1 for 32 x 128
633  // V_mem_su_ld_insts = 1 for 128 x 32
634  static constexpr int K_mem_su_ld_insts = 1;
635  static constexpr int V_mem_su_ld_insts = 1;
636 
637  auto K_mem_load = [&](auto k_lds_write_idx) {
638  async_load_tile_raw(k_lds_window_store(k_lds_write_idx), k_dram_window);
639 
641  // move K tile windows
642  move_tile_window(k_dram_window, {kN0, 0});
643  };
644 
645  auto K_lds_load = [&](auto k_lds_read_idx) {
646  kv_tile.k_tile = load_tile(k_lds_window_load(k_lds_read_idx));
647  };
648 
649  auto V_mem_load = [&](auto v_lds_write_idx) {
650  async_load_tile_raw(v_lds_window_store(v_lds_write_idx), v_dram_window);
651  __builtin_amdgcn_sched_barrier(0);
652 
654  move_tile_window(v_dram_window, {kK1, 0});
655  };
656 
657  auto V_lds_load = [&](auto v_lds_read_idx) {
658  kv_tile.v_tile = load_tile_transpose(v_lds_window_load(v_lds_read_idx));
659  };
660 
661  decltype(m) m_old;
662  SMPLComputeDataType o_acc_scale; // rescale o_acc in fmha_alu1() & fmha_alu_D_upd()
664  statically_indexed_array<decltype(sp(number<0>{}).sp_compute), 2> sp_delta;
665 
666  auto fmha_alu0 = [&](auto sp_reg_idx) {
667  m_old = m; // m{j-1}
668  static_assert(m.thread_buf_.size() == 1,
669  "assuming that each thread holds 1 rowmax value");
670  auto m_latest = block_tile_reduce<SMPLComputeDataType>(
671  sp(sp_reg_idx).sp_compute, sequence<1>{}, f_max, m.thread_buf_[0]);
672 #if defined(__gfx950__)
673  // assuming that we are using 32x32 mfma
674  int32x2_t swapped_regs =
675  __builtin_amdgcn_permlane32_swap(bit_cast<int32_t>(m_latest.thread_buf_[0]),
676  bit_cast<int32_t>(m_latest.thread_buf_[0]),
677  false,
678  false);
680  m_latest.thread_buf_[0] = f_max(bit_cast<SMPLComputeDataType>(swapped_regs.x),
681  bit_cast<SMPLComputeDataType>(swapped_regs.y));
682 #else
683  block_tile_reduce_sync(m_latest, f_max, bool_constant<false>{});
684 #endif
685  m = m_latest;
686 
687  constexpr auto p_spans =
688  std::decay_t<decltype(sp(sp_reg_idx).sp_compute)>::get_distributed_spans();
689  sweep_tile_span(p_spans[number<0>{}], [&](auto idx0) {
690  sweep_tile_span(p_spans[number<1>{}], [&](auto idx1) {
691  constexpr auto i_j_idx = make_tuple(idx0, idx1);
692  sp_delta(sp_reg_idx)(i_j_idx) = detail::fma_impl_vsv(
693  sp(sp_reg_idx).sp_compute(i_j_idx), scale_s, -scale_s * m(i_j_idx));
694  });
695  });
697  };
698 
699  auto fmha_alu1 = [&](auto sp_reg_idx) {
700  constexpr auto p_spans =
701  std::decay_t<decltype(sp(sp_reg_idx).sp_compute)>::get_distributed_spans();
702  sweep_tile_span(p_spans[number<0>{}], [&](auto idx0) {
703  sweep_tile_span(p_spans[number<1>{}], [&](auto idx1) {
704  constexpr auto i_j_idx = make_tuple(idx0, idx1);
705  sp(sp_reg_idx).sp_compute(i_j_idx) =
706  ck_tile::exp2(sp_delta(sp_reg_idx)(i_j_idx));
707  });
708  });
709 
710  auto rowsum_p = block_tile_reduce<SMPLComputeDataType>(
711  sp(sp_reg_idx).sp_compute,
712  sequence<1>{},
713  f_sum,
714  SMPLComputeDataType{0}); // rowsum(Pcompute{j})
715  static_assert(rowsum_p.thread_buf_.size() == 1,
716  "assuming that each thread holds 1 rowsum value");
717 #if defined(__gfx950__)
718  // assuming that we are using 32x32 mfma
719  int32x2_t swapped_regs =
720  __builtin_amdgcn_permlane32_swap(bit_cast<int32_t>(rowsum_p.thread_buf_[0]),
721  bit_cast<int32_t>(rowsum_p.thread_buf_[0]),
722  false,
723  false);
724  rowsum_p.thread_buf_[0] = f_sum(bit_cast<SMPLComputeDataType>(swapped_regs.x),
725  bit_cast<SMPLComputeDataType>(swapped_regs.y));
726 #else
727  block_tile_reduce_sync(rowsum_p, f_sum, bool_constant<false>{});
728 #endif
729  // update partial o_acc [0, 2)
730  static_for<0, ck_tile::min(2, fmha_alu_D_reg_cnt), 1>{}(
731  [&](auto idx) { o_acc.thread_buf_[idx] *= o_acc_scale; });
732 
733  // l{j}
734  constexpr auto o_spans = decltype(o_acc)::get_distributed_spans();
735  sweep_tile_span(o_spans[number<0>{}], [&](auto idx0) {
736  constexpr auto i_idx = make_tuple(idx0);
737  const auto tmp = ck_tile::exp2(scale_s * (m_old[i_idx] - m[i_idx]));
738 
739  l(i_idx) = detail::add_impl_vv(tmp * l[i_idx], rowsum_p[i_idx]);
740  });
741 
742  // update partial o_acc [2, fmha_alu_D_reg_cnt)
743  static_for<2, ck_tile::max(2, fmha_alu_D_reg_cnt), 1>{}(
744  [&](auto idx) { o_acc.thread_buf_[idx] *= o_acc_scale; });
745 
749  static_assert(sp(sp_reg_idx).p.thread_buf_.size() % 2 == 0);
750  static_for<0, sp(sp_reg_idx).p.thread_buf_.size(), 2>{}([&](auto idx) {
751  float x = p_compute_element_func(sp(sp_reg_idx).sp_compute.thread_buf_[idx]);
752  float y = p_compute_element_func(sp(sp_reg_idx).sp_compute.thread_buf_[idx + 1]);
753  if constexpr(std::is_same_v<PDataType, fp16_t>)
754  {
755  auto casted = detail::cvt_pk_fp16_f32(x, y);
756  sp(sp_reg_idx).p.thread_buf_[idx] = casted.x;
757  sp(sp_reg_idx).p.thread_buf_[idx + 1] = casted.y;
758  }
759  else
760  {
761  auto casted = detail::cvt_pk_bf16_f32(x, y);
762  sp(sp_reg_idx).p.thread_buf_[idx] = casted.x;
763  sp(sp_reg_idx).p.thread_buf_[idx + 1] = casted.y;
764  }
765  });
766  };
767 
768  auto gemm = [&](auto sp_reg_idx, auto gemm_idx) {
769  if constexpr(gemm_idx == 0)
770  {
771  clear_tile(sp(sp_reg_idx).sp_compute); // initialize C
772  gemm_0(sp(sp_reg_idx).sp_compute,
773  get_slice_tile(q_tile,
774  sequence<0, (k0_loops - 1) * kK0>{},
776  get_slice_tile(kv_tile.k_tile,
777  sequence<0, (k0_loops - 1) * kK0>{},
779  }
780  else
781  {
782  gemm_1(o_acc,
783  get_slice_tile(sp(sp_reg_idx).p,
784  sequence<0, (k1_loops - 1) * kK1>{},
786  get_slice_tile(kv_tile.v_tile,
787  sequence<0, (k1_loops - 1) * kK1>{},
789  }
790  };
791 
792  auto cl_calc = [&](auto sp_reg_idx, auto gemm_idx) {
793  if constexpr(gemm_idx == 0)
794  {
795  clear_tile(sp(sp_reg_idx).sp_compute); // initialize C
796  gemm_0(sp(sp_reg_idx).sp_compute,
797  get_slice_tile(q_tile,
798  sequence<0, (k0_loops - 1) * kK0>{},
800  get_slice_tile(kv_tile.k_tile,
801  sequence<0, (k0_loops - 1) * kK0>{},
803  }
804  else
805  {
806  gemm_1(o_acc,
807  get_slice_tile(sp(sp_reg_idx).p,
808  sequence<0, (k1_loops - 1) * kK1>{},
810  get_slice_tile(kv_tile.v_tile,
811  sequence<0, (k1_loops - 1) * kK1>{},
813  fmha_alu0(number<1>{} - sp_reg_idx);
814  }
815  };
816 
817  auto fmha_alu_D_upd = [&] {
818  o_acc_scale = ck_tile::exp2(scale_s * (m_old.thread_buf_[0] - m.thread_buf_[0]));
819 
820  fp32x2_t pk_o_acc_scale;
821  pk_o_acc_scale.x = o_acc_scale;
822  pk_o_acc_scale.y = o_acc_scale;
823 
824  static_assert((o_acc.thread_buf_.size() - fmha_alu_D_reg_cnt) % 2 == 0);
825 #if CK_TILE_DISABLE_PACKED_FP32
826  static_assert(fmha_alu_D_reg_cnt + 2 <= o_acc.thread_buf_.size());
828  [&](auto idx) { o_acc.thread_buf_[idx] *= o_acc_scale; });
829 #endif
830 
831  constexpr auto issued_D_reg_cnt =
832 #if CK_TILE_DISABLE_PACKED_FP32
833  fmha_alu_D_reg_cnt + 2
834 #else
835  fmha_alu_D_reg_cnt
836 #endif
837  ;
840  // update partial o_acc after [issued_D_reg_cnt]
841  static_for<issued_D_reg_cnt, o_acc.thread_buf_.size(), 2>{}([&](auto idx) {
842  fp32x2_t input;
843  input.x = o_acc.thread_buf_[idx];
844  input.y = o_acc.thread_buf_[idx + 1];
845 
846  auto output = detail::pk_mul_f32(input, pk_o_acc_scale);
847 
848  o_acc.thread_buf_[idx] = output.x;
849  o_acc.thread_buf_[idx + 1] = output.y;
850  });
851  };
852 
853  auto fmha_mask = [&](auto sp_reg_idx) {
854  if constexpr(kPadSeqLenK || FmhaMask::IsMasking)
855  {
856  bool need_perpixel_check = mask.IsEdgeTile(
857  q_origin.at(number<0>{}), kv_token_start, number<kM0>{}, number<kN0>{});
858  if(need_perpixel_check)
859  {
860  set_tile_if(sp(sp_reg_idx).sp_compute,
862  [&](auto tile_idx) {
863  const auto row =
864  q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
865  const auto col = kv_token_start + tile_idx.at(number<1>{});
866  return mask.IsOutOfBound(row, col);
867  });
868  }
869  }
870  };
871 
872  auto cl_load = [&](auto load_type, auto mem_wr_idx, auto lds_rd_idx) {
873  if constexpr(load_type == 0)
874  {
875  V_mem_load(mem_wr_idx);
876  K_lds_load(lds_rd_idx);
877  }
878  else
879  {
880  K_mem_load(mem_wr_idx);
881  V_lds_load(lds_rd_idx);
882  }
883  };
884 
885  auto core_loop = [&](auto cl_p) {
886  auto gemm0 = number<0>{};
887  auto gemm1 = number<1>{};
888 
889  auto memV = number<0>{};
890  auto memK = number<1>{};
891 
893 
894  auto iteration = [&](auto pi) {
895  auto xdl_SP_p01_reg_idx = number<1>{} - pi;
896  auto xdl_SP_p23_reg_idx = pi;
897 
898  auto K_w0_lds_wr_idx = number<1>{} - pi;
899  auto V_w0_lds_wr_idx = pi;
900  auto K_w0_lds_rd_idx = pi;
901  auto V_w0_lds_rd_idx = pi;
902 
903  auto K_w4_lds_wr_idx = number<1>{} - pi;
904  auto V_w4_lds_wr_idx = number<1>{} - pi;
905  auto K_w4_lds_rd_idx = number<1>{} - pi;
906  auto V_w4_lds_rd_idx = pi;
907 
908  bool result = true;
909 
910  if constexpr(cl_p == 0)
911  {
912 #if ADD_SBARRIER_FOR_PHASE0
913  __builtin_amdgcn_sched_barrier(0);
914  __builtin_amdgcn_s_barrier();
915 #endif
916  __builtin_amdgcn_sched_barrier(0);
917  // phase0
918  if constexpr(pi == 0)
919  {
920  ASM_MARKER("phase0 Wave0-3 (pi=0)");
921  }
922  else
923  {
924  ASM_MARKER("phase0 Wave0-3 (pi=1)");
925  }
926  s_waitcnt_lgkmcnt<0>();
927  __builtin_amdgcn_sched_barrier(0);
928  cl_calc(xdl_SP_p01_reg_idx, gemm0);
929  fmha_alu1(xdl_SP_p23_reg_idx);
930 
931  Scheduler::schedule(cl_p, number<0>{});
932  __builtin_amdgcn_sched_barrier(0);
933  // phase1
934  ASM_MARKER("phase1 Wave0-3");
935  s_waitcnt_vmcnt<K_mem_su_ld_insts + V_mem_su_ld_insts>();
936  __builtin_amdgcn_sched_barrier(0);
937  __builtin_amdgcn_s_barrier();
938  __builtin_amdgcn_sched_barrier(0);
939  cl_load(memK, K_w0_lds_wr_idx, V_w0_lds_rd_idx);
940  fmha_mask(xdl_SP_p01_reg_idx);
941 
942  Scheduler::schedule(cl_p, number<1>{});
943  __builtin_amdgcn_sched_barrier(0);
944  // phase2
945  ASM_MARKER("phase2 Wave0-3");
946  s_waitcnt_lgkmcnt<0>();
947  __builtin_amdgcn_sched_barrier(0);
948  __builtin_amdgcn_s_barrier();
949  __builtin_amdgcn_sched_barrier(0);
950  cl_calc(xdl_SP_p23_reg_idx, gemm1);
951 
952  Scheduler::schedule(cl_p, number<2>{});
953  __builtin_amdgcn_sched_barrier(0);
954  fmha_alu_D_upd();
955 
956  __builtin_amdgcn_sched_barrier(0);
957  // phase3
958  ASM_MARKER("phase3 Wave0-3");
959  s_waitcnt_vmcnt<K_mem_su_ld_insts + V_mem_su_ld_insts>();
960  __builtin_amdgcn_sched_barrier(0);
961  __builtin_amdgcn_s_barrier();
962  __builtin_amdgcn_sched_barrier(0);
963  cl_load(memV, V_w0_lds_wr_idx, K_w0_lds_rd_idx);
964 
965  Scheduler::schedule(cl_p, number<3>{});
966  kv_token_start += kN0;
967  if(num_total_loop <= ++i_total_loops)
968  {
969  result = false;
970  }
971  }
972  else
973  {
974 #if ADD_SBARRIER_FOR_PHASE0
975  __builtin_amdgcn_sched_barrier(0);
976  __builtin_amdgcn_s_barrier();
977 #endif
978  __builtin_amdgcn_sched_barrier(0);
979  // phase0
980  if constexpr(pi == 0)
981  {
982  ASM_MARKER("phase0 Wave4-7 (pi=0)");
983  }
984  else
985  {
986  ASM_MARKER("phase0 Wave4-7 (pi=1)");
987  }
988  cl_load(memV, V_w4_lds_wr_idx, K_w4_lds_rd_idx);
989 
990  Scheduler::schedule(cl_p, number<0>{});
991  __builtin_amdgcn_sched_barrier(0);
992  // phase1
993  ASM_MARKER("phase1 Wave4-7");
994  s_waitcnt<K_mem_su_ld_insts + V_mem_su_ld_insts, 0>();
995  __builtin_amdgcn_sched_barrier(0);
996  __builtin_amdgcn_s_barrier();
997  __builtin_amdgcn_sched_barrier(0);
998  cl_calc(xdl_SP_p01_reg_idx, gemm0);
999  fmha_alu1(xdl_SP_p23_reg_idx);
1000 
1001  Scheduler::schedule(cl_p, number<1>{});
1002  __builtin_amdgcn_sched_barrier(0);
1003  // phase2
1004  ASM_MARKER("phase2 Wave4-7");
1005  __builtin_amdgcn_s_barrier();
1006  __builtin_amdgcn_sched_barrier(0);
1007  cl_load(memK, K_w4_lds_wr_idx, V_w4_lds_rd_idx);
1008  fmha_mask(xdl_SP_p01_reg_idx);
1009 
1010  Scheduler::schedule(cl_p, number<2>{});
1011  kv_token_start += kN0;
1012  if(num_total_loop <= ++i_total_loops)
1013  {
1014  result = false;
1015  }
1016 
1017  __builtin_amdgcn_sched_barrier(0);
1018  // phase3
1019  ASM_MARKER("phase3 Wave4-7");
1020  s_waitcnt<K_mem_su_ld_insts + V_mem_su_ld_insts, 0>();
1021  __builtin_amdgcn_sched_barrier(0);
1022  __builtin_amdgcn_s_barrier();
1023  __builtin_amdgcn_sched_barrier(0);
1024  cl_calc(xdl_SP_p23_reg_idx, gemm1);
1025 
1026  Scheduler::schedule(cl_p, number<3>{});
1027  __builtin_amdgcn_sched_barrier(0);
1028  fmha_alu_D_upd();
1029  }
1030  return result;
1031  };
1032  return iteration(number<0>{}) && iteration(number<1>{});
1033  };
1034 
1035  auto fmha_post_process = [&](auto d) {
1036  auto ps_pi = number<1>{} - d;
1037  auto V_lds_rd_idx = ps_pi;
1038 
1039  s_waitcnt_vmcnt<K_mem_su_ld_insts>();
1040  __builtin_amdgcn_s_barrier();
1041 
1042  V_lds_load(V_lds_rd_idx);
1043  fmha_alu1(ps_pi);
1044 
1045  s_waitcnt_lgkmcnt<0>();
1046 
1047  auto xdl_SP_p23_reg_idx = ps_pi;
1048  gemm(xdl_SP_p23_reg_idx, /*gemm_idx=*/number<1>{});
1049  };
1050 
1051  // pre-stage
1052  {
1053  ASM_MARKER("before pre-stage");
1054  // (1) load K0 to LDS & VGPR
1055  K_mem_load(number<0>{}); // mem_K0
1056 
1057  s_waitcnt_vmcnt<0>();
1058  __builtin_amdgcn_s_barrier();
1059 
1060  K_lds_load(number<0>{}); // lds_K0
1061 
1062  s_waitcnt_lgkmcnt<0>();
1063  __builtin_amdgcn_s_barrier();
1064 
1065  // (2) prefetch K1 and V0 to LDS in parallel with GEMM0
1066  if(1 < num_total_loop)
1067  {
1068  K_mem_load(number<1>{}); // mem_K1
1069  }
1070  V_mem_load(number<0>{}); // mem_V0
1071 
1072  // (3) mfma (Q*K0) + softmax
1073  gemm(number<0>{}, /*gemm_idx=*/number<0>{});
1074 
1075  fmha_mask(number<0>{});
1077  fmha_alu0(number<0>{});
1078  fmha_alu_D_upd();
1079 
1080  kv_token_start += kN0;
1081  ++i_total_loops;
1082  if(num_total_loop <= i_total_loops)
1083  {
1084  goto label_main_loops_exit;
1085  }
1086 
1087  if(2 < num_total_loop)
1088  {
1089  K_mem_load(number<0>{}); // mem_K2
1090 
1091  s_waitcnt_vmcnt<K_mem_su_ld_insts + V_mem_su_ld_insts>();
1092  __builtin_amdgcn_s_barrier();
1093  }
1094 
1095  ASM_MARKER("end pre-stage");
1096  }
1097 
1098  if(1 < num_total_loop)
1099  {
1100  if(warp_group_id == 0)
1101  {
1102  V_mem_load(number<1>{}); // V1
1103  K_lds_load(number<1>{}); // K1
1104 
1105  asm volatile("s_setprio 0");
1106  __builtin_amdgcn_s_barrier();
1107  while(core_loop(number<0>{}))
1108  ;
1109  }
1110  if(warp_group_id != 0)
1111  {
1112  asm volatile("s_setprio 1");
1113  __builtin_amdgcn_s_barrier();
1114  while(core_loop(number<1>{}))
1115  ;
1116  }
1117  }
1118  label_main_loops_exit:
1119  if(num_total_loop % 2)
1120  {
1121  fmha_post_process(number<1>{});
1122  }
1123  if(!(num_total_loop % 2))
1124  {
1125  fmha_post_process(number<0>{});
1126  }
1127 
1128  // store lse
1129  if constexpr(kStoreLSE)
1130  {
1131  auto lse = make_static_distributed_tensor<LSEDataType>(m.get_tile_distribution());
1132 
1133  constexpr auto lse_spans = decltype(lse)::get_distributed_spans();
1134  sweep_tile_span(lse_spans[number<0>{}], [&](auto idx0) {
1135  constexpr auto i_idx = make_tuple(idx0);
1136  lse(i_idx) = m[i_idx] / C_LOG2E + log(l[i_idx]);
1137  });
1138 
1139  store_tile(lse_dram_window_tmp, tile_elementwise_in(lse_element_func, lse));
1140  }
1141 
1142  // finally, O
1143  constexpr auto o_spans = decltype(o_acc)::get_distributed_spans();
1144 
1145  sweep_tile_span(o_spans[number<0>{}], [&](auto idx0) {
1146  constexpr auto i_idx = make_tuple(idx0);
1147  const auto tmp = [&]() {
1148  if constexpr(FmhaMask::IsMasking)
1149  {
1150  return l[i_idx] == 0.f ? 0.f : 1 / l[i_idx];
1151  }
1152  else
1153  return 1 / l[i_idx];
1154  }();
1155  sweep_tile_span(o_spans[number<1>{}], [&](auto idx1) {
1156  constexpr auto i_j_idx = make_tuple(idx0, idx1);
1157  o_acc(i_j_idx) *= tmp;
1158  });
1159  });
1160 
1161  o_acc = tile_elementwise_in(o_acc_element_func, o_acc);
1162 
1163  return o_acc;
1164  }
1165 
1166  template <typename QDramBlockWindowTmp,
1167  typename KDramBlockWindowTmp,
1168  typename VDramBlockWindowTmp,
1169  typename LSEDramBlockWindowTmp>
1170  CK_TILE_HOST_DEVICE auto
1171  operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile
1172  const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile
1173  const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile
1174  LSEDramBlockWindowTmp& lse_dram_block_window_tmp, // M0*1 tile
1175  FmhaMask mask,
1176  float scale_s,
1177  void* smem_ptr) const
1178  {
1179  using namespace ck_tile;
1180 
1181  return operator()(q_dram_block_window_tmp,
1182  identity{},
1183  k_dram_block_window_tmp,
1184  identity{},
1185  v_dram_block_window_tmp,
1186  identity{},
1187  lse_dram_block_window_tmp,
1188  identity{},
1189  identity{},
1190  identity{},
1191  identity{},
1192  mask,
1193  scale_s,
1194  smem_ptr);
1195  }
1196 };
1197 
1198 } // namespace ck_tile
#define ASM_MARKER(marker)
Definition: block_fmha_fwd_v3_pipeline.hpp:12
#define CK_TILE_DEVICE
Definition: config.hpp:41
#define CK_TILE_HOST_DEVICE
Definition: config.hpp:42
CK_TILE_DEVICE bf16x2_t cvt_pk_bf16_f32(float a, float b)
Definition: block_fmha_fwd_v3_pipeline.hpp:189
CK_TILE_DEVICE float fma_impl_vsv(float a, float b, float c)
Definition: block_fmha_fwd_v3_pipeline.hpp:158
CK_TILE_DEVICE float add_impl_vv(float lhs, float rhs)
Definition: block_fmha_fwd_v3_pipeline.hpp:171
CK_TILE_DEVICE fp16x2_t cvt_pk_fp16_f32(float a, float b)
Definition: block_fmha_fwd_v3_pipeline.hpp:180
CK_TILE_DEVICE fp32x2_t pk_mul_f32(fp32x2_t lhs, fp32x2_t rhs)
Definition: block_fmha_fwd_v3_pipeline.hpp:198
Definition: cluster_descriptor.hpp:13
CK_TILE_DEVICE bfloat16_t log(bfloat16_t x)
Definition: bfloat16.hpp:432
constexpr CK_TILE_HOST_DEVICE auto make_naive_tensor_descriptor(const tuple< Lengths... > &lengths, const tuple< Strides... > &strides, number< GuaranteedLastDimensionVectorLength >=number<-1 >{}, number< GuaranteedLastDimensionVectorStride >=number<-1 >{})
Definition: tensor_descriptor.hpp:268
constexpr CK_TILE_HOST_DEVICE auto integer_divide_ceil(X x, Y y)
Definition: math.hpp:149
constexpr CK_TILE_DEVICE auto get_slice_tile(const tile_window_with_static_lengths< BottomTensorView_, WindowLengths_ > &tile, sequence< SliceBegins... > slice_begins, sequence< SliceEnds... > slice_ends)
Definition: slice_tile.hpp:23
CK_TILE_DEVICE auto tile_elementwise_in(const InElementFunc &in_element_func, const InTensor &... in_dstr_tensors)
Definition: tile_elementwise.hpp:40
CK_TILE_DEVICE void set_tile(DstrTensors &dstr_tensor, const T &value)
Definition: tile_elementwise.hpp:95
bfloat16_t bf16x2_t
Definition: pk_fp4.hpp:24
CK_TILE_DEVICE void block_tile_reduce_sync(AccDistributedTensor_ &acc_tensor, const ReduceFunc &reduce_func, bool_constant< WithBroadcast >={}, bool_constant< CrossWarp >={})
Definition: block_reduce.hpp:21
tuple_array< T, N > statically_indexed_array
Definition: statically_indexed_array.hpp:16
float fp32x2_t
Definition: pk_fp4.hpp:22
int32_t index_t
Definition: integer.hpp:9
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition: type_traits.hpp:21
_Float16 fp16x2_t
Definition: half.hpp:385
CK_TILE_DEVICE auto load_tile_transpose(const tile_window_with_static_distribution< BottomTensorView_, WindowLengths_, TileDistribution_, NumCoord > &tile_window)
transpose loads tile from a tensor and returns the resulting tensor with a new (transposed) tile dist...
Definition: load_tile_transpose.hpp:403
CK_TILE_DEVICE void sweep_tile_span(TileDistributedSpan_, const F &f)
Definition: sweep_tile.hpp:20
constexpr CK_TILE_DEVICE auto make_tile_window(null_tensor_view, const WindowLengths &window_lengths, const multi_index< WindowLengths::size()> &, Ts &&...)
Definition: null_tile_window.hpp:75
CK_TILE_DEVICE void move_tile_window(null_tile_window< WindowLengths > &, const typename null_tile_window< WindowLengths >::BottomTensorIndex &)
Definition: null_tile_window.hpp:95
CK_TILE_HOST_DEVICE void set_tile_if(static_distributed_tensor< DataType, StaticTileDistribution > &out_tensor, DataType value, XIndicesPredicate predicate)
Definition: static_distributed_tensor.hpp:175
constexpr CK_TILE_DEVICE auto make_tile_window_linear(const TensorView_ &tensor_view, const WindowLengths_ &window_lengths, const multi_index< TensorView_::get_num_of_dimension()> &origin, const StaticTileDistribution_ &tile_distribution, LinearBottomDims_={})
Definition: tile_window_linear.hpp:993
constexpr CK_TILE_HOST_DEVICE auto generate_tuple(F &&f, number< N >)
Definition: tuple.hpp:429
constexpr CK_TILE_HOST_DEVICE auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:360
CK_TILE_DEVICE auto async_load_tile_raw(LdsTileWindow_ &&lds_tile, const TileWindow_ &tile_window, number< i_access >={}, bool_constant< oob_conditional_check >={}, bool_constant< pre_nop >={})
Definition: load_tile.hpp:110
CK_TILE_DEVICE void store_tile(tile_window_with_static_lengths< BottomTensorView_, WindowLengths_ > &tile_window_tmp, const static_distributed_tensor< DataType_, TileDistribution_ > &dstr_tensor)
Definition: store_tile.hpp:23
CK_TILE_DEVICE void clear_tile(DstrTensors &dstr_tensor)
Definition: tile_elementwise.hpp:177
int32_t int32x2_t
Definition: vector_type.hpp:143
constexpr CK_TILE_HOST_DEVICE T min(T x)
Definition: math.hpp:210
constexpr CK_TILE_HOST_DEVICE T max(T x)
Definition: math.hpp:161
CK_TILE_DEVICE auto load_tile(const TileWindow_ &tile_window, number< i_access >={}, bool_constant< oob_conditional_check >={})
Definition: load_tile.hpp:22
CK_TILE_DEVICE bfloat16_t exp2(bfloat16_t x)
Definition: bfloat16.hpp:429
constexpr bool is_same_v
Definition: type.hpp:283
const GenericPointer< typename T::ValueType > T2 T::AllocatorType & a
Definition: pointer.h:1249
Definition: block_fmha_fwd_v3_pipeline.hpp:210
ck_tile::remove_cvref_t< typename Problem::PDataType > PDataType
Definition: block_fmha_fwd_v3_pipeline.hpp:219
ck_tile::remove_cvref_t< typename Problem::SMPLComputeDataType > SMPLComputeDataType
Definition: block_fmha_fwd_v3_pipeline.hpp:217
ck_tile::remove_cvref_t< typename Problem::KDataType > KDataType
Definition: block_fmha_fwd_v3_pipeline.hpp:214
ck_tile::remove_cvref_t< typename Problem::ODataType > ODataType
Definition: block_fmha_fwd_v3_pipeline.hpp:221
ck_tile::remove_cvref_t< typename Problem::BlockFmhaShape > BlockFmhaShape
Definition: block_fmha_fwd_v3_pipeline.hpp:227
static constexpr CK_TILE_DEVICE void s_waitcnt_lgkmcnt()
Definition: block_fmha_fwd_v3_pipeline.hpp:330
ck_tile::remove_cvref_t< typename Problem::OaccDataType > OaccDataType
Definition: block_fmha_fwd_v3_pipeline.hpp:220
static constexpr CK_TILE_DEVICE auto MakeSimpleLdsDesc1D()
Definition: block_fmha_fwd_v3_pipeline.hpp:293
static constexpr CK_TILE_DEVICE void s_waitcnt()
Definition: block_fmha_fwd_v3_pipeline.hpp:314
ck_tile::remove_cvref_t< Problem_ > Problem
Definition: block_fmha_fwd_v3_pipeline.hpp:211
static constexpr CK_TILE_DEVICE void s_waitcnt_vmcnt()
Definition: block_fmha_fwd_v3_pipeline.hpp:324
static constexpr CK_TILE_DEVICE auto make_lds_tile_window(void *base, const Descriptor &desc)
Definition: block_fmha_fwd_v3_pipeline.hpp:303
ck_tile::remove_cvref_t< typename Problem::QDataType > QDataType
Definition: block_fmha_fwd_v3_pipeline.hpp:213
CK_TILE_HOST_DEVICE auto operator()(const QDramBlockWindowTmp &q_dram_block_window_tmp, const KDramBlockWindowTmp &k_dram_block_window_tmp, const VDramBlockWindowTmp &v_dram_block_window_tmp, LSEDramBlockWindowTmp &lse_dram_block_window_tmp, FmhaMask mask, float scale_s, void *smem_ptr) const
Definition: block_fmha_fwd_v3_pipeline.hpp:1171
ck_tile::remove_cvref_t< typename Problem::LSEDataType > LSEDataType
Definition: block_fmha_fwd_v3_pipeline.hpp:218
static constexpr CK_TILE_HOST_DEVICE ck_tile::index_t GetSmemSize()
Definition: block_fmha_fwd_v3_pipeline.hpp:269
ck_tile::remove_cvref_t< typename Problem::SaccDataType > SaccDataType
Definition: block_fmha_fwd_v3_pipeline.hpp:216
ck_tile::remove_cvref_t< Policy_ > Policy
Definition: block_fmha_fwd_v3_pipeline.hpp:212
ck_tile::remove_cvref_t< typename Problem::FmhaMask > FmhaMask
Definition: block_fmha_fwd_v3_pipeline.hpp:222
CK_TILE_DEVICE auto operator()(const QDramBlockWindowTmp &q_dram_block_window_tmp, const QElementFunction &q_element_func, const KDramBlockWindowTmp &k_dram_block_window_tmp, [[maybe_unused]] const KElementFunction &k_element_func, const VDramBlockWindowTmp &v_dram_block_window_tmp, [[maybe_unused]] const VElementFunction &v_element_func, LSEDramBlockWindowTmp &lse_dram_window_tmp, const LSEElementFunction &lse_element_func, [[maybe_unused]] const SAccElementFunction &s_acc_element_func, const PComputeElementFunction &p_compute_element_func, const OAccElementFunction &o_acc_element_func, FmhaMask mask, float scale_s, void *smem_ptr) const
Definition: block_fmha_fwd_v3_pipeline.hpp:346
ck_tile::remove_cvref_t< typename Problem::VDataType > VDataType
Definition: block_fmha_fwd_v3_pipeline.hpp:215
static constexpr CK_TILE_DEVICE auto MakeSimpleLdsDesc()
Definition: block_fmha_fwd_v3_pipeline.hpp:279
static constexpr CK_TILE_DEVICE void schedule(ck_tile::number< WaveGroup >, ck_tile::number< Phase >)
Definition: block_fmha_fwd_v3_pipeline.hpp:103
static constexpr CK_TILE_DEVICE void schedule(ck_tile::number< WaveGroup >, ck_tile::number< Phase >)
Definition: block_fmha_fwd_v3_pipeline.hpp:45
Definition: block_fmha_fwd_v3_pipeline.hpp:39
Definition: integral_constant.hpp:13
Definition: functional.hpp:86
Definition: numeric.hpp:18
Definition: coordinate_transform.hpp:1392
Definition: sequence.hpp:49
Definition: functional.hpp:43
Definition: tensor_view.hpp:41
#define C_LOG2E
Definition: math.hpp:469