/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_whole_k_prefetch.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_whole_k_prefetch.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_whole_k_prefetch.hpp Source File
block_fmha_pipeline_qr_ks_vs_whole_k_prefetch.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include "ck_tile/core.hpp"
11 
12 namespace ck_tile {
13 
14 template <typename Problem_, typename Policy_ = BlockFmhaPipelineQRKSVSWholeKPrefetchDefaultPolicy>
16 {
32 
35  static constexpr bool kQLoadOnce = true;
36  static_assert(kQLoadOnce == Policy::QLoadOnce);
37 
38  static constexpr index_t kBlockSize = Problem::kBlockSize;
39 
40  static constexpr index_t kM0 = BlockFmhaShape::kM0;
41  static constexpr index_t kN0 = BlockFmhaShape::kN0;
42  static constexpr index_t kK0 = BlockFmhaShape::kK0;
43  static constexpr index_t kN1 = BlockFmhaShape::kN1;
44  static constexpr index_t kK1 = BlockFmhaShape::kK1;
45  static constexpr index_t kQKHeaddim = BlockFmhaShape::kQKHeaddim;
46  static constexpr index_t kSubQKHeaddim = BlockFmhaShape::kSubQKHeaddim;
47 
48  static_assert(kSubQKHeaddim <= 256, "hdim bigger than 256 is not suitable for this pipeline!");
49 
50  static constexpr bool kIsGroupMode = Problem::kIsGroupMode;
51  static constexpr bool kPadSeqLenQ = Problem::kPadSeqLenQ;
52  static constexpr bool kPadSeqLenK = Problem::kPadSeqLenK;
53  static constexpr bool kPadHeadDimQ = Problem::kPadHeadDimQ;
54  static constexpr bool kPadHeadDimV = (kQKHeaddim < kSubQKHeaddim) ? 1 : Problem::kPadHeadDimV;
55  static constexpr auto BiasEnum = Problem::BiasEnum;
56  static constexpr bool kStoreLSE = Problem::kStoreLSE;
57  static constexpr bool kHasDropout = Problem::kHasDropout;
58  static constexpr bool kHasLogitsSoftCap = Problem::kHasLogitsSoftCap;
59 
60  // last dimension vector length used to create tensor view(and decide buffer_load vector length)
61  // ... together with tensor distribution. tensor dist should able to overwrite this
62  static constexpr index_t kAlignmentQ =
63  kPadHeadDimQ ? 1 : Policy::template GetAlignmentQ<Problem>();
64  static constexpr index_t kAlignmentK =
65  kPadHeadDimQ ? 1 : Policy::template GetAlignmentK<Problem>();
66  static constexpr index_t kAlignmentV = []() {
67  if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
68  return Problem::kPadHeadDimV ? 1 : Policy::template GetAlignmentV<Problem>();
69  else
70  return kPadSeqLenK ? 1 : Policy::template GetAlignmentV<Problem>();
71  }();
72 
73  static constexpr index_t kAlignmentO =
74  kPadHeadDimV ? 1 : Policy::template GetAlignmentO<Problem>();
75  static constexpr index_t kAlignmentBias =
76  kPadSeqLenK ? 1 : Policy::template GetAlignmentBias<Problem>();
77 
78  static constexpr index_t kBlockPerCu = []() {
79  if constexpr(Problem::kBlockPerCu != -1)
80  return Problem::kBlockPerCu;
81  else
82  {
83  if constexpr(kQKHeaddim == 32)
84  {
85  return 2;
86  }
87  else if constexpr(kQKHeaddim == 64)
88  {
89  return 2;
90  }
91  else if constexpr(kQKHeaddim == 96 || kQKHeaddim == 128)
92  {
94  return 1;
95  else
96  return 2;
97  }
98  else if constexpr(kQKHeaddim == 256)
99  {
100  return 1;
101  }
102  else
103  {
104  return 1;
105  };
106  }
107  }();
108 
109  static constexpr const char* name = "qr_async";
110 
111  using DropoutType = std::conditional_t<kHasDropout, BlockDropout, NullBlockDropout>;
112 
114  {
115  return Policy::template GetSmemSize<Problem>();
116  }
117 
118  template <typename QDramBlockWindowTmp,
119  typename KDramBlockWindowTmp,
120  typename VDramBlockWindowTmp,
121  typename BiasDramBlockWindowTmp,
122  typename RandValDramBlockWindowTmp,
123  typename LSEDramBlockWindowTmp,
124  typename QElementFunction,
125  typename KElementFunction,
126  typename VElementFunction,
127  typename BiasElementFunction,
128  typename LSEElementFunction,
129  typename SAccElementFunction,
130  typename PComputeElementFunction,
131  typename OAccElementFunction,
132  typename PositionEncoding,
133  typename AttentionVariantParams,
134  typename BlockIndices>
136  operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*kSubQKHeaddim tile
137  const QElementFunction& q_element_func,
138  const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*kSubQKHeaddim tile
139  const KElementFunction& k_element_func,
140  const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile
141  const VElementFunction& v_element_func,
142  const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile
143  const BiasElementFunction& bias_element_func,
144  RandValDramBlockWindowTmp& randval_dram_block_window_tmp,
145  LSEDramBlockWindowTmp& lse_dram_window_tmp, // M0*1 tile
146  const LSEElementFunction& lse_element_func,
147  const SAccElementFunction& s_acc_element_func,
148  const PComputeElementFunction& p_compute_element_func,
149  const OAccElementFunction& o_acc_element_func,
150  FmhaMask mask,
151  PositionEncoding position_encoding,
152  float scale_s,
153  const AttentionVariant& /* unused */,
154  const AttentionVariantParams& /* unused */,
155  const BlockIndices& /* unused */,
156  void* smem_ptr,
157  DropoutType& dropout) const
158  {
159  ignore = q_element_func;
160  ignore = k_element_func;
161 
162  static_assert(
166  "wrong!");
167 
168  static_assert(kM0 == QDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
169  kN0 == KDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
170  kK0 == KDramBlockWindowTmp{}.get_window_lengths()[number<1>{}] &&
171  kN1 == VDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
172  kK1 == VDramBlockWindowTmp{}.get_window_lengths()[number<1>{}] &&
173  kM0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
174  kN0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<1>{}],
175  "wrong!");
176 
177  constexpr auto I0 = number<0>{};
178  constexpr auto I1 = number<1>{};
179 
180  constexpr index_t k0_loops = kQKHeaddim / kK0;
181  constexpr index_t k1_loops = kN0 / kK1;
182  static_assert(2 <= k0_loops);
183  static_assert(2 <= k1_loops);
184 
185  constexpr bool kPreloadWholeNextIterationK =
186  Policy::template IsPreloadWholeNextIterationK<Problem>();
187 
188  constexpr auto NumKLdsBuffers = Policy::template GetNumKLdsBuffers<Problem>();
189  constexpr auto NumVLdsBuffers = Policy::template GetNumVLdsBuffers<Problem>();
190  constexpr auto NumPrefetchV = Policy::template GetNumPrefetchV<Problem>();
191 
192  static_assert(NumKLdsBuffers >= 2);
193 
194  auto q_dram_window = make_tile_window(q_dram_block_window_tmp.get_bottom_tensor_view(),
195  q_dram_block_window_tmp.get_window_lengths(),
196  q_dram_block_window_tmp.get_window_origin(),
197  Policy::template MakeQRegTileDistribution<Problem>());
198 
199  const auto q_origin = q_dram_window.get_window_origin();
200  const auto [seqlen_k_start, seqlen_k_end] =
201  mask.GetTileRangeAlongX(q_origin.at(number<0>{}), number<kM0>{}, number<kN0>{});
202 
203  auto k_dram_block_window =
204  make_tile_window(k_dram_block_window_tmp.get_bottom_tensor_view(),
205  k_dram_block_window_tmp.get_window_lengths(),
206  {seqlen_k_start, 0});
207 
208  auto k_dram_window =
209  make_tile_window(k_dram_block_window.get_bottom_tensor_view(),
210  k_dram_block_window.get_window_lengths(),
211  k_dram_block_window.get_window_origin(),
212  Policy::template MakeKDramTileDistribution<Problem>());
213 
214  using k_tile_type = decltype(load_tile(k_dram_window));
215 
216  auto k_tiles = [&]() {
217  if constexpr(kPreloadWholeNextIterationK)
219  else
221  }();
222 
223  k_tiles[I0] = load_tile(k_dram_window);
224  move_tile_window(k_dram_window, {0, kK0});
225 
226  auto q_tile = load_tile(q_dram_window);
227 
228  __builtin_amdgcn_sched_barrier(0);
229 
230  // K tile in LDS
231  KDataType* k_lds_ptr = static_cast<KDataType*>(smem_ptr);
232  auto k_lds = make_tensor_view<address_space_enum::lds>(
233  k_lds_ptr, Policy::template MakeKLdsBlockDescriptor<Problem>());
234  auto k_lds_window = make_tile_window(
235  k_lds, Policy::template MakeKLdsBlockDescriptor<Problem>().get_lengths(), {0, 0});
236 
237  using k_lds_window_type =
238  decltype(get_slice_tile(k_lds_window, sequence<0, 0>{}, sequence<kN0, kK0>{}));
239 
241 
242  static_for<0, NumKLdsBuffers, 1>{}([&](auto i_buf) {
243  k_lds_windows[i_buf] = get_slice_tile(
244  k_lds_window, sequence<i_buf * kN0, 0>{}, sequence<(i_buf + 1) * kN0, kK0>{});
245  });
246 
247  auto v_dram_window =
248  make_tile_window(v_dram_block_window_tmp.get_bottom_tensor_view(),
249  v_dram_block_window_tmp.get_window_lengths(),
250  {0, seqlen_k_start}, // TODO: hdim split?
251  Policy::template MakeVDramTileDistribution<Problem>());
252  // V tile in LDS
253  auto v_lds = make_tensor_view<address_space_enum::lds>(
254  reinterpret_cast<VDataType*>(static_cast<char*>(smem_ptr) +
255  Policy::template GetExclusiveKLdsBytes<Problem>()),
256  Policy::template MakeVLdsBlockDescriptor<Problem>());
257  auto v_lds_window = make_tile_window(
258  v_lds, Policy::template MakeVLdsBlockDescriptor<Problem>().get_lengths(), {0, 0});
259 
260  using v_tile_type = decltype(load_tile(v_dram_window));
261 
263 
264  using v_lds_window_type =
265  decltype(get_slice_tile(v_lds_window, sequence<0, 0>{}, sequence<kN1, kK1>{}));
266 
268 
269  static_for<0, NumVLdsBuffers, 1>{}([&](auto i_buf) {
270  v_lds_windows[i_buf] = get_slice_tile(
271  v_lds_window, sequence<i_buf * kN1, 0>{}, sequence<(i_buf + 1) * kN1, kK1>{});
272  });
273 
274  // Block GEMM
275  constexpr auto gemm_0 = Policy::template GetQKBlockGemm<Problem>();
276  constexpr auto gemm_1 = Policy::template GetKVBlockGemm<Problem>();
277 
278  using SaccBlockTileType = decltype(gemm_0.MakeCBlockTile());
279  auto s_acc = SaccBlockTileType{};
280 
281  // reduction function for softmax
282  const auto f_max = [](auto e0, auto e1) { return max(e0, e1); };
283  const auto f_sum = [](auto e0, auto e1) { return e0 + e1; };
284 
285  // infer Sacc, S, P, M, L, Oacc type
286  using SBlockTileType = decltype(cast_tile<SMPLComputeDataType>(s_acc));
287 
288  using MLBlockTileType = decltype(block_tile_reduce<SMPLComputeDataType>(
289  SBlockTileType{}, sequence<1>{}, f_max, SMPLComputeDataType{0}));
290 
291  using OaccBlockTileType = decltype(gemm_1.MakeCBlockTile());
292 
293  // init Oacc, M, L
294  auto o_acc = OaccBlockTileType{};
295  auto m = MLBlockTileType{};
296  auto l = MLBlockTileType{};
297 
298  clear_tile(o_acc);
300  clear_tile(l);
301 
302  const auto num_total_loop = integer_divide_ceil(seqlen_k_end - seqlen_k_start, kN0);
303 
304  // check early exit if no work to do
305  if constexpr(FmhaMask::IsMasking || kPadSeqLenK)
306  {
307  if(num_total_loop <= 0)
308  {
309  if constexpr(kStoreLSE)
310  {
311  auto lse =
312  make_static_distributed_tensor<LSEDataType>(m.get_tile_distribution());
313 
315 
316  store_tile(lse_dram_window_tmp, tile_elementwise_in(lse_element_func, lse));
317  }
318 
319  // Note: here occ are all cleard, return it
320  // Note: q loaded but no fence, ignore it.
321  return o_acc;
322  }
323  }
324 
325  const auto bias_origin = bias_dram_block_window_tmp.get_window_origin();
326  auto bias_dram_window =
327  make_tile_window(bias_dram_block_window_tmp.get_bottom_tensor_view(),
328  bias_dram_block_window_tmp.get_window_lengths(),
329  {bias_origin.at(number<0>{}), seqlen_k_start}, // M/N
330  Policy::template MakeBiasDramTileDistribution<decltype(gemm_0)>());
331 
332  auto randval_dram_window = dropout.template MakeRandvalDramWindow<decltype(gemm_0)>(
333  randval_dram_block_window_tmp, seqlen_k_start);
334 
335  q_tile = tile_elementwise_in(q_element_func, q_tile);
336 
337  index_t i_total_loops = 0;
338 
339  do
340  {
341  if constexpr(kPreloadWholeNextIterationK)
342  {
343  if(i_total_loops == 0) // executed by fist iteration
344  {
345  if(num_total_loop > 1) // there are multiple iterations
346  {
347  static_for<0, k0_loops - 1, 1>{}([&](auto i_k0) {
348  store_tile(
349  k_lds_windows[number<i_k0 % NumKLdsBuffers>{}],
350  tile_elementwise_in(k_element_func, k_tiles[number<i_k0>{}]));
351 
352  k_tiles[number<i_k0 + 1>{}] = load_tile(k_dram_window);
353  if constexpr(i_k0 < k0_loops - 2)
354  move_tile_window(k_dram_window, {0, kK0});
355 
356  if constexpr(i_k0 == 0)
357  clear_tile(s_acc);
358 
359  block_sync_lds();
360  // execute current unroll of gemm_0
361  gemm_0(s_acc,
362  get_slice_tile(q_tile,
363  sequence<0, i_k0 * kK0>{},
364  sequence<kM0, (i_k0 + 1) * kK0>{}),
365  k_lds_windows[number<i_k0 % NumKLdsBuffers>{}]);
366  });
367 
368  store_tile(
369  k_lds_windows[number<(k0_loops - 1) % NumKLdsBuffers>{}],
370  tile_elementwise_in(k_element_func, k_tiles[number<k0_loops - 1>{}]));
371 
372  // prefetch first v_tile
373  v_tiles[I0] = load_tile(v_dram_window);
374  move_tile_window(v_dram_window, {0, kK1});
375 
376  move_tile_window(k_dram_window, {kN0, -(k0_loops - 1) * kK0});
377 
378  // prefetch all k_tiles for next iteration
379  static_for<0, k0_loops, 1>{}([&](auto i_k0) {
380  k_tiles[number<i_k0>{}] = load_tile(k_dram_window);
381 
382  if constexpr(i_k0 < k0_loops - 1)
383  move_tile_window(k_dram_window, {0, kK0});
384  });
385 
386  move_tile_window(k_dram_window, {0, -(k0_loops - 1) * kK0});
387 
388  block_sync_lds();
389  // execute last unroll of gemm_0
390  gemm_0(s_acc,
391  get_slice_tile(q_tile,
392  sequence<0, (k0_loops - 1) * kK0>{},
393  sequence<kM0, k0_loops * kK0>{}),
394  k_lds_windows[number<(k0_loops - 1) % NumKLdsBuffers>{}]);
395  }
396  else // there is only single iteration
397  {
398  static_for<0, k0_loops - 1, 1>{}([&](auto i_k0) {
399  store_tile(
400  k_lds_windows[number<i_k0 % NumKLdsBuffers>{}],
401  tile_elementwise_in(k_element_func, k_tiles[number<i_k0>{}]));
402 
403  k_tiles[number<i_k0 + 1>{}] = load_tile(k_dram_window);
404  if constexpr(i_k0 < k0_loops - 2)
405  move_tile_window(k_dram_window, {0, kK0});
406 
407  if constexpr(i_k0 == 0)
408  clear_tile(s_acc);
409 
410  block_sync_lds();
411  // execute current unroll of gemm_0
412  gemm_0(s_acc,
413  get_slice_tile(q_tile,
414  sequence<0, i_k0 * kK0>{},
415  sequence<kM0, (i_k0 + 1) * kK0>{}),
416  k_lds_windows[number<i_k0 % NumKLdsBuffers>{}]);
417  });
418 
419  store_tile(
420  k_lds_windows[number<(k0_loops - 1) % NumKLdsBuffers>{}],
421  tile_elementwise_in(k_element_func, k_tiles[number<k0_loops - 1>{}]));
422 
423  // prefetch first v_tile
424  v_tiles[I0] = load_tile(v_dram_window);
425  move_tile_window(v_dram_window, {0, kK1});
426 
427  block_sync_lds();
428  gemm_0(s_acc,
429  get_slice_tile(q_tile,
430  sequence<0, (k0_loops - 1) * kK0>{},
431  sequence<kM0, k0_loops * kK0>{}),
432  k_lds_windows[number<(k0_loops - 1) % NumKLdsBuffers>{}]);
433 
434  // move_tile_window(k_dram_window, {0, -k0_loops * kK0});
435  }
436  }
437  else // executed by intermediate and last iteration
438  {
439  if(i_total_loops < num_total_loop - 1) // intermediate iteration
440  {
441  store_tile(k_lds_windows[I0],
442  tile_elementwise_in(k_element_func, k_tiles[I0]));
443 
444  // prefetch first v_tile
445  v_tiles[I0] = load_tile(v_dram_window);
446  move_tile_window(v_dram_window, {0, kK1});
447 
448  clear_tile(s_acc);
449  block_sync_lds();
450  gemm_0(s_acc,
451  get_slice_tile(q_tile, sequence<0, 0>{}, sequence<kM0, kK0>{}),
452  k_lds_windows[I0]);
453 
454  store_tile(k_lds_windows[I1],
455  tile_elementwise_in(k_element_func, k_tiles[I1]));
456 
457  move_tile_window(k_dram_window, {kN0, 0});
458 
459  // prefetch first k_tile for next iteration
460  k_tiles[I0] = load_tile(k_dram_window);
461  move_tile_window(k_dram_window, {0, kK0});
462 
463  k_tiles[I1] = load_tile(k_dram_window);
464  if constexpr(1 < k0_loops - 1)
465  move_tile_window(k_dram_window, {0, kK0});
466 
467  block_sync_lds();
468  gemm_0(s_acc,
469  get_slice_tile(q_tile, sequence<0, kK0>{}, sequence<kM0, 2 * kK0>{}),
470  k_lds_windows[I1]);
471 
472  // during the gemm-loop, also prefetch other k_tiles for next iteration
473  static_for<2, k0_loops, 1>{}([&](auto i_k0) {
474  store_tile(k_lds_windows[number<i_k0 % NumKLdsBuffers>{}],
475  k_tiles[number<i_k0>{}]);
476 
477  k_tiles[number<i_k0>{}] = load_tile(k_dram_window);
478  if constexpr(i_k0 < k0_loops - 1)
479  move_tile_window(k_dram_window, {0, kK0});
480 
481  block_sync_lds();
482  gemm_0(s_acc,
483  get_slice_tile(q_tile,
484  sequence<0, i_k0 * kK0>{},
485  sequence<kM0, (i_k0 + 1) * kK0>{}),
486  k_lds_windows[number<i_k0 % NumKLdsBuffers>{}]);
487  });
488 
489  move_tile_window(k_dram_window, {0, -(k0_loops - 1) * kK0});
490  }
491  else // last iteration
492  {
493  store_tile(k_lds_windows[I0],
494  tile_elementwise_in(k_element_func, k_tiles[I0]));
495 
496  // prefetch first v_tile
497  v_tiles[I0] = load_tile(v_dram_window);
498  move_tile_window(v_dram_window, {0, kK1});
499 
500  clear_tile(s_acc);
501  block_sync_lds();
502  gemm_0(s_acc,
503  get_slice_tile(q_tile, sequence<0, 0>{}, sequence<kM0, kK0>{}),
504  k_lds_windows[I0]);
505 
506  static_for<1, k0_loops, 1>{}([&](auto i_k0) {
507  store_tile(
508  k_lds_windows[number<i_k0 % NumKLdsBuffers>{}],
509  tile_elementwise_in(k_element_func, k_tiles[number<i_k0>{}]));
510 
511  block_sync_lds();
512  gemm_0(s_acc,
513  get_slice_tile(q_tile,
514  sequence<0, i_k0 * kK0>{},
515  sequence<kM0, (i_k0 + 1) * kK0>{}),
516  k_lds_windows[number<i_k0 % NumKLdsBuffers>{}]);
517  });
518  };
519  };
520  }
521  else // only preload one unroll of K for next iteration
522  {
523  static_for<0, k0_loops - 1, 1>{}([&](auto i_k0) {
524  store_tile(k_lds_windows[number<i_k0 % NumKLdsBuffers>{}],
525  tile_elementwise_in(k_element_func, k_tiles[I0]));
526  if constexpr(i_k0 == 0)
527  clear_tile(s_acc);
528 
529  if constexpr(i_k0 < k0_loops - 1)
530  k_tiles[I0] = load_tile(k_dram_window);
531  if constexpr(i_k0 < k0_loops - 2)
532  move_tile_window(k_dram_window, {0, kK0});
533 
534  block_sync_lds();
535  // execute current unroll of gemm_0
536  gemm_0(s_acc,
537  get_slice_tile(q_tile,
538  sequence<0, i_k0 * kK0>{},
539  sequence<kM0, (i_k0 + 1) * kK0>{}),
540  k_lds_windows[number<i_k0 % NumKLdsBuffers>{}]);
541  });
542 
543  store_tile(k_lds_windows[number<(k0_loops - 1) % NumKLdsBuffers>{}],
544  tile_elementwise_in(k_element_func, k_tiles[I0]));
545 
546  // prefetch first v_tile
547  v_tiles[I0] = load_tile(v_dram_window);
548  move_tile_window(v_dram_window, {0, kK1});
549 
550  block_sync_lds();
551  gemm_0(s_acc,
552  get_slice_tile(q_tile,
553  sequence<0, (k0_loops - 1) * kK0>{},
554  sequence<kM0, k0_loops * kK0>{}),
555  k_lds_windows[number<(k0_loops - 1) % NumKLdsBuffers>{}]);
556  };
557 
558  __builtin_amdgcn_sched_barrier(0);
559 
560  const auto bias_tile = load_tile(bias_dram_window); // load bias tile
561 
562  static_for<1, NumPrefetchV, 1>{}([&](auto i_buf) {
563  v_tiles[i_buf] = load_tile(v_dram_window);
564  move_tile_window(v_dram_window, {0, kK1});
565  });
566 
567  // STAGE 2, scale_s, add bias, mask, softmax
569  {
570  s_acc = tile_elementwise_in(s_acc_element_func, s_acc);
571  tile_elementwise_inout([&scale_s](auto& x) { x = x * scale_s; }, s_acc);
573  [&](auto& x, const auto& y) {
574 #if !CK_TILE_FMHA_FWD_FAST_EXP2
575  x += type_convert<SaccDataType>(bias_element_func(y));
576 #else
577  x += log2e_v<SaccDataType> *
578  type_convert<SaccDataType>(bias_element_func(y));
579 #endif
580  },
581  s_acc,
582  bias_tile);
583  }
584  else if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI)
585  {
586  const auto k_origin = k_dram_block_window.get_window_origin();
587  constexpr auto s_spans = decltype(s_acc)::get_distributed_spans();
588  s_acc = tile_elementwise_in(s_acc_element_func, s_acc);
589  sweep_tile_span(s_spans[number<0>{}], [&](auto idx0) {
590  sweep_tile_span(s_spans[number<1>{}], [&](auto idx1) {
591  const auto tile_idx = get_x_indices_from_distributed_indices(
592  s_acc.get_tile_distribution(), make_tuple(idx0, idx1));
593 
594  const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
595  const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
596  constexpr auto i_j_idx = make_tuple(idx0, idx1);
597 
598  s_acc(i_j_idx) *= scale_s;
599  position_encoding.update(s_acc(i_j_idx), row, col);
600  });
601  });
602  }
603  else
604  {
605  s_acc = tile_elementwise_in(s_acc_element_func, s_acc);
606 #if !CK_TILE_FMHA_FWD_FAST_EXP2
607  tile_elementwise_inout([&scale_s](auto& x) { x = x * scale_s; }, s_acc);
608 #endif
609  }
610  move_tile_window(bias_dram_window, {0, kN0});
611  if constexpr(kPadSeqLenK || FmhaMask::IsMasking)
612  {
613  const auto k_origin = k_dram_block_window.get_window_origin();
614  bool need_perpixel_check = mask.IsEdgeTile(q_origin.at(number<0>{}),
615  k_origin.at(number<0>{}),
616  number<kM0>{},
617  number<kN0>{});
618  if(need_perpixel_check)
619  {
620  set_tile_if(
621  s_acc, -numeric<SMPLComputeDataType>::infinity(), [&](auto tile_idx) {
622  const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
623  const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
624  return mask.IsOutOfBound(row, col);
625  });
626  }
627  }
628 
629  const auto s = cast_tile<SMPLComputeDataType>(s_acc); // S{j}
630  auto m_local = block_tile_reduce<SMPLComputeDataType>(
631  s,
632  sequence<1>{},
633  f_max,
634  -numeric<SMPLComputeDataType>::infinity()); // m_local = rowmax(S{j})
635  block_tile_reduce_sync(m_local, f_max, bool_constant<false>{});
636 
637  const auto m_old = m; // m{j-1}
639  [](auto& e0, auto e1, auto e2) { e0 = max(e1, e2); }, m, m_old, m_local); // m{j}
640 
641  auto p_compute = make_static_distributed_tensor<SMPLComputeDataType>(
642  s.get_tile_distribution()); // Pcompute{j}
643 
644  static const auto get_validated_m = [](SMPLComputeDataType raw_m) {
648  FmhaMask::IsMasking)
649  {
650  return raw_m == -numeric<SMPLComputeDataType>::infinity()
651  ? type_convert<SMPLComputeDataType>(0.f)
652  : raw_m;
653  }
654  else
655  {
656  return raw_m;
657  }
658  };
659 
660  constexpr auto p_spans = decltype(p_compute)::get_distributed_spans();
661  sweep_tile_span(p_spans[number<0>{}], [&](auto idx0) {
662  constexpr auto i_idx = make_tuple(idx0);
663 #if CK_TILE_FMHA_FWD_FAST_EXP2
664  auto row_max = scale_s * get_validated_m(m[i_idx]);
665 #endif
666  sweep_tile_span(p_spans[number<1>{}], [&](auto idx1) {
667  constexpr auto i_j_idx = make_tuple(idx0, idx1);
668 #if CK_TILE_FMHA_FWD_FAST_EXP2
671  {
672  p_compute(i_j_idx) = exp2(s[i_j_idx] - get_validated_m(m[i_idx]));
673  }
674  else
675  {
676  p_compute(i_j_idx) = exp2(scale_s * s[i_j_idx] - row_max);
677  }
678 #else
679  p_compute(i_j_idx) = exp(s[i_j_idx] - get_validated_m(m[i_idx]));
680 #endif
681  });
682  });
683 
684  auto rowsum_p = block_tile_reduce<SMPLComputeDataType>(
685  p_compute, sequence<1>{}, f_sum, SMPLComputeDataType{0}); // rowsum(Pcompute{j})
686 
687  block_tile_reduce_sync(rowsum_p, f_sum, bool_constant<false>{});
688  // l{j}, Oacc{j}
689  constexpr auto o_spans = decltype(o_acc)::get_distributed_spans();
690  sweep_tile_span(o_spans[number<0>{}], [&](auto idx0) {
691  constexpr auto i_idx = make_tuple(idx0);
692 #if CK_TILE_FMHA_FWD_FAST_EXP2
693  const auto tmp = [&]() {
696  {
697  return exp2(m_old[i_idx] - get_validated_m(m[i_idx]));
698  }
699  else
700  {
701  auto row_max = scale_s * get_validated_m(m[i_idx]);
702  return exp2(scale_s * m_old[i_idx] - row_max);
703  }
704  }();
705 #else
706  const auto tmp = exp(m_old[i_idx] - get_validated_m(m[i_idx]));
707 #endif
708  l(i_idx) = tmp * l[i_idx] + rowsum_p[i_idx];
709  sweep_tile_span(o_spans[number<1>{}], [&](auto idx1) {
710  constexpr auto i_j_idx = make_tuple(idx0, idx1);
711  // FIXME: this use different equation from FA v2 paper,
712  // but produce correc result.
713  // Is the equation wrong?
714  o_acc(i_j_idx) *= tmp;
715  });
716  });
717 
718  if constexpr(kHasDropout)
719  {
720  auto randval_ptr =
721  reinterpret_cast<char*>(smem_ptr) + Policy::template GetSmemSizeK<Problem>();
722  dropout.template Run<decltype(gemm_0), SMPLComputeDataType, RandValOutputDataType>(
723  smem_ptr, seqlen_k_start + i_total_loops * kN0, p_compute, randval_dram_window);
724  }
725 
726  __builtin_amdgcn_sched_barrier(0x7f);
727 
728  if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
729  {
730  auto v_shuffle_tmp = make_static_distributed_tensor<VDataType>(
731  Policy::template MakeShuffledVRegBlockDescriptor<Problem>());
732  shuffle_tile(v_shuffle_tmp, v_tiles[I0]);
733 
734  store_tile(
735  v_lds_windows[I0],
736  tile_elementwise_in(v_element_func, v_shuffle_tmp)); // store the prefetch
737  }
738  else
739  {
740  store_tile(v_lds_windows[I0],
741  tile_elementwise_in(v_element_func, v_tiles[I0])); // store the prefetch
742  }
743 
744  __builtin_amdgcn_sched_barrier(0);
745 
746  const auto p =
747  cast_tile<PDataType>(tile_elementwise_in(p_compute_element_func, p_compute));
748 
749  if constexpr(!kPreloadWholeNextIterationK)
750  {
751  if(i_total_loops < num_total_loop - 1)
752  {
753  move_tile_window(k_dram_window, {kN0, -(k0_loops - 1) * kK0});
754  k_tiles[I0] = load_tile(k_dram_window);
755  move_tile_window(k_dram_window, {0, kK0});
756  };
757 
758  __builtin_amdgcn_sched_barrier(0);
759  }
760 
761  // STAGE 3, KV gemm
762  if constexpr(k1_loops > 1)
763  {
764  if constexpr(NumPrefetchV == 1) // NumVLdsBuffers == 2
765  {
766  static_for<0, k1_loops - 1, 1>{}([&](auto i_k1) {
767  v_tiles[I0] = load_tile(v_dram_window);
768 
769  block_sync_lds();
770  gemm_1(o_acc,
772  p, sequence<0, i_k1 * kK1>{}, sequence<kM0, (i_k1 + 1) * kK1>{}),
773  v_lds_windows[number<i_k1 % NumVLdsBuffers>{}]);
774 
775  if constexpr(std::is_same_v<VLayout,
777  {
778  auto v_shuffle_tmp = make_static_distributed_tensor<VDataType>(
779  Policy::template MakeShuffledVRegBlockDescriptor<Problem>());
780  shuffle_tile(v_shuffle_tmp, v_tiles[I0]);
781  store_tile(v_lds_windows[number<(i_k1 + 1) % NumVLdsBuffers>{}],
782  tile_elementwise_in(v_element_func, v_shuffle_tmp));
783  }
784  else
785  {
786  store_tile(v_lds_windows[number<(i_k1 + 1) % NumVLdsBuffers>{}],
787  tile_elementwise_in(v_element_func, v_tiles[I0]));
788  }
789 
790  move_tile_window(v_dram_window, {0, kK1});
791  });
792  }
793  else // NumVLdsBuffers == 3 or 2
794  {
795  static_for<0, k1_loops - 1, 1>{}([&](auto i_k1) {
796  if constexpr(i_k1 < k1_loops - NumPrefetchV)
797  v_tiles[number<i_k1 % NumPrefetchV>{}] = load_tile(v_dram_window);
798 
799  block_sync_lds();
800  gemm_1(o_acc,
802  p, sequence<0, i_k1 * kK1>{}, sequence<kM0, (i_k1 + 1) * kK1>{}),
803  v_lds_windows[number<i_k1 % NumVLdsBuffers>{}]);
804 
805  if constexpr(std::is_same_v<VLayout,
807  {
808  auto v_shuffle_tmp = make_static_distributed_tensor<VDataType>(
809  Policy::template MakeShuffledVRegBlockDescriptor<Problem>());
810  shuffle_tile(v_shuffle_tmp,
811  v_tiles[number<(i_k1 + 1) % NumPrefetchV>{}]);
812  store_tile(v_lds_windows[number<(i_k1 + 1) % NumVLdsBuffers>{}],
813  tile_elementwise_in(v_element_func, v_shuffle_tmp));
814  }
815  else
816  {
817  store_tile(
818  v_lds_windows[number<(i_k1 + 1) % NumVLdsBuffers>{}],
819  tile_elementwise_in(v_element_func,
820  v_tiles[number<(i_k1 + 1) % NumPrefetchV>{}]));
821  }
822 
823  if constexpr(i_k1 < k1_loops - NumPrefetchV)
824  move_tile_window(v_dram_window, {0, kK1});
825  });
826  }
827  }
828  // move K tile windows
829  move_tile_window(k_dram_block_window, {kN0, 0});
830 
831  block_sync_lds();
832  gemm_1(o_acc,
833  get_slice_tile(p, sequence<0, (k1_loops - 1) * kK1>{}, sequence<kM0, kN0>{}),
834  v_lds_windows[number<(k1_loops - 1) % NumVLdsBuffers>{}]);
835 
836  if constexpr(Policy::template IsFirstKLdsBufferOverlapLastVLdsBuffer<Problem>())
837  {
838  __builtin_amdgcn_sched_barrier(0);
839  __builtin_amdgcn_s_barrier();
840  };
841 
842  } while(++i_total_loops < num_total_loop);
843 
844  // store lse
845  if constexpr(kStoreLSE)
846  {
847  auto lse = make_static_distributed_tensor<LSEDataType>(m.get_tile_distribution());
848 
849  constexpr auto lse_spans = decltype(lse)::get_distributed_spans();
850  sweep_tile_span(lse_spans[number<0>{}], [&, m_ = m, l_ = l](auto idx0) {
851  constexpr auto i_idx = make_tuple(idx0);
852 #if CK_TILE_FMHA_FWD_FAST_EXP2
855  {
856  lse(i_idx) = m_[i_idx] / C_LOG2E + log(l_[i_idx]);
857  }
858  else
859  {
860  lse(i_idx) = m_[i_idx] * scale_s / C_LOG2E + log(l_[i_idx]);
861  }
862 #else
863  lse(i_idx) = m_[i_idx] + log(l_[i_idx]);
864 #endif
865  });
866 
867  store_tile(lse_dram_window_tmp, tile_elementwise_in(lse_element_func, lse));
868  }
869 
870  // finally, O
871  constexpr auto o_spans = decltype(o_acc)::get_distributed_spans();
872 
873  sweep_tile_span(o_spans[number<0>{}], [&](auto idx0) {
874  constexpr auto i_idx = make_tuple(idx0);
875  const auto tmp = [&]() {
876  if constexpr(FmhaMask::IsMasking)
877  {
878  return l[i_idx] == 0.f ? 0.f : 1 / l[i_idx];
879  }
880  else
881  return 1 / l[i_idx];
882  }();
883  sweep_tile_span(o_spans[number<1>{}], [&](auto idx1) {
884  constexpr auto i_j_idx = make_tuple(idx0, idx1);
885  o_acc(i_j_idx) *= tmp;
886  });
887  });
888 
889  o_acc = tile_elementwise_in(o_acc_element_func, o_acc);
890 
891  return o_acc;
892  }
893 
894  template <typename QDramBlockWindowTmp,
895  typename KDramBlockWindowTmp,
896  typename VDramBlockWindowTmp,
897  typename BiasDramBlockWindowTmp,
898  typename RandValDramBlockWindowTmp,
899  typename LSEDramBlockWindowTmp,
900  typename PositionEncoding,
901  typename AttentionVariantParams,
902  typename BlockIndices>
904  operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile
905  const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile
906  const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile
907  const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile
908  RandValDramBlockWindowTmp& randval_dram_block_window_tmp, // M0*N0 tile
909  LSEDramBlockWindowTmp& lse_dram_block_window_tmp, // M0*1 tile
910  FmhaMask mask,
911  PositionEncoding position_encoding,
912  float scale_s,
913  const AttentionVariant& variant,
914  const AttentionVariantParams& variant_params,
915  const BlockIndices& block_indices,
916  void* smem_ptr,
917  DropoutType& dropout) const
918  {
919  return operator()(q_dram_block_window_tmp,
920  identity{},
921  k_dram_block_window_tmp,
922  identity{},
923  v_dram_block_window_tmp,
924  identity{},
925  bias_dram_block_window_tmp,
926  identity{},
927  randval_dram_block_window_tmp,
928  lse_dram_block_window_tmp,
929  identity{},
930  identity{},
931  identity{},
932  identity{},
933  mask,
934  position_encoding,
935  scale_s,
936  variant,
937  variant_params,
938  block_indices,
939  smem_ptr,
940  dropout);
941  }
942 };
943 
944 } // namespace ck_tile
CK_TILE_DEVICE void block_sync_lds()
Definition: arch.hpp:190
#define CK_TILE_HOST_DEVICE
Definition: config.hpp:41
Definition: cluster_descriptor.hpp:13
CK_TILE_DEVICE bfloat16_t log(bfloat16_t x)
Definition: bfloat16.hpp:432
constexpr CK_TILE_HOST_DEVICE auto integer_divide_ceil(X x, Y y)
Definition: math.hpp:149
constexpr CK_TILE_DEVICE auto get_slice_tile(const tile_window_with_static_lengths< BottomTensorView_, WindowLengths_ > &tile, sequence< SliceBegins... > slice_begins, sequence< SliceEnds... > slice_ends)
Definition: slice_tile.hpp:23
CK_TILE_DEVICE auto tile_elementwise_in(const InElementFunc &in_element_func, const InTensor &... in_dstr_tensors)
Definition: tile_elementwise.hpp:40
CK_TILE_DEVICE void set_tile(DstrTensors &dstr_tensor, const T &value)
Definition: tile_elementwise.hpp:95
CK_TILE_DEVICE void block_tile_reduce_sync(AccDistributedTensor_ &acc_tensor, const ReduceFunc &reduce_func, bool_constant< WithBroadcast >={}, bool_constant< CrossWarp >={})
Definition: block_reduce.hpp:21
CK_TILE_DEVICE void tile_elementwise_inout(const InOutElementFunc &inout_element_func, InOutDstrTensors &... inout_dstr_tensors)
Definition: tile_elementwise.hpp:23
tuple_array< T, N > statically_indexed_array
Definition: statically_indexed_array.hpp:16
int32_t index_t
Definition: integer.hpp:9
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition: type_traits.hpp:21
constant< v > number
Definition: integral_constant.hpp:37
CK_TILE_DEVICE void shuffle_tile(OutTensor &out, const InTensor &in)
Definition: shuffle_tile.hpp:154
constexpr detail::ignore_t ignore
Definition: ignore.hpp:20
CK_TILE_DEVICE bfloat16_t exp(bfloat16_t x)
Definition: bfloat16.hpp:423
CK_TILE_DEVICE void sweep_tile_span(TileDistributedSpan_, const F &f)
Definition: sweep_tile.hpp:20
constexpr CK_TILE_DEVICE auto make_tile_window(null_tensor_view, const WindowLengths &window_lengths, const multi_index< WindowLengths::size()> &, Ts &&...)
Definition: null_tile_window.hpp:75
CK_TILE_DEVICE void move_tile_window(null_tile_window< WindowLengths > &, const typename null_tile_window< WindowLengths >::BottomTensorIndex &)
Definition: null_tile_window.hpp:95
CK_TILE_HOST_DEVICE void set_tile_if(static_distributed_tensor< DataType, StaticTileDistribution > &out_tensor, DataType value, XIndicesPredicate predicate)
Definition: static_distributed_tensor.hpp:175
constexpr CK_TILE_HOST_DEVICE auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:360
constexpr CK_TILE_HOST_DEVICE auto get_x_indices_from_distributed_indices(StaticTileDistribution tile_distribution, DistributedIndices distributed_indices)
Definition: static_distributed_tensor.hpp:159
CK_TILE_DEVICE void store_tile(tile_window_with_static_lengths< BottomTensorView_, WindowLengths_ > &tile_window_tmp, const static_distributed_tensor< DataType_, TileDistribution_ > &dstr_tensor)
Definition: store_tile.hpp:23
CK_TILE_DEVICE void clear_tile(DstrTensors &dstr_tensor)
Definition: tile_elementwise.hpp:177
constexpr CK_TILE_HOST_DEVICE T max(T x)
Definition: math.hpp:161
CK_TILE_DEVICE auto load_tile(const TileWindow_ &tile_window, number< i_access >={}, bool_constant< oob_conditional_check >={})
Definition: load_tile.hpp:22
CK_TILE_DEVICE bfloat16_t exp2(bfloat16_t x)
Definition: bfloat16.hpp:429
constexpr bool is_same_v
Definition: type.hpp:283
Definition: block_fmha_pipeline_qr_ks_vs_whole_k_prefetch.hpp:16
remove_cvref_t< typename Problem::LSEDataType > LSEDataType
Definition: block_fmha_pipeline_qr_ks_vs_whole_k_prefetch.hpp:26
static constexpr index_t kM0
Definition: block_fmha_pipeline_qr_ks_vs_whole_k_prefetch.hpp:40
static constexpr index_t kBlockSize
Definition: block_fmha_pipeline_qr_ks_vs_whole_k_prefetch.hpp:38
CK_TILE_HOST_DEVICE auto operator()(const QDramBlockWindowTmp &q_dram_block_window_tmp, const QElementFunction &q_element_func, const KDramBlockWindowTmp &k_dram_block_window_tmp, const KElementFunction &k_element_func, const VDramBlockWindowTmp &v_dram_block_window_tmp, const VElementFunction &v_element_func, const BiasDramBlockWindowTmp &bias_dram_block_window_tmp, const BiasElementFunction &bias_element_func, RandValDramBlockWindowTmp &randval_dram_block_window_tmp, LSEDramBlockWindowTmp &lse_dram_window_tmp, const LSEElementFunction &lse_element_func, const SAccElementFunction &s_acc_element_func, const PComputeElementFunction &p_compute_element_func, const OAccElementFunction &o_acc_element_func, FmhaMask mask, PositionEncoding position_encoding, float scale_s, const AttentionVariant &, const AttentionVariantParams &, const BlockIndices &, void *smem_ptr, DropoutType &dropout) const
Definition: block_fmha_pipeline_qr_ks_vs_whole_k_prefetch.hpp:136
remove_cvref_t< typename BlockFmhaShape::VLayout > VLayout
Definition: block_fmha_pipeline_qr_ks_vs_whole_k_prefetch.hpp:34
static constexpr bool kIsGroupMode
Definition: block_fmha_pipeline_qr_ks_vs_whole_k_prefetch.hpp:50
static constexpr index_t kN1
Definition: block_fmha_pipeline_qr_ks_vs_whole_k_prefetch.hpp:43
remove_cvref_t< typename Problem::QDataType > QDataType
Definition: block_fmha_pipeline_qr_ks_vs_whole_k_prefetch.hpp:19
static constexpr bool kPadHeadDimV
Definition: block_fmha_pipeline_qr_ks_vs_whole_k_prefetch.hpp:54
static constexpr index_t kAlignmentO
Definition: block_fmha_pipeline_qr_ks_vs_whole_k_prefetch.hpp:73
remove_cvref_t< typename Problem::AttentionVariant > AttentionVariant
Definition: block_fmha_pipeline_qr_ks_vs_whole_k_prefetch.hpp:31
static constexpr bool kPadHeadDimQ
Definition: block_fmha_pipeline_qr_ks_vs_whole_k_prefetch.hpp:53
static constexpr index_t kAlignmentQ
Definition: block_fmha_pipeline_qr_ks_vs_whole_k_prefetch.hpp:62
static constexpr bool kPadSeqLenK
Definition: block_fmha_pipeline_qr_ks_vs_whole_k_prefetch.hpp:52
static constexpr index_t kN0
Definition: block_fmha_pipeline_qr_ks_vs_whole_k_prefetch.hpp:41
static constexpr index_t kAlignmentBias
Definition: block_fmha_pipeline_qr_ks_vs_whole_k_prefetch.hpp:75
static constexpr bool kQLoadOnce
Definition: block_fmha_pipeline_qr_ks_vs_whole_k_prefetch.hpp:35
static constexpr CK_TILE_HOST_DEVICE ck_tile::index_t GetSmemSize()
Definition: block_fmha_pipeline_qr_ks_vs_whole_k_prefetch.hpp:113
remove_cvref_t< typename Problem::ODataType > ODataType
Definition: block_fmha_pipeline_qr_ks_vs_whole_k_prefetch.hpp:29
static constexpr index_t kK0
Definition: block_fmha_pipeline_qr_ks_vs_whole_k_prefetch.hpp:42
remove_cvref_t< typename Problem::BlockFmhaShape > BlockFmhaShape
Definition: block_fmha_pipeline_qr_ks_vs_whole_k_prefetch.hpp:33
static constexpr index_t kK1
Definition: block_fmha_pipeline_qr_ks_vs_whole_k_prefetch.hpp:44
remove_cvref_t< typename Problem::PDataType > PDataType
Definition: block_fmha_pipeline_qr_ks_vs_whole_k_prefetch.hpp:27
remove_cvref_t< typename Problem::VDataType > VDataType
Definition: block_fmha_pipeline_qr_ks_vs_whole_k_prefetch.hpp:21
remove_cvref_t< typename Problem::BiasDataType > BiasDataType
Definition: block_fmha_pipeline_qr_ks_vs_whole_k_prefetch.hpp:24
static constexpr index_t kAlignmentV
Definition: block_fmha_pipeline_qr_ks_vs_whole_k_prefetch.hpp:66
static constexpr index_t kQKHeaddim
Definition: block_fmha_pipeline_qr_ks_vs_whole_k_prefetch.hpp:45
static constexpr bool kHasDropout
Definition: block_fmha_pipeline_qr_ks_vs_whole_k_prefetch.hpp:57
static constexpr bool kStoreLSE
Definition: block_fmha_pipeline_qr_ks_vs_whole_k_prefetch.hpp:56
remove_cvref_t< typename Problem::SaccDataType > SaccDataType
Definition: block_fmha_pipeline_qr_ks_vs_whole_k_prefetch.hpp:22
remove_cvref_t< typename Problem::FmhaMask > FmhaMask
Definition: block_fmha_pipeline_qr_ks_vs_whole_k_prefetch.hpp:30
CK_TILE_HOST_DEVICE auto operator()(const QDramBlockWindowTmp &q_dram_block_window_tmp, const KDramBlockWindowTmp &k_dram_block_window_tmp, const VDramBlockWindowTmp &v_dram_block_window_tmp, const BiasDramBlockWindowTmp &bias_dram_block_window_tmp, RandValDramBlockWindowTmp &randval_dram_block_window_tmp, LSEDramBlockWindowTmp &lse_dram_block_window_tmp, FmhaMask mask, PositionEncoding position_encoding, float scale_s, const AttentionVariant &variant, const AttentionVariantParams &variant_params, const BlockIndices &block_indices, void *smem_ptr, DropoutType &dropout) const
Definition: block_fmha_pipeline_qr_ks_vs_whole_k_prefetch.hpp:904
remove_cvref_t< typename Problem::RandValOutputDataType > RandValOutputDataType
Definition: block_fmha_pipeline_qr_ks_vs_whole_k_prefetch.hpp:25
remove_cvref_t< Policy_ > Policy
Definition: block_fmha_pipeline_qr_ks_vs_whole_k_prefetch.hpp:18
static constexpr bool kHasLogitsSoftCap
Definition: block_fmha_pipeline_qr_ks_vs_whole_k_prefetch.hpp:58
static constexpr auto BiasEnum
Definition: block_fmha_pipeline_qr_ks_vs_whole_k_prefetch.hpp:55
static constexpr const char * name
Definition: block_fmha_pipeline_qr_ks_vs_whole_k_prefetch.hpp:109
remove_cvref_t< Problem_ > Problem
Definition: block_fmha_pipeline_qr_ks_vs_whole_k_prefetch.hpp:17
remove_cvref_t< typename Problem::KDataType > KDataType
Definition: block_fmha_pipeline_qr_ks_vs_whole_k_prefetch.hpp:20
static constexpr index_t kAlignmentK
Definition: block_fmha_pipeline_qr_ks_vs_whole_k_prefetch.hpp:64
static constexpr bool kPadSeqLenQ
Definition: block_fmha_pipeline_qr_ks_vs_whole_k_prefetch.hpp:51
static constexpr index_t kSubQKHeaddim
Definition: block_fmha_pipeline_qr_ks_vs_whole_k_prefetch.hpp:46
remove_cvref_t< typename Problem::OaccDataType > OaccDataType
Definition: block_fmha_pipeline_qr_ks_vs_whole_k_prefetch.hpp:28
static constexpr index_t kBlockPerCu
Definition: block_fmha_pipeline_qr_ks_vs_whole_k_prefetch.hpp:78
remove_cvref_t< typename Problem::SMPLComputeDataType > SMPLComputeDataType
Definition: block_fmha_pipeline_qr_ks_vs_whole_k_prefetch.hpp:23
std::conditional_t< kHasDropout, BlockDropout, NullBlockDropout > DropoutType
Definition: block_fmha_pipeline_qr_ks_vs_whole_k_prefetch.hpp:111
Definition: integral_constant.hpp:13
Definition: functional.hpp:86
Definition: numeric.hpp:18
static constexpr CK_TILE_HOST_DEVICE T infinity()
Definition: numeric.hpp:38
Definition: sequence.hpp:49
Definition: functional.hpp:43
Definition: tensor_layout.hpp:17
#define C_LOG2E
Definition: math.hpp:469