/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/fmha/kernel/fmha_fwd_pagedkv_kernel.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/fmha/kernel/fmha_fwd_pagedkv_kernel.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/fmha/kernel/fmha_fwd_pagedkv_kernel.hpp Source File
fmha_fwd_pagedkv_kernel.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include "ck_tile/core.hpp"
7 #include "ck_tile/ops/common.hpp"
10 
11 #include <string>
12 #include <type_traits>
13 #include <utility>
14 #include <variant>
15 
16 // S[seqlen_q, seqlen_k] = Q[seqlen_q, hdim_q] @ K[seqlen_k, hdim_q]
17 // S'[seqlen_q, seqlen_k] = S[seqlen_q, seqlen_k] * Scale[1]
18 // S''[seqlen_q, seqlen_k] = S'[seqlen_q, seqlen_k] + Bias[seqlen_q, seqlen_k]
19 // P[seqlen_q, seqlen_k] = Softmax(S''[seqlen_q, seqlen_k])
20 // O[seqlen_q, hdim_v] = P[seqlen_q, seqlen_k] @ V^T[hdim_v, seqlen_k]
21 
22 namespace ck_tile {
23 
24 // TODO: This class is a variant of the existing FmhaFwdSplitKVKernel pipeline.
25 // Refactoring to extract shared logic is recommended as future work.
26 template <typename FmhaPipeline_, typename EpiloguePipeline_>
28 {
31  static constexpr ck_tile::index_t kBlockSize = FmhaPipeline::kBlockSize;
32  static constexpr ck_tile::index_t kBlockPerCu = FmhaPipeline::kBlockPerCu;
33 
34  static_assert(kBlockPerCu > 0);
35  static constexpr ck_tile::index_t kBlockPerCuInput = FmhaPipeline::Problem::kBlockPerCu;
36 
44 
46 
47  static constexpr bool kIsGroupMode = FmhaPipeline::kIsGroupMode;
48  static constexpr bool kPadSeqLenQ = FmhaPipeline::kPadSeqLenQ;
49  static constexpr bool kPadSeqLenK = FmhaPipeline::kPadSeqLenK;
50  static constexpr bool kPadHeadDimQ = FmhaPipeline::kPadHeadDimQ;
51  static constexpr bool kPadHeadDimV = FmhaPipeline::kPadHeadDimV;
52  static constexpr bool kHasLogitsSoftCap = FmhaPipeline::kHasLogitsSoftCap;
53  static constexpr auto BiasEnum = FmhaPipeline::BiasEnum;
54  static constexpr bool kStoreLSE = FmhaPipeline::kStoreLSE;
55  static constexpr bool kDoFp8StaticQuant = FmhaPipeline::Problem::kDoFp8StaticQuant;
56  static constexpr bool kSkipMinSeqlenQ = FmhaPipeline::Problem::kSkipMinSeqlenQ;
57  static constexpr bool kIsPagedKV = FmhaPipeline::Problem::kIsPagedKV;
58 
61  static constexpr bool kHasMask = FmhaMask::IsMasking;
62 
63  static constexpr bool kUseAsyncCopy = FmhaPipeline::Policy::AsyncCopy;
64 
65  // clang-format off
66  template <typename T> struct t2s;
67  template <> struct t2s<float> { static constexpr const char * name = "fp32"; };
68  template <> struct t2s<ck_tile::fp16_t> { static constexpr const char * name = "fp16"; };
69  template <> struct t2s<ck_tile::bf16_t> { static constexpr const char * name = "bf16"; };
70  template <> struct t2s<ck_tile::fp8_t> { static constexpr const char * name = "fp8"; };
71  template <> struct t2s<ck_tile::bf8_t> { static constexpr const char * name = "bf8"; };
72  // clang-format on
73 
74  CK_TILE_HOST static std::string GetName()
75  {
76  // sync with generate.py
77  // clang-format off
78  using bfs = typename FmhaPipeline::BlockFmhaShape;
79  using g0br = typename bfs::Gemm0BlockWarps;
80  using g1br = typename bfs::Gemm1BlockWarps;
81  using g0wt = typename bfs::Gemm0WarpTile;
82  using g1wt = typename bfs::Gemm1WarpTile;
83  #define _SS_ std::string
84  #define _TS_ std::to_string
85  auto pn = [&] () {
86  std::string n;
87  if (kPadSeqLenQ) n += "s";
88  if (kPadSeqLenK) n += "sk";
89  if (kPadHeadDimQ) n += "d";
90  if (kPadHeadDimV) n += "dv";
91  return n.empty() ? n : std::string("p") + n; }();
92  return
93  _SS_("fmha_fwd_pagedkv_d") + _TS_(bfs::kQKHeaddim) + "_" + _SS_(t2s<QDataType>::name) +
94  "_" + (kIsGroupMode ? "group" : "batch") + "_"
95  "b" + _TS_(bfs::kM0) + "x" + _TS_(bfs::kN0) + "x" + _TS_(bfs::kK0) + "x" +
96  _TS_(bfs::kN1) + "x" + _TS_(bfs::kK1) + "x" + _TS_(bfs::kQKHeaddim) + "_" +
97  "r" + _TS_(g0br::at(ck_tile::number<0>{})) + "x" + _TS_(g0br::at(ck_tile::number<1>{})) + "x" + _TS_(g0br::at(ck_tile::number<2>{})) + "_" +
98  "r" + _TS_(g1br::at(ck_tile::number<0>{})) + "x" + _TS_(g1br::at(ck_tile::number<1>{})) + "x" + _TS_(g1br::at(ck_tile::number<2>{})) + "_" +
99  "w" + _TS_(g0wt::at(ck_tile::number<0>{})) + "x" + _TS_(g0wt::at(ck_tile::number<1>{})) + "x" + _TS_(g0wt::at(ck_tile::number<2>{})) + "_" +
100  "w" + _TS_(g1wt::at(ck_tile::number<0>{})) + "x" + _TS_(g1wt::at(ck_tile::number<1>{})) + "x" + _TS_(g1wt::at(ck_tile::number<2>{})) + "_" +
101  (kBlockPerCuInput == -1 ? "" : ("o" + _TS_(kBlockPerCu) + "_")) + _SS_(FmhaPipeline::name) + "_" +
102  "v" + (std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor> ? "r" : "c") + (pn.empty() ? "_npad" : "_" + pn) +
103  (kHasLogitsSoftCap ? "_logits" : "_nlogits" ) + (BiasEnum == BlockAttentionBiasEnum::NO_BIAS ? _SS_("_nbias") : (_SS_("_") + BlockAttentionBiasEnumToStr<BiasEnum>::name)) +
104  (kHasMask ? "_" + _SS_(FmhaMask::name) : "_nmask") + (kStoreLSE ? "_lse" : "_nlse" ) + (kSkipMinSeqlenQ ? "_skip" : "_nskip" ) + (kDoFp8StaticQuant ? "_squant" : "_nsquant" ) + (kIsPagedKV ? "_pagedkv" : "_npagedkv" );
105  #undef _SS_
106  #undef _TS_
107  // clang-format on
108  }
109 
110  template <ck_tile::index_t I> // to avoid duplicated base class prblem, introduce an template
111  // arg
113  {
114  };
115 
116  // kargs use aggregate initializer, so no constructor will provided
117  // use inheritance to minimize karg size
118  // user need to use MakeKargs() function to create kargs.
120  {
121  const void* q_ptr;
122  const void* k_ptr;
123  const void* v_ptr;
124  void* o_ptr;
125 
130 
132  // for MQA/GQA, nhead could be different. This parameter is nhead_q / nhead_k
133  // if this param is larger than 1, indicate MQA/GQA case
135  float scale_s;
136 
141 
146  };
147 
149  {
151 
152  void init_logits_soft_cap(float logits_soft_cap_)
153  {
154  if(0 < logits_soft_cap_)
155  {
156  logits_soft_cap = logits_soft_cap_;
158  }
159  else
160  {
161  logits_soft_cap = 0.f;
162  logits_soft_cap_rcp = 0.f;
163  }
164  }
165 
168  };
169 
171  {
172  const void* bias_ptr = nullptr;
175  };
176 
178  {
180  };
181 
183  {
184  // alibi is batch*nhead*1, no matter in batch/group mode, they are the same
185  const void* alibi_slope_ptr;
186  ck_tile::index_t alibi_slope_stride; // stride in batch, or 0 for all batch share same slope
187  };
188 
190  {
191  // ck_tile::index_t window_size_left, window_size_right;
194  };
195 
197  {
198  float scale_p;
199  float scale_o;
200  };
201 
203  {
204  void* lse_ptr = nullptr;
207  };
208 
210  {
212  };
213 
215  {
219  };
220 
222  {
223  bool is_gappy = false;
224  };
225 
227  {
229  };
230 
233  std::conditional_t<BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS,
234  FmhaFwdBatchModeBiasKargs,
235  std::conditional_t<BiasEnum == BlockAttentionBiasEnum::ALIBI,
236  FmhaFwdAlibiKargs,
237  FmhaFwdEmptyKargs<0>>>,
238  std::conditional_t<kHasMask, FmhaFwdMaskKargs, FmhaFwdEmptyKargs<1>>,
239  std::conditional_t<kStoreLSE, FmhaFwdCommonLSEKargs, FmhaFwdEmptyKargs<2>>,
240  std::conditional_t<kDoFp8StaticQuant, FmhaFwdFp8StaticQuantKargs, FmhaFwdEmptyKargs<3>>,
241  std::conditional_t<kIsPagedKV, CommonPageBlockTableKargs, CacheBatchIdxKargs>,
242  std::conditional_t<kHasLogitsSoftCap, FmhaFwdLogitsSoftCapKargs, FmhaFwdEmptyKargs<4>>
243  {
245 
247  ck_tile::index_t batch_stride_k; // when using paged-kvcache, this will be stride/size for
248  // single kcache page-block
249  ck_tile::index_t batch_stride_v; // when using paged-kvcache, this will be stride/size for
250  // single vcache page-block
252  };
253 
256  std::conditional_t<BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS,
257  FmhaFwdCommonBiasKargs,
258  std::conditional_t<BiasEnum == BlockAttentionBiasEnum::ALIBI,
259  FmhaFwdAlibiKargs,
260  FmhaFwdEmptyKargs<0>>>,
261  std::conditional_t<kHasMask, FmhaFwdMaskKargs, FmhaFwdEmptyKargs<1>>,
262  std::conditional_t<kStoreLSE, FmhaFwdCommonLSEKargs, FmhaFwdEmptyKargs<2>>,
263  std::conditional_t<kDoFp8StaticQuant, FmhaFwdFp8StaticQuantKargs, FmhaFwdEmptyKargs<3>>,
264  std::conditional_t<kHasLogitsSoftCap, FmhaFwdLogitsSoftCapKargs, FmhaFwdEmptyKargs<4>>,
265  std::conditional_t<kIsPagedKV, GroupModePageBlockTableKargs, FmhaFwdEmptyKargs<5>>,
266  std::conditional_t<kSkipMinSeqlenQ, FmhaFwdSkipMinSeqlenQKargs, FmhaFwdEmptyKargs<6>>
267  {
271 
272  ck_tile::index_t batch_stride_k; // only used for paged-kvcache, this will be stride/size
273  // for single kcache page-block
274  ck_tile::index_t batch_stride_v; // only used for paged-kvcache, this will be stride/size
275  // for single vcache page-block
276  };
277 
278  using Kargs = std::conditional_t<kIsGroupMode, FmhaFwdGroupModeKargs, FmhaFwdBatchModeKargs>;
279 
281  {
285  };
286 
287  template <bool Cond = !kIsGroupMode>
288  CK_TILE_HOST static constexpr std::enable_if_t<Cond, Kargs>
289  MakeKargsImpl(const void* q_ptr,
290  const void* k_ptr,
291  const void* v_ptr,
292  const void* bias_ptr,
293  void* lse_ptr,
294  void* o_ptr,
295  ck_tile::index_t seqlen_q,
296  ck_tile::index_t seqlen_k,
297  const void* seqlen_k_ptr, // only used for (paged-) kvcache
298  ck_tile::index_t hdim_q,
299  ck_tile::index_t hdim_v,
300  ck_tile::index_t num_head_q,
301  ck_tile::index_t nhead_ratio_qk,
302  const void* block_table_ptr,
303  ck_tile::index_t batch_stride_block_table,
304  ck_tile::index_t page_block_size,
305  const void* cache_batch_idx,
306  float scale_s,
307  float scale_p,
308  float scale_o,
309  float logits_soft_cap,
310  ck_tile::index_t stride_q,
311  ck_tile::index_t stride_k,
312  ck_tile::index_t stride_v,
313  ck_tile::index_t stride_bias,
314  ck_tile::index_t stride_o,
315  ck_tile::index_t nhead_stride_q,
316  ck_tile::index_t nhead_stride_k,
317  ck_tile::index_t nhead_stride_v,
318  ck_tile::index_t nhead_stride_bias,
319  ck_tile::index_t nhead_stride_lse,
320  ck_tile::index_t nhead_stride_o,
321  ck_tile::index_t batch_stride_q,
322  ck_tile::index_t batch_stride_k,
323  ck_tile::index_t batch_stride_v,
324  ck_tile::index_t batch_stride_bias,
325  ck_tile::index_t batch_stride_lse,
326  ck_tile::index_t batch_stride_o,
327  ck_tile::index_t window_size_left,
328  ck_tile::index_t window_size_right,
329  ck_tile::index_t mask_type)
330  {
331  Kargs kargs{{q_ptr,
332  k_ptr,
333  v_ptr,
334  o_ptr,
335  seqlen_q,
336  seqlen_k,
337  hdim_q,
338  hdim_v,
339  num_head_q,
340  nhead_ratio_qk,
341 #if CK_TILE_FMHA_FWD_FAST_EXP2
342  static_cast<float>(scale_s * ck_tile::log2e_v<>),
343 #else
344  scale_s,
345 #endif
346  stride_q,
347  stride_k,
348  stride_v,
349  stride_o,
350  nhead_stride_q,
351  nhead_stride_k,
352  nhead_stride_v,
353  nhead_stride_o}, // args for common karg
354  {}, // placeholder for bias
355  {}, // placeholder for mask
356  {}, // placeholder for lse
357  {}, // placeholder for fp8_static_quant args
358  {}, // placeholder for pagedkv
359  {}, // placeholder for logits_soft_cap
360  reinterpret_cast<const int32_t*>(seqlen_k_ptr),
361  batch_stride_q,
362  batch_stride_k,
363  batch_stride_v,
364  batch_stride_o};
365 
367  {
368  kargs.bias_ptr = bias_ptr;
369  kargs.stride_bias = stride_bias;
370  kargs.nhead_stride_bias = nhead_stride_bias;
371  kargs.batch_stride_bias = batch_stride_bias;
372  }
373  else if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI)
374  {
375  kargs.alibi_slope_ptr = bias_ptr;
376  kargs.alibi_slope_stride = stride_bias;
377  }
378  if constexpr(kHasMask)
379  {
380  kargs.window_size_left = window_size_left;
381  kargs.window_size_right = window_size_right;
382  kargs.mask_type = static_cast<ck_tile::GenericAttentionMaskEnum>(mask_type);
383  }
384  if constexpr(kStoreLSE)
385  {
386  kargs.lse_ptr = lse_ptr;
387  kargs.nhead_stride_lse = nhead_stride_lse;
388  kargs.batch_stride_lse = batch_stride_lse;
389  }
390  if constexpr(kDoFp8StaticQuant)
391  {
392  kargs.scale_p = scale_p;
393  kargs.scale_o = scale_o;
394  }
395  if constexpr(kIsPagedKV)
396  {
397  kargs.block_table_ptr = reinterpret_cast<const int32_t*>(block_table_ptr);
398  kargs.batch_stride_block_table = batch_stride_block_table;
399  kargs.page_block_size = page_block_size;
400  }
401  else
402  {
403  kargs.cache_batch_idx = reinterpret_cast<const int32_t*>(cache_batch_idx);
404  }
405  if constexpr(kHasLogitsSoftCap)
406  {
407  kargs.init_logits_soft_cap(logits_soft_cap);
408  }
409 
410  return kargs;
411  }
412 
413  // std::variant<> can't take in a list initializer, overload for backward compatibility
414  template <bool Cond = !kIsGroupMode>
415  CK_TILE_HOST static constexpr std::enable_if_t<Cond, Kargs>
416  MakeKargs(const void* q_ptr,
417  const void* k_ptr,
418  const void* v_ptr,
419  const void* bias_ptr,
420  void* lse_ptr,
421  void* o_ptr,
422  ck_tile::index_t seqlen_q,
423  ck_tile::index_t seqlen_k,
424  const void* seqlen_k_ptr, // only used for (paged-) kvcache
425  ck_tile::index_t hdim_q,
426  ck_tile::index_t hdim_v,
427  ck_tile::index_t num_head_q,
428  ck_tile::index_t nhead_ratio_qk,
429  const void* block_table_ptr,
430  ck_tile::index_t batch_stride_block_table,
431  ck_tile::index_t page_block_size,
432  const void* cache_batch_idx,
433  float scale_s,
434  float scale_p,
435  float scale_o,
436  float logits_soft_cap,
437  ck_tile::index_t stride_q,
438  ck_tile::index_t stride_k,
439  ck_tile::index_t stride_v,
440  ck_tile::index_t stride_bias,
441  ck_tile::index_t stride_o,
442  ck_tile::index_t nhead_stride_q,
443  ck_tile::index_t nhead_stride_k,
444  ck_tile::index_t nhead_stride_v,
445  ck_tile::index_t nhead_stride_bias,
446  ck_tile::index_t nhead_stride_lse,
447  ck_tile::index_t nhead_stride_o,
448  ck_tile::index_t batch_stride_q,
449  ck_tile::index_t batch_stride_k,
450  ck_tile::index_t batch_stride_v,
451  ck_tile::index_t batch_stride_bias,
452  ck_tile::index_t batch_stride_lse,
453  ck_tile::index_t batch_stride_o,
454  ck_tile::index_t window_size_left,
455  ck_tile::index_t window_size_right,
456  ck_tile::index_t mask_type)
457  {
458  return MakeKargsImpl(q_ptr,
459  k_ptr,
460  v_ptr,
461  bias_ptr,
462  lse_ptr,
463  o_ptr,
464  seqlen_q,
465  seqlen_k,
466  seqlen_k_ptr,
467  hdim_q,
468  hdim_v,
469  num_head_q,
470  nhead_ratio_qk,
471  block_table_ptr,
472  batch_stride_block_table,
473  page_block_size,
474  cache_batch_idx,
475  scale_s,
476  scale_p,
477  scale_o,
478  logits_soft_cap,
479  stride_q,
480  stride_k,
481  stride_v,
482  stride_bias,
483  stride_o,
484  nhead_stride_q,
485  nhead_stride_k,
486  nhead_stride_v,
487  nhead_stride_bias,
488  nhead_stride_lse,
489  nhead_stride_o,
490  batch_stride_q,
491  batch_stride_k,
492  batch_stride_v,
493  batch_stride_bias,
494  batch_stride_lse,
495  batch_stride_o,
496  window_size_left,
497  window_size_right,
498  mask_type);
499  }
500 
501  template <bool Cond = kIsGroupMode>
502  CK_TILE_HOST static constexpr std::enable_if_t<Cond, Kargs>
503  MakeKargsImpl(const void* q_ptr,
504  const void* k_ptr,
505  const void* v_ptr,
506  const void* bias_ptr,
507  void* lse_ptr,
508  void* o_ptr,
509  const void* seqstart_q_ptr,
510  const void* seqstart_k_ptr,
511  const void* seqlen_k_ptr,
512  ck_tile::index_t hdim_q,
513  ck_tile::index_t hdim_v,
514  ck_tile::index_t num_head_q,
515  ck_tile::index_t nhead_ratio_qk,
516  const void* block_table_ptr,
517  ck_tile::index_t batch_stride_block_table,
518  ck_tile::index_t page_block_size,
519  bool is_gappy,
520  float scale_s,
521  float scale_p,
522  float scale_o,
523  float logits_soft_cap,
524  ck_tile::index_t stride_q,
525  ck_tile::index_t stride_k,
526  ck_tile::index_t stride_v,
527  ck_tile::index_t stride_bias,
528  ck_tile::index_t stride_o,
529  ck_tile::index_t nhead_stride_q,
530  ck_tile::index_t nhead_stride_k,
531  ck_tile::index_t nhead_stride_v,
532  ck_tile::index_t nhead_stride_bias,
533  ck_tile::index_t nhead_stride_lse,
534  ck_tile::index_t nhead_stride_o,
535  ck_tile::index_t batch_stride_k, // only used for paged-kvcache
536  ck_tile::index_t batch_stride_v, // only used for paged-kvcache
537  ck_tile::index_t window_size_left,
538  ck_tile::index_t window_size_right,
539  ck_tile::index_t mask_type,
540  ck_tile::index_t min_seqlen_q)
541  {
542  Kargs kargs{{q_ptr,
543  k_ptr,
544  v_ptr,
545  o_ptr,
546  -1, // seqlen will be updated by another pointer
547  -1, //
548  hdim_q,
549  hdim_v,
550  num_head_q,
551  nhead_ratio_qk,
552 #if CK_TILE_FMHA_FWD_FAST_EXP2
553  static_cast<float>(scale_s * ck_tile::log2e_v<>),
554 #else
555  scale_s,
556 #endif
557  stride_q,
558  stride_k,
559  stride_v,
560  stride_o,
561  nhead_stride_q,
562  nhead_stride_k,
563  nhead_stride_v,
564  nhead_stride_o}, // args for common karg
565  {}, // placeholder for bias
566  {}, // placeholder for mask
567  {}, // placeholder for lse
568  {}, // placeholder for fp8_static_quant args
569  {}, // placeholder for logits_soft_cap
570  {}, // placeholder for pagdkv
571  {}, // placeholder for min_seqlen_q
572  reinterpret_cast<const int32_t*>(seqstart_q_ptr),
573  reinterpret_cast<const int32_t*>(seqstart_k_ptr),
574  reinterpret_cast<const int32_t*>(seqlen_k_ptr),
575  batch_stride_k,
576  batch_stride_v};
577 
579  {
580  kargs.bias_ptr = bias_ptr;
581  kargs.stride_bias = stride_bias;
582  kargs.nhead_stride_bias = nhead_stride_bias;
583  }
584  else if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI)
585  {
586  kargs.alibi_slope_ptr = bias_ptr;
587  kargs.alibi_slope_stride = stride_bias;
588  }
589  if constexpr(kHasMask)
590  {
591  kargs.window_size_left = window_size_left;
592  kargs.window_size_right = window_size_right;
593  kargs.mask_type = static_cast<ck_tile::GenericAttentionMaskEnum>(mask_type);
594  }
595  if constexpr(kStoreLSE)
596  {
597  kargs.lse_ptr = lse_ptr;
598  kargs.nhead_stride_lse = nhead_stride_lse;
599  }
600  if constexpr(kDoFp8StaticQuant)
601  {
602  kargs.scale_p = scale_p;
603  kargs.scale_o = scale_o;
604  }
605  if constexpr(kHasLogitsSoftCap)
606  {
607  kargs.init_logits_soft_cap(logits_soft_cap);
608  }
609  if constexpr(kIsPagedKV)
610  {
611  kargs.block_table_ptr = reinterpret_cast<const int32_t*>(block_table_ptr);
612  kargs.batch_stride_block_table = batch_stride_block_table;
613  kargs.page_block_size = page_block_size;
614  kargs.is_gappy = is_gappy;
615  }
616  if constexpr(kSkipMinSeqlenQ)
617  {
618  kargs.min_seqlen_q = min_seqlen_q;
619  }
620 
621  return kargs;
622  }
623 
624  // std::variant<> can't take in a list initializer, overload for backward compatibility
625  template <bool Cond = kIsGroupMode>
626  CK_TILE_HOST static constexpr std::enable_if_t<Cond, Kargs>
627  MakeKargs(const void* q_ptr,
628  const void* k_ptr,
629  const void* v_ptr,
630  const void* bias_ptr,
631  void* lse_ptr,
632  void* o_ptr,
633  const void* seqstart_q_ptr,
634  const void* seqstart_k_ptr,
635  const void* seqlen_k_ptr,
636  ck_tile::index_t hdim_q,
637  ck_tile::index_t hdim_v,
638  ck_tile::index_t num_head_q,
639  ck_tile::index_t nhead_ratio_qk,
640  const void* block_table_ptr,
641  ck_tile::index_t batch_stride_block_table,
642  ck_tile::index_t page_block_size,
643  bool is_gappy,
644  float scale_s,
645  float scale_p,
646  float scale_o,
647  float logits_soft_cap,
648  ck_tile::index_t stride_q,
649  ck_tile::index_t stride_k,
650  ck_tile::index_t stride_v,
651  ck_tile::index_t stride_bias,
652  ck_tile::index_t stride_o,
653  ck_tile::index_t nhead_stride_q,
654  ck_tile::index_t nhead_stride_k,
655  ck_tile::index_t nhead_stride_v,
656  ck_tile::index_t nhead_stride_bias,
657  ck_tile::index_t nhead_stride_lse,
658  ck_tile::index_t nhead_stride_o,
659  ck_tile::index_t batch_stride_k, // only used for paged-kvcache
660  ck_tile::index_t batch_stride_v, // only used for paged-kvcache
661  ck_tile::index_t window_size_left,
662  ck_tile::index_t window_size_right,
663  ck_tile::index_t mask_type,
664  ck_tile::index_t min_seqlen_q)
665  {
666  return MakeKargsImpl(q_ptr,
667  k_ptr,
668  v_ptr,
669  bias_ptr,
670  lse_ptr,
671  o_ptr,
672  seqstart_q_ptr,
673  seqstart_k_ptr,
674  seqlen_k_ptr,
675  hdim_q,
676  hdim_v,
677  num_head_q,
678  nhead_ratio_qk,
679  block_table_ptr,
680  batch_stride_block_table,
681  page_block_size,
682  is_gappy,
683  scale_s,
684  scale_p,
685  scale_o,
686  logits_soft_cap,
687  stride_q,
688  stride_k,
689  stride_v,
690  stride_bias,
691  stride_o,
692  nhead_stride_q,
693  nhead_stride_k,
694  nhead_stride_v,
695  nhead_stride_bias,
696  nhead_stride_lse,
697  nhead_stride_o,
698  batch_stride_k,
699  batch_stride_v,
700  window_size_left,
701  window_size_right,
702  mask_type,
703  min_seqlen_q);
704  }
705 
706  CK_TILE_HOST static void PrintParameters(const Kargs& kargs, int num_batches)
707  {
708  static bool dummy = [&]() {
709  std::cout << std::endl;
710 
711  std::cout << " q_ptr: " << kargs.q_ptr << " k_ptr:" << kargs.k_ptr
712  << " v_ptr: " << kargs.v_ptr << " o_ptr:" << kargs.o_ptr
713  << " hdim_q: " << kargs.hdim_q << " hdim_v: " << kargs.hdim_v
714  << " num_head_q:" << kargs.num_head_q
715  << " nhead_ratio_qk: " << kargs.nhead_ratio_qk << " scale_s:" << kargs.scale_s
716  << " stride_q:" << kargs.stride_q << " stride_k:" << kargs.stride_k
717  << " stride_v:" << kargs.stride_v << " stride_o:" << kargs.stride_o
718  << " nhead_stride_q: " << kargs.nhead_stride_q
719  << " nhead_stride_k: " << kargs.nhead_stride_k
720  << " nhead_stride_v:" << kargs.nhead_stride_v
721  << " nhead_stride_o: " << kargs.nhead_stride_o;
722  if constexpr(!kIsGroupMode)
723  {
724  std::cout << " batch_stride_q:" << kargs.batch_stride_q;
725  }
726  std::cout << " batch_stride_k:" << kargs.batch_stride_k
727  << " batch_stride_v:" << kargs.batch_stride_v;
728 
729  if constexpr(kIsGroupMode)
730  {
731  if constexpr(kSkipMinSeqlenQ)
732  {
733  std::cout << " min_seqlen_q: " << kargs.min_seqlen_q;
734  }
735 
736  std::cout << " seqstart_q_ptr:" << kargs.seqstart_q_ptr
737  << " seqstart_k_ptr: " << kargs.seqstart_k_ptr
738  << " seqlen_k_ptr:" << kargs.seqlen_k_ptr;
739  if(kargs.seqlen_k_ptr != nullptr)
740  {
741  std::cout << "{";
742  for(int i_batch = 0; i_batch < num_batches; i_batch++)
743  std::cout << kargs.seqlen_k_ptr[i_batch] << ",";
744  std::cout << "}";
745  }
746  }
747  if constexpr(kHasMask)
748  {
749  std::cout << " window_size_left: " << kargs.window_size_left
750  << " window_size_right:" << kargs.window_size_right
751  << " mask_type: " << static_cast<int>(kargs.mask_type);
752  }
753 
754  if constexpr(kIsPagedKV)
755  {
756  std::cout << " block_table_ptr: " << kargs.block_table_ptr
757  << " batch_stride_block_table:" << kargs.batch_stride_block_table
758  << " page_block_size: " << kargs.page_block_size;
759 
760  std::cout << "table value: [";
761  for(int b = 0; b < num_batches; b++)
762  {
763  std::cout << "[ ";
764  for(int i = 0; i < kargs.batch_stride_block_table; i++)
765  {
766  std::cout << kargs.block_table_ptr[b * kargs.batch_stride_block_table + i]
767  << ",";
768  }
769  std::cout << " ]";
770  }
771  std::cout << " ]";
772  }
773  std::cout << std::endl;
774  return true;
775  }();
776  (void)dummy;
777  }
778  CK_TILE_HOST static constexpr auto GridSize(ck_tile::index_t batch_size_,
779  ck_tile::index_t nhead_,
780  ck_tile::index_t seqlen_q_,
781  ck_tile::index_t hdim_v_,
782  bool has_padded_seqlen_k)
783  {
784  // has_padded_seqlen_k is determined by checking (seqlen_k_ptr != nullptr)
785  if(has_padded_seqlen_k)
786  {
787  // TODO: this may need tuning
788  return dim3(nhead_,
789  batch_size_,
790  ck_tile::integer_divide_ceil(seqlen_q_, FmhaPipeline::kM0) *
791  ck_tile::integer_divide_ceil(hdim_v_, FmhaPipeline::kN1));
792  }
793  else
794  {
795  // TODO: this may need tuning
796  return dim3(ck_tile::integer_divide_ceil(seqlen_q_, FmhaPipeline::kM0) *
797  ck_tile::integer_divide_ceil(hdim_v_, FmhaPipeline::kN1),
798  nhead_,
799  batch_size_);
800  }
801  }
802 
803  CK_TILE_DEVICE static constexpr auto GetTileIndex(const Kargs& kargs)
804  {
805  bool has_padded_seqlen_k = false;
806 
807  if constexpr(kIsGroupMode)
808  has_padded_seqlen_k = (kargs.seqlen_k_ptr != nullptr);
809 
810  if(has_padded_seqlen_k)
811  {
812  // const index_t num_tile_m0 = seqlen_q / kM0;
813  const index_t num_tile_n1 =
814  ck_tile::integer_divide_ceil(kargs.hdim_v, FmhaPipeline::kN1);
815 
816  const index_t i_block = blockIdx.z;
817  const index_t i_nhead = blockIdx.x;
818  const index_t i_batch = blockIdx.y;
819 
820  const auto f = [](index_t dividend, index_t divisor) {
821  index_t quotient = dividend / divisor;
822  index_t modulus = dividend - quotient * divisor;
823  return ck_tile::make_tuple(quotient, modulus);
824  };
825 
826  const auto [i_tile_m, i_tile_n] = f(i_block, num_tile_n1);
827 
828  if constexpr(kHasMask)
829  {
830  // assume that num_tile_n1 is always 1
831  return ck_tile::make_tuple(gridDim.z - 1 - i_tile_m, i_tile_n, i_nhead, i_batch);
832  }
833  else
834  {
835  return ck_tile::make_tuple(i_tile_m, i_tile_n, i_nhead, i_batch);
836  }
837  }
838  else
839  {
840  // const index_t num_tile_m0 = seqlen_q / kM0;
841  const index_t num_tile_n1 =
842  ck_tile::integer_divide_ceil(kargs.hdim_v, FmhaPipeline::kN1);
843 
844  const index_t i_block = blockIdx.x;
845  const index_t i_nhead = blockIdx.y;
846  const index_t i_batch = blockIdx.z;
847 
848  const auto f = [](index_t dividend, index_t divisor) {
849  index_t quotient = dividend / divisor;
850  index_t modulus = dividend - quotient * divisor;
851  return ck_tile::make_tuple(quotient, modulus);
852  };
853 
854  const auto [i_tile_m, i_tile_n] = f(i_block, num_tile_n1);
855 
856  if constexpr(kHasMask)
857  {
858  // assume that num_tile_n1 is always 1
859  return ck_tile::make_tuple(gridDim.x - 1 - i_tile_m, i_tile_n, i_nhead, i_batch);
860  }
861  else
862  {
863  return ck_tile::make_tuple(i_tile_m, i_tile_n, i_nhead, i_batch);
864  }
865  }
866  }
867 
868  CK_TILE_HOST static constexpr auto BlockSize() { return dim3(kBlockSize); }
869 
871  {
872  return ck_tile::max(FmhaPipeline::GetSmemSize(), EpiloguePipeline::GetSmemSize());
873  }
874 
875  CK_TILE_DEVICE void operator()(Kargs kargs) const
876  {
877  // allocate LDS
878  __shared__ char smem_ptr[GetSmemSize()];
879 
880  // divide problem
881  const auto [i_tile_m, i_tile_n, i_nhead, i_batch] = GetTileIndex(kargs);
882 
883  const index_t i_m0 = __builtin_amdgcn_readfirstlane(i_tile_m * FmhaPipeline::kM0);
884  const index_t i_n1 = __builtin_amdgcn_readfirstlane(i_tile_n * FmhaPipeline::kN1);
885 
886  long_index_t batch_offset_q = 0;
887  long_index_t batch_offset_k = 0;
888  long_index_t batch_offset_v = 0;
889  long_index_t batch_offset_bias = 0;
890  long_index_t batch_offset_lse = 0;
891  long_index_t batch_offset_o = 0;
892  index_t kv_l2p_offset = 0;
893 
894  if constexpr(kIsGroupMode)
895  {
896  // get starting offset for each batch
897  const long_index_t query_start = kargs.seqstart_q_ptr[i_batch];
898  const long_index_t key_start = kargs.seqstart_k_ptr[i_batch];
899 
900  batch_offset_q = query_start * kargs.stride_q;
901  batch_offset_k = key_start * kargs.stride_k;
902  if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
903  {
904  batch_offset_v = key_start * kargs.stride_v;
905  }
906  else
907  {
908  batch_offset_v = key_start;
909  }
911  {
912  batch_offset_bias = query_start * kargs.stride_bias;
913  }
914  if constexpr(kStoreLSE)
915  {
916  batch_offset_lse = query_start;
917  }
918 
919  batch_offset_o = query_start * kargs.stride_o;
920 
921  // get real # queries & # keys under group mode
922  const auto adjusted_seqstart_q_ptr = kargs.seqstart_q_ptr + i_batch;
923  kargs.seqlen_q = adjusted_seqstart_q_ptr[1] - adjusted_seqstart_q_ptr[0];
924 
925  if constexpr(kSkipMinSeqlenQ)
926  {
927  if(kargs.seqlen_q <= kargs.min_seqlen_q)
928  {
929  return;
930  }
931  }
932 
933  // # of required blocks is different in each groups, terminate unnecessary blocks
934  // earlier
935  if(kargs.seqlen_q <= i_m0)
936  {
937  return;
938  }
939 
940  if(kargs.seqlen_k_ptr != nullptr)
941  {
942  kargs.seqlen_k = kargs.seqlen_k_ptr[i_batch];
943  }
944  else
945  {
946  const auto adjusted_seqstart_k_ptr = kargs.seqstart_k_ptr + i_batch;
947  kargs.seqlen_k = adjusted_seqstart_k_ptr[1] - adjusted_seqstart_k_ptr[0];
948  }
949 
950  if constexpr(kIsPagedKV)
951  {
952  if(kargs.is_gappy)
953  {
954  // seqstart_k_ptr has different meaning in this case
955  kv_l2p_offset = kargs.seqstart_k_ptr[i_batch];
956  }
957  }
958  }
959  else
960  {
961  const index_t i_cache_batch = [&, i_batch_ = i_batch] {
962  if constexpr(kIsPagedKV)
963  {
964  return i_batch_;
965  }
966  else
967  {
968  return (kargs.cache_batch_idx != nullptr ? kargs.cache_batch_idx[i_batch_]
969  : i_batch_);
970  }
971  }();
972 
973  batch_offset_q = static_cast<long_index_t>(i_batch) * kargs.batch_stride_q;
974  batch_offset_k = static_cast<long_index_t>(i_cache_batch) * kargs.batch_stride_k;
975  batch_offset_v = static_cast<long_index_t>(i_cache_batch) * kargs.batch_stride_v;
977  {
978  batch_offset_bias = static_cast<long_index_t>(i_batch) * kargs.batch_stride_bias;
979  }
980  if constexpr(kStoreLSE)
981  {
982  batch_offset_lse = static_cast<long_index_t>(i_batch) * kargs.batch_stride_lse;
983  }
984 
985  batch_offset_o = static_cast<long_index_t>(i_batch) * kargs.batch_stride_o;
986 
987  if(kargs.seqlen_k_ptr != nullptr)
988  {
989  kargs.seqlen_k = kargs.seqlen_k_ptr[i_batch];
990  }
991  }
992 
993  // for simplicity, batch stride we just modify the pointer
994  const QDataType* q_ptr = reinterpret_cast<const QDataType*>(kargs.q_ptr) +
995  static_cast<long_index_t>(i_nhead) * kargs.nhead_stride_q +
996  batch_offset_q;
997  const KDataType* k_ptr =
998  reinterpret_cast<const KDataType*>(kargs.k_ptr) +
999  static_cast<long_index_t>(i_nhead / kargs.nhead_ratio_qk) * kargs.nhead_stride_k +
1000  batch_offset_k;
1001  const VDataType* v_ptr =
1002  reinterpret_cast<const VDataType*>(kargs.v_ptr) +
1003  static_cast<long_index_t>(i_nhead / kargs.nhead_ratio_qk) * kargs.nhead_stride_v +
1004  batch_offset_v;
1005  ODataType* o_ptr = reinterpret_cast<ODataType*>(kargs.o_ptr) +
1006  static_cast<long_index_t>(i_nhead) * kargs.nhead_stride_o +
1007  batch_offset_o;
1008 
1009  // Q/K/V DRAM and DRAM window
1010  const auto q_dram = [&]() {
1011  const auto q_dram_naive = make_naive_tensor_view<address_space_enum::global>(
1012  q_ptr,
1013  make_tuple(kargs.seqlen_q, kargs.hdim_q),
1014  make_tuple(kargs.stride_q, 1),
1016  number<1>{});
1017  if constexpr(FmhaPipeline::kQLoadOnce)
1018  {
1019  return pad_tensor_view(
1020  q_dram_naive,
1023  }
1024  else
1025  {
1026  return pad_tensor_view(
1027  q_dram_naive,
1030  }
1031  }();
1032 
1033  const auto make_k_dram = [&](const KDataType* data, index_t height) {
1034  const auto k_dram_naive = make_naive_tensor_view<address_space_enum::global>(
1035  data, // will update this pointer if using paged-kvcache
1036  make_tuple(height, kargs.hdim_q),
1037  make_tuple(kargs.stride_k, 1),
1039  number<1>{});
1040 
1041  return pad_tensor_view(
1042  k_dram_naive,
1045  };
1046  const auto k_dram = [&]() {
1047  if constexpr(kIsPagedKV)
1048  {
1049  return make_k_dram(nullptr, kargs.page_block_size);
1050  }
1051  else
1052  {
1053  return make_k_dram(k_ptr, kargs.seqlen_k);
1054  }
1055  }();
1056 
1057  const auto make_v_dram = [&](const VDataType* data, index_t length) {
1058  if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
1059  {
1060  const auto v_dram_naive = make_naive_tensor_view<address_space_enum::global>(
1061  data, // will update this pointer if using paged-kvcache
1062  make_tuple(length, kargs.hdim_v),
1063  make_tuple(kargs.stride_v, 1),
1065  number<1>{});
1066 
1067  const auto v_dram_transposed =
1068  transform_tensor_view(v_dram_naive,
1070  make_pass_through_transform(length)),
1073 
1074  return pad_tensor_view(
1075  v_dram_transposed,
1078  }
1079  else
1080  {
1081  const auto v_dram_naive = make_naive_tensor_view<address_space_enum::global>(
1082  data, // will update this pointer if using paged-kvcache
1083  make_tuple(kargs.hdim_v, length),
1084  make_tuple(kargs.stride_v, 1),
1086  number<1>{});
1087 
1088  return pad_tensor_view(
1089  v_dram_naive,
1092  }
1093  };
1094  const auto v_dram = [&]() {
1095  if constexpr(kIsPagedKV)
1096  {
1097  return make_v_dram(nullptr, kargs.page_block_size);
1098  }
1099  else
1100  {
1101  return make_v_dram(v_ptr, kargs.seqlen_k);
1102  }
1103  }();
1104 
1105  auto q_dram_window = make_tile_window(
1106  q_dram,
1107  [&]() {
1108  if constexpr(FmhaPipeline::kQLoadOnce)
1111  else
1113  }(),
1114  {i_m0, 0});
1115 
1116  auto k_page_block_navigator =
1117  [&, i_batch_ = i_batch, i_nhead_ = i_nhead / kargs.nhead_ratio_qk]() {
1118  if constexpr(kIsPagedKV)
1119  {
1120  const auto* block_indices =
1121  reinterpret_cast<const int32_t*>(kargs.block_table_ptr) +
1122  i_batch_ * kargs.batch_stride_block_table;
1123  const index_t num_blocks =
1124  integer_divide_ceil(kv_l2p_offset + kargs.seqlen_k, kargs.page_block_size);
1125 
1126  const long_index_t fixed_offset =
1127  static_cast<long_index_t>(i_nhead_) * kargs.nhead_stride_k;
1128 
1129  return make_page_block_navigator<const KDataType, 0>(
1130  kargs.k_ptr,
1131  kargs.batch_stride_k, // kcache page-block stride/size
1132  fixed_offset,
1133  block_indices,
1134  num_blocks,
1135  kargs.page_block_size,
1136  k_dram,
1137  make_k_dram(nullptr,
1138  (kv_l2p_offset + kargs.seqlen_k) -
1139  (num_blocks - 1) * kargs.page_block_size));
1140  }
1141  else
1142  {
1143  return make_page_block_navigator(k_dram);
1144  }
1145  }();
1146 
1147  auto v_page_block_navigator =
1148  [&, i_batch_ = i_batch, i_nhead_ = i_nhead / kargs.nhead_ratio_qk]() {
1149  if constexpr(kIsPagedKV)
1150  {
1151  const auto* block_indices =
1152  reinterpret_cast<const int32_t*>(kargs.block_table_ptr) +
1153  i_batch_ * kargs.batch_stride_block_table;
1154  const index_t num_blocks =
1155  integer_divide_ceil(kv_l2p_offset + kargs.seqlen_k, kargs.page_block_size);
1156 
1157  const long_index_t fixed_offset =
1158  static_cast<long_index_t>(i_nhead_) * kargs.nhead_stride_v;
1159 
1160  return make_page_block_navigator<const VDataType, 1>(
1161  kargs.v_ptr,
1162  kargs.batch_stride_v, // vcache page-block stride/size
1163  fixed_offset,
1164  block_indices,
1165  num_blocks,
1166  kargs.page_block_size,
1167  v_dram,
1168  make_v_dram(nullptr,
1169  (kv_l2p_offset + kargs.seqlen_k) -
1170  (num_blocks - 1) * kargs.page_block_size));
1171  }
1172  else
1173  {
1174  return make_page_block_navigator(v_dram);
1175  }
1176  }();
1177 
1178  auto k_dram_window_lengths =
1180  auto v_dram_window_lengths =
1182 
1185  const auto bias_dram_window = [&, i_nhead_ = i_nhead]() {
1186  constexpr auto bias_dram_window_lengths =
1189  {
1190  const BiasDataType* bias_ptr =
1191  reinterpret_cast<const BiasDataType*>(kargs.bias_ptr) +
1192  static_cast<long_index_t>(i_nhead_) * kargs.nhead_stride_bias +
1193  batch_offset_bias;
1194 
1195  const auto bias_dram = [&]() {
1196  const auto bias_dram_naive = make_naive_tensor_view<address_space_enum::global>(
1197  bias_ptr,
1198  make_tuple(kargs.seqlen_q, kargs.seqlen_k),
1199  make_tuple(kargs.stride_bias, 1),
1201  number<1>{});
1202 
1203  return pad_tensor_view(bias_dram_naive,
1204  bias_dram_window_lengths,
1206  }();
1207 
1208  return make_tile_window(bias_dram, bias_dram_window_lengths, {i_m0, 0});
1209  }
1210  else
1211  {
1212  return make_null_tile_window(bias_dram_window_lengths);
1213  }
1214  }();
1215 
1216  // lse
1217  auto lse_dram_window = [&, i_nhead_ = i_nhead]() {
1218  constexpr auto lse_dram_window_lengths = make_tuple(number<FmhaPipeline::kM0>{});
1219  if constexpr(kStoreLSE)
1220  {
1221  LSEDataType* lse_ptr =
1222  reinterpret_cast<LSEDataType*>(kargs.lse_ptr) +
1223  static_cast<long_index_t>(i_nhead_) * kargs.nhead_stride_lse + batch_offset_lse;
1224 
1225  const auto lse_dram = [&]() {
1226  const auto lse_dram_naive = make_naive_tensor_view<address_space_enum::global>(
1227  lse_ptr,
1228  make_tuple(kargs.seqlen_q),
1229  make_tuple(1),
1230  number<1>{},
1231  number<1>{});
1232 
1233  return pad_tensor_view(
1234  lse_dram_naive, lse_dram_window_lengths, sequence<kPadSeqLenQ>{});
1235  }();
1236 
1237  return make_tile_window(lse_dram, lse_dram_window_lengths, {i_m0});
1238  }
1239  else
1240  {
1241  return make_null_tile_window(lse_dram_window_lengths);
1242  }
1243  }();
1244 
1245  FmhaMask mask = [&]() {
1246  if constexpr(kHasMask)
1247  return ck_tile::make_generic_attention_mask_from_lr_window<FmhaMask>(
1248  kargs.window_size_left,
1249  kargs.window_size_right,
1250  kargs.seqlen_q,
1251  kargs.seqlen_k,
1253  else
1254  return FmhaMask{kargs.seqlen_q, kargs.seqlen_k};
1255  }();
1256 
1257  // WA i_batch capture structure binding before c++20
1258  auto position_encoding = [&, i_batch_ = i_batch, i_nhead_ = i_nhead]() {
1259  if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI)
1260  {
1261  // data loading, shared by entire wg
1262  // TODO: how to use s_read?
1263  SaccDataType slope =
1264  *(reinterpret_cast<const SaccDataType*>(kargs.alibi_slope_ptr) +
1265  i_batch_ * kargs.alibi_slope_stride + i_nhead_);
1266 #if CK_TILE_FMHA_FWD_FAST_EXP2
1267  slope *= ck_tile::log2e_v<>;
1268 #endif
1269  if constexpr(kHasMask)
1270  {
1271  return make_alibi_from_lr_mask<SaccDataType, true>(slope,
1272  kargs.window_size_left,
1273  kargs.window_size_right,
1274  kargs.seqlen_q,
1275  kargs.seqlen_k,
1276  kargs.mask_type);
1277  }
1278  else
1279  {
1281  slope, kargs.seqlen_q, kargs.seqlen_k, AlibiMode::FROM_BOTTOM_RIGHT};
1282  }
1283  }
1284  else
1285  {
1287  }
1288  }();
1289 
1290  AttentionVariant variant;
1291  const auto variant_params = [&] {
1292  if constexpr(kHasLogitsSoftCap)
1293  {
1295  mask, kargs.scale_s, kargs.logits_soft_cap, kargs.logits_soft_cap_rcp};
1296  }
1297  else
1298  {
1299  return ck_tile::StandardAttentionParams<FmhaMask>{mask, kargs.scale_s};
1300  }
1301  }();
1302 
1303  BlockIndices block_indices{i_batch, i_nhead, i_nhead / kargs.nhead_ratio_qk};
1304 
1305  auto o_acc_tile = [&]() {
1306  if constexpr(kDoFp8StaticQuant)
1307  {
1308  return FmhaPipeline{}(
1309  q_dram_window,
1310  identity{}, // q_element_func
1311  k_dram_window_lengths,
1312  k_page_block_navigator,
1313  identity{}, // k_element_func
1314  v_dram_window_lengths,
1315  v_page_block_navigator,
1316  identity{}, // v_element_func
1317  bias_dram_window,
1318  identity{}, // bias_element_func
1319  lse_dram_window,
1320  identity{}, // lse_element_func
1321  identity{}, // s_acc_element_func
1322  scales{kargs.scale_p}, // p_compute_element_func
1323  composes(saturates<fp8_t>{}, scales{kargs.scale_o}), // o_acc_element_func
1324  mask,
1325  position_encoding,
1326  kargs.scale_s,
1327  variant,
1328  variant_params,
1329  block_indices,
1330  kv_l2p_offset,
1331  smem_ptr);
1332  }
1333  else
1334  {
1335  return FmhaPipeline{}(q_dram_window,
1336  k_dram_window_lengths,
1337  k_page_block_navigator,
1338  v_dram_window_lengths,
1339  v_page_block_navigator,
1340  bias_dram_window,
1341  lse_dram_window,
1342  mask,
1343  position_encoding,
1344  kargs.scale_s,
1345  variant,
1346  variant_params,
1347  block_indices,
1348  kv_l2p_offset,
1349  smem_ptr);
1350  }
1351  }();
1352 
1353  // O DRAM and O DRAM window
1354  auto o_dram = [&]() {
1355  const auto o_dram_naive = make_naive_tensor_view<address_space_enum::global>(
1356  o_ptr,
1357  make_tuple(kargs.seqlen_q, kargs.hdim_v),
1358  make_tuple(kargs.stride_o, 1),
1360  number<1>{});
1361  return pad_tensor_view(
1362  o_dram_naive,
1365  }();
1366 
1367  auto o_dram_window =
1368  make_tile_window(o_dram,
1370  {i_m0, i_n1});
1371 
1372  EpiloguePipeline{}(o_dram_window, o_acc_tile, nullptr);
1373  }
1374 };
1375 
1376 } // namespace ck_tile
#define CK_TILE_DEVICE
Definition: config.hpp:41
#define CK_TILE_HOST
Definition: config.hpp:40
#define CK_TILE_HOST_DEVICE
Definition: config.hpp:42
#define _TS_
#define _SS_
Definition: cluster_descriptor.hpp:13
constexpr CK_TILE_DEVICE auto make_null_tile_window(const WindowLengths &window_lengths)
Definition: null_tile_window.hpp:66
_BitInt(8) fp8_t
Definition: float8.hpp:204
constexpr CK_TILE_HOST_DEVICE auto integer_divide_ceil(X x, Y y)
Definition: math.hpp:149
_Float16 fp16_t
Definition: half.hpp:110
CK_TILE_HOST_DEVICE auto make_page_block_navigator(const TensorView &tensor_view)
Definition: page_block_navigator.hpp:333
constexpr CK_TILE_HOST_DEVICE auto transform_tensor_view(const OldTensorView &old_tensor_view, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition: tensor_view.hpp:511
bfloat16_t bf16_t
Definition: bfloat16.hpp:113
int32_t index_t
Definition: integer.hpp:9
constexpr CK_TILE_HOST_DEVICE auto pad_tensor_view(const TensorView &tensor_view, const TileLengths &tile_lengths, DoPads)
Definition: tensor_view.hpp:530
constexpr CK_TILE_HOST_DEVICE auto make_pass_through_transform(const LowLength &low_length)
Definition: coordinate_transform.hpp:1558
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition: type_traits.hpp:21
int64_t long_index_t
Definition: integer.hpp:11
int32_t int32_t
Definition: integer.hpp:10
constexpr CK_TILE_DEVICE auto make_tile_window(null_tensor_view, const WindowLengths &window_lengths, const multi_index< WindowLengths::size()> &, Ts &&...)
Definition: null_tile_window.hpp:75
unsigned _BitInt(8) bf8_t
Definition: float8.hpp:206
constexpr CK_TILE_HOST_DEVICE auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:360
GenericAttentionMaskEnum
Definition: block_masking.hpp:11
__host__ __device__ composes(Ts &&...) -> composes< remove_cvref_t< Ts >... >
FIXME: create macro to replace 'host device' and nothing more.
constexpr CK_TILE_HOST_DEVICE T max(T x)
Definition: math.hpp:161
typename conditional< predicate, X, Y >::type conditional_t
Definition: functional.hpp:115
Definition: block_position_encoding.hpp:48
Definition: block_attention_bias_enum.hpp:19
Definition: block_position_encoding.hpp:137
Definition: fmha_fwd_pagedkv_kernel.hpp:281
ck_tile::index_t qo_head_idx
Definition: fmha_fwd_pagedkv_kernel.hpp:283
ck_tile::index_t batch_idx
Definition: fmha_fwd_pagedkv_kernel.hpp:282
ck_tile::index_t kv_head_idx
Definition: fmha_fwd_pagedkv_kernel.hpp:284
Definition: fmha_fwd_pagedkv_kernel.hpp:227
const int32_t * cache_batch_idx
Definition: fmha_fwd_pagedkv_kernel.hpp:228
Definition: fmha_fwd_pagedkv_kernel.hpp:215
ck_tile::index_t batch_stride_block_table
Definition: fmha_fwd_pagedkv_kernel.hpp:217
const int32_t * block_table_ptr
Definition: fmha_fwd_pagedkv_kernel.hpp:216
ck_tile::index_t page_block_size
Definition: fmha_fwd_pagedkv_kernel.hpp:218
Definition: fmha_fwd_pagedkv_kernel.hpp:183
ck_tile::index_t alibi_slope_stride
Definition: fmha_fwd_pagedkv_kernel.hpp:186
const void * alibi_slope_ptr
Definition: fmha_fwd_pagedkv_kernel.hpp:185
Definition: fmha_fwd_pagedkv_kernel.hpp:178
ck_tile::index_t batch_stride_bias
Definition: fmha_fwd_pagedkv_kernel.hpp:179
Definition: fmha_fwd_pagedkv_kernel.hpp:243
ck_tile::index_t batch_stride_k
Definition: fmha_fwd_pagedkv_kernel.hpp:247
ck_tile::index_t batch_stride_q
Definition: fmha_fwd_pagedkv_kernel.hpp:246
const int32_t * seqlen_k_ptr
Definition: fmha_fwd_pagedkv_kernel.hpp:244
ck_tile::index_t batch_stride_o
Definition: fmha_fwd_pagedkv_kernel.hpp:251
ck_tile::index_t batch_stride_v
Definition: fmha_fwd_pagedkv_kernel.hpp:249
Definition: fmha_fwd_pagedkv_kernel.hpp:171
ck_tile::index_t stride_bias
Definition: fmha_fwd_pagedkv_kernel.hpp:173
const void * bias_ptr
Definition: fmha_fwd_pagedkv_kernel.hpp:172
ck_tile::index_t nhead_stride_bias
Definition: fmha_fwd_pagedkv_kernel.hpp:174
Definition: fmha_fwd_pagedkv_kernel.hpp:120
ck_tile::index_t hdim_v
Definition: fmha_fwd_pagedkv_kernel.hpp:129
ck_tile::index_t seqlen_q
Definition: fmha_fwd_pagedkv_kernel.hpp:126
ck_tile::index_t nhead_stride_k
Definition: fmha_fwd_pagedkv_kernel.hpp:143
ck_tile::index_t stride_o
Definition: fmha_fwd_pagedkv_kernel.hpp:140
const void * k_ptr
Definition: fmha_fwd_pagedkv_kernel.hpp:122
ck_tile::index_t nhead_stride_q
Definition: fmha_fwd_pagedkv_kernel.hpp:142
ck_tile::index_t stride_v
Definition: fmha_fwd_pagedkv_kernel.hpp:139
float scale_s
Definition: fmha_fwd_pagedkv_kernel.hpp:135
ck_tile::index_t stride_k
Definition: fmha_fwd_pagedkv_kernel.hpp:138
const void * v_ptr
Definition: fmha_fwd_pagedkv_kernel.hpp:123
ck_tile::index_t nhead_stride_v
Definition: fmha_fwd_pagedkv_kernel.hpp:144
ck_tile::index_t nhead_stride_o
Definition: fmha_fwd_pagedkv_kernel.hpp:145
ck_tile::index_t hdim_q
Definition: fmha_fwd_pagedkv_kernel.hpp:128
ck_tile::index_t seqlen_k
Definition: fmha_fwd_pagedkv_kernel.hpp:127
const void * q_ptr
Definition: fmha_fwd_pagedkv_kernel.hpp:121
ck_tile::index_t num_head_q
Definition: fmha_fwd_pagedkv_kernel.hpp:131
ck_tile::index_t nhead_ratio_qk
Definition: fmha_fwd_pagedkv_kernel.hpp:134
void * o_ptr
Definition: fmha_fwd_pagedkv_kernel.hpp:124
ck_tile::index_t stride_q
Definition: fmha_fwd_pagedkv_kernel.hpp:137
Definition: fmha_fwd_pagedkv_kernel.hpp:203
ck_tile::index_t batch_stride_lse
Definition: fmha_fwd_pagedkv_kernel.hpp:206
void * lse_ptr
Definition: fmha_fwd_pagedkv_kernel.hpp:204
ck_tile::index_t nhead_stride_lse
Definition: fmha_fwd_pagedkv_kernel.hpp:205
Definition: fmha_fwd_pagedkv_kernel.hpp:113
Definition: fmha_fwd_pagedkv_kernel.hpp:197
float scale_p
Definition: fmha_fwd_pagedkv_kernel.hpp:198
float scale_o
Definition: fmha_fwd_pagedkv_kernel.hpp:199
Definition: fmha_fwd_pagedkv_kernel.hpp:267
const int32_t * seqlen_k_ptr
Definition: fmha_fwd_pagedkv_kernel.hpp:270
const int32_t * seqstart_k_ptr
Definition: fmha_fwd_pagedkv_kernel.hpp:269
ck_tile::index_t batch_stride_k
Definition: fmha_fwd_pagedkv_kernel.hpp:272
ck_tile::index_t batch_stride_v
Definition: fmha_fwd_pagedkv_kernel.hpp:274
const int32_t * seqstart_q_ptr
Definition: fmha_fwd_pagedkv_kernel.hpp:268
Definition: fmha_fwd_pagedkv_kernel.hpp:149
float logits_soft_cap
Definition: fmha_fwd_pagedkv_kernel.hpp:166
void init_logits_soft_cap(float logits_soft_cap_)
Definition: fmha_fwd_pagedkv_kernel.hpp:152
float logits_soft_cap_rcp
Definition: fmha_fwd_pagedkv_kernel.hpp:167
Definition: fmha_fwd_pagedkv_kernel.hpp:190
ck_tile::index_t window_size_left
Definition: fmha_fwd_pagedkv_kernel.hpp:192
ck_tile::index_t window_size_right
Definition: fmha_fwd_pagedkv_kernel.hpp:192
ck_tile::GenericAttentionMaskEnum mask_type
Definition: fmha_fwd_pagedkv_kernel.hpp:193
Definition: fmha_fwd_pagedkv_kernel.hpp:210
ck_tile::index_t min_seqlen_q
Definition: fmha_fwd_pagedkv_kernel.hpp:211
Definition: fmha_fwd_pagedkv_kernel.hpp:222
bool is_gappy
Definition: fmha_fwd_pagedkv_kernel.hpp:223
Definition: fmha_fwd_pagedkv_kernel.hpp:66
Definition: fmha_fwd_pagedkv_kernel.hpp:28
static constexpr CK_TILE_HOST std::enable_if_t< Cond, Kargs > MakeKargsImpl(const void *q_ptr, const void *k_ptr, const void *v_ptr, const void *bias_ptr, void *lse_ptr, void *o_ptr, ck_tile::index_t seqlen_q, ck_tile::index_t seqlen_k, const void *seqlen_k_ptr, ck_tile::index_t hdim_q, ck_tile::index_t hdim_v, ck_tile::index_t num_head_q, ck_tile::index_t nhead_ratio_qk, const void *block_table_ptr, ck_tile::index_t batch_stride_block_table, ck_tile::index_t page_block_size, const void *cache_batch_idx, float scale_s, float scale_p, float scale_o, float logits_soft_cap, ck_tile::index_t stride_q, ck_tile::index_t stride_k, ck_tile::index_t stride_v, ck_tile::index_t stride_bias, ck_tile::index_t stride_o, ck_tile::index_t nhead_stride_q, ck_tile::index_t nhead_stride_k, ck_tile::index_t nhead_stride_v, ck_tile::index_t nhead_stride_bias, ck_tile::index_t nhead_stride_lse, ck_tile::index_t nhead_stride_o, ck_tile::index_t batch_stride_q, ck_tile::index_t batch_stride_k, ck_tile::index_t batch_stride_v, ck_tile::index_t batch_stride_bias, ck_tile::index_t batch_stride_lse, ck_tile::index_t batch_stride_o, ck_tile::index_t window_size_left, ck_tile::index_t window_size_right, ck_tile::index_t mask_type)
Definition: fmha_fwd_pagedkv_kernel.hpp:289
static constexpr bool kIsGroupMode
Definition: fmha_fwd_pagedkv_kernel.hpp:47
static constexpr bool kHasLogitsSoftCap
Definition: fmha_fwd_pagedkv_kernel.hpp:52
ck_tile::remove_cvref_t< EpiloguePipeline_ > EpiloguePipeline
Definition: fmha_fwd_pagedkv_kernel.hpp:30
static constexpr ck_tile::index_t kBlockPerCu
Definition: fmha_fwd_pagedkv_kernel.hpp:32
static constexpr bool kStoreLSE
Definition: fmha_fwd_pagedkv_kernel.hpp:54
static CK_TILE_HOST std::string GetName()
Definition: fmha_fwd_pagedkv_kernel.hpp:74
static constexpr bool kPadSeqLenK
Definition: fmha_fwd_pagedkv_kernel.hpp:49
static constexpr bool kIsPagedKV
Definition: fmha_fwd_pagedkv_kernel.hpp:57
static constexpr CK_TILE_HOST std::enable_if_t< Cond, Kargs > MakeKargs(const void *q_ptr, const void *k_ptr, const void *v_ptr, const void *bias_ptr, void *lse_ptr, void *o_ptr, const void *seqstart_q_ptr, const void *seqstart_k_ptr, const void *seqlen_k_ptr, ck_tile::index_t hdim_q, ck_tile::index_t hdim_v, ck_tile::index_t num_head_q, ck_tile::index_t nhead_ratio_qk, const void *block_table_ptr, ck_tile::index_t batch_stride_block_table, ck_tile::index_t page_block_size, bool is_gappy, float scale_s, float scale_p, float scale_o, float logits_soft_cap, ck_tile::index_t stride_q, ck_tile::index_t stride_k, ck_tile::index_t stride_v, ck_tile::index_t stride_bias, ck_tile::index_t stride_o, ck_tile::index_t nhead_stride_q, ck_tile::index_t nhead_stride_k, ck_tile::index_t nhead_stride_v, ck_tile::index_t nhead_stride_bias, ck_tile::index_t nhead_stride_lse, ck_tile::index_t nhead_stride_o, ck_tile::index_t batch_stride_k, ck_tile::index_t batch_stride_v, ck_tile::index_t window_size_left, ck_tile::index_t window_size_right, ck_tile::index_t mask_type, ck_tile::index_t min_seqlen_q)
Definition: fmha_fwd_pagedkv_kernel.hpp:627
static constexpr bool kSkipMinSeqlenQ
Definition: fmha_fwd_pagedkv_kernel.hpp:56
static CK_TILE_HOST void PrintParameters(const Kargs &kargs, int num_batches)
Definition: fmha_fwd_pagedkv_kernel.hpp:706
static constexpr CK_TILE_HOST std::enable_if_t< Cond, Kargs > MakeKargs(const void *q_ptr, const void *k_ptr, const void *v_ptr, const void *bias_ptr, void *lse_ptr, void *o_ptr, ck_tile::index_t seqlen_q, ck_tile::index_t seqlen_k, const void *seqlen_k_ptr, ck_tile::index_t hdim_q, ck_tile::index_t hdim_v, ck_tile::index_t num_head_q, ck_tile::index_t nhead_ratio_qk, const void *block_table_ptr, ck_tile::index_t batch_stride_block_table, ck_tile::index_t page_block_size, const void *cache_batch_idx, float scale_s, float scale_p, float scale_o, float logits_soft_cap, ck_tile::index_t stride_q, ck_tile::index_t stride_k, ck_tile::index_t stride_v, ck_tile::index_t stride_bias, ck_tile::index_t stride_o, ck_tile::index_t nhead_stride_q, ck_tile::index_t nhead_stride_k, ck_tile::index_t nhead_stride_v, ck_tile::index_t nhead_stride_bias, ck_tile::index_t nhead_stride_lse, ck_tile::index_t nhead_stride_o, ck_tile::index_t batch_stride_q, ck_tile::index_t batch_stride_k, ck_tile::index_t batch_stride_v, ck_tile::index_t batch_stride_bias, ck_tile::index_t batch_stride_lse, ck_tile::index_t batch_stride_o, ck_tile::index_t window_size_left, ck_tile::index_t window_size_right, ck_tile::index_t mask_type)
Definition: fmha_fwd_pagedkv_kernel.hpp:416
static constexpr ck_tile::index_t kBlockSize
Definition: fmha_fwd_pagedkv_kernel.hpp:31
static constexpr bool kPadHeadDimV
Definition: fmha_fwd_pagedkv_kernel.hpp:51
ck_tile::remove_cvref_t< typename FmhaPipeline::LSEDataType > LSEDataType
Definition: fmha_fwd_pagedkv_kernel.hpp:41
ck_tile::remove_cvref_t< typename FmhaPipeline::VDataType > VDataType
Definition: fmha_fwd_pagedkv_kernel.hpp:39
static constexpr CK_TILE_DEVICE auto GetTileIndex(const Kargs &kargs)
Definition: fmha_fwd_pagedkv_kernel.hpp:803
static constexpr ck_tile::index_t kBlockPerCuInput
Definition: fmha_fwd_pagedkv_kernel.hpp:35
static constexpr bool kUseAsyncCopy
Definition: fmha_fwd_pagedkv_kernel.hpp:63
static constexpr bool kHasMask
Definition: fmha_fwd_pagedkv_kernel.hpp:61
static constexpr CK_TILE_HOST auto GridSize(ck_tile::index_t batch_size_, ck_tile::index_t nhead_, ck_tile::index_t seqlen_q_, ck_tile::index_t hdim_v_, bool has_padded_seqlen_k)
Definition: fmha_fwd_pagedkv_kernel.hpp:778
ck_tile::remove_cvref_t< FmhaPipeline_ > FmhaPipeline
Definition: fmha_fwd_pagedkv_kernel.hpp:29
ck_tile::remove_cvref_t< typename FmhaPipeline::ODataType > ODataType
Definition: fmha_fwd_pagedkv_kernel.hpp:42
static constexpr CK_TILE_HOST auto BlockSize()
Definition: fmha_fwd_pagedkv_kernel.hpp:868
ck_tile::remove_cvref_t< typename FmhaPipeline::BiasDataType > BiasDataType
Definition: fmha_fwd_pagedkv_kernel.hpp:40
static constexpr bool kPadHeadDimQ
Definition: fmha_fwd_pagedkv_kernel.hpp:50
ck_tile::remove_cvref_t< typename FmhaPipeline::VLayout > VLayout
Definition: fmha_fwd_pagedkv_kernel.hpp:45
static constexpr bool kPadSeqLenQ
Definition: fmha_fwd_pagedkv_kernel.hpp:48
ck_tile::remove_cvref_t< typename FmhaPipeline::SaccDataType > SaccDataType
Definition: fmha_fwd_pagedkv_kernel.hpp:43
ck_tile::remove_cvref_t< typename FmhaPipeline::FmhaMask > FmhaMask
Definition: fmha_fwd_pagedkv_kernel.hpp:60
static constexpr CK_TILE_HOST_DEVICE ck_tile::index_t GetSmemSize()
Definition: fmha_fwd_pagedkv_kernel.hpp:870
CK_TILE_DEVICE void operator()(Kargs kargs) const
Definition: fmha_fwd_pagedkv_kernel.hpp:875
ck_tile::remove_cvref_t< typename FmhaPipeline::AttentionVariant > AttentionVariant
Definition: fmha_fwd_pagedkv_kernel.hpp:59
ck_tile::remove_cvref_t< typename FmhaPipeline::KDataType > KDataType
Definition: fmha_fwd_pagedkv_kernel.hpp:38
ck_tile::remove_cvref_t< typename FmhaPipeline::QDataType > QDataType
Definition: fmha_fwd_pagedkv_kernel.hpp:37
static constexpr bool kDoFp8StaticQuant
Definition: fmha_fwd_pagedkv_kernel.hpp:55
std::conditional_t< kIsGroupMode, FmhaFwdGroupModeKargs, FmhaFwdBatchModeKargs > Kargs
Definition: fmha_fwd_pagedkv_kernel.hpp:278
static constexpr CK_TILE_HOST std::enable_if_t< Cond, Kargs > MakeKargsImpl(const void *q_ptr, const void *k_ptr, const void *v_ptr, const void *bias_ptr, void *lse_ptr, void *o_ptr, const void *seqstart_q_ptr, const void *seqstart_k_ptr, const void *seqlen_k_ptr, ck_tile::index_t hdim_q, ck_tile::index_t hdim_v, ck_tile::index_t num_head_q, ck_tile::index_t nhead_ratio_qk, const void *block_table_ptr, ck_tile::index_t batch_stride_block_table, ck_tile::index_t page_block_size, bool is_gappy, float scale_s, float scale_p, float scale_o, float logits_soft_cap, ck_tile::index_t stride_q, ck_tile::index_t stride_k, ck_tile::index_t stride_v, ck_tile::index_t stride_bias, ck_tile::index_t stride_o, ck_tile::index_t nhead_stride_q, ck_tile::index_t nhead_stride_k, ck_tile::index_t nhead_stride_v, ck_tile::index_t nhead_stride_bias, ck_tile::index_t nhead_stride_lse, ck_tile::index_t nhead_stride_o, ck_tile::index_t batch_stride_k, ck_tile::index_t batch_stride_v, ck_tile::index_t window_size_left, ck_tile::index_t window_size_right, ck_tile::index_t mask_type, ck_tile::index_t min_seqlen_q)
Definition: fmha_fwd_pagedkv_kernel.hpp:503
static constexpr auto BiasEnum
Definition: fmha_fwd_pagedkv_kernel.hpp:53
Definition: variants.hpp:63
float logits_soft_cap
Definition: variants.hpp:128
Definition: variants.hpp:51
Definition: integral_constant.hpp:13
Definition: functional.hpp:86
Definition: unary_element_function.hpp:56
Definition: math.hpp:28
Definition: sequence.hpp:49