/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/fmha/kernel/fmha_fwd_pagedkv_kernel.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/fmha/kernel/fmha_fwd_pagedkv_kernel.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/fmha/kernel/fmha_fwd_pagedkv_kernel.hpp Source File
fmha_fwd_pagedkv_kernel.hpp
Go to the documentation of this file.
1 // Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
2 // SPDX-License-Identifier: MIT
3 
4 #pragma once
5 
6 #include "ck_tile/core.hpp"
7 #include "ck_tile/ops/common.hpp"
10 
11 #include <string>
12 #include <type_traits>
13 #include <utility>
14 #include <variant>
15 
16 // S[seqlen_q, seqlen_k] = Q[seqlen_q, hdim_q] @ K[seqlen_k, hdim_q]
17 // S'[seqlen_q, seqlen_k] = S[seqlen_q, seqlen_k] * Scale[1]
18 // S''[seqlen_q, seqlen_k] = S'[seqlen_q, seqlen_k] + Bias[seqlen_q, seqlen_k]
19 // P[seqlen_q, seqlen_k] = Softmax(S''[seqlen_q, seqlen_k])
20 // O[seqlen_q, hdim_v] = P[seqlen_q, seqlen_k] @ V^T[hdim_v, seqlen_k]
21 
22 namespace ck_tile {
23 
24 // TODO: This class is a variant of the existing FmhaFwdSplitKVKernel pipeline.
25 // Refactoring to extract shared logic is recommended as future work.
26 template <typename FmhaPipeline_, typename EpiloguePipeline_>
28 {
31  static constexpr ck_tile::index_t kBlockSize = FmhaPipeline::kBlockSize;
32  static constexpr ck_tile::index_t kBlockPerCu = FmhaPipeline::kBlockPerCu;
33 
34  static_assert(kBlockPerCu > 0);
35  static constexpr ck_tile::index_t kBlockPerCuInput = FmhaPipeline::Problem::kBlockPerCu;
36 
44 
46 
47  static constexpr bool kIsGroupMode = FmhaPipeline::kIsGroupMode;
48  static constexpr bool kPadSeqLenQ = FmhaPipeline::kPadSeqLenQ;
49  static constexpr bool kPadSeqLenK = FmhaPipeline::kPadSeqLenK;
50  static constexpr bool kPadHeadDimQ = FmhaPipeline::kPadHeadDimQ;
51  static constexpr bool kPadHeadDimV = FmhaPipeline::kPadHeadDimV;
52  static constexpr bool kHasLogitsSoftCap = FmhaPipeline::kHasLogitsSoftCap;
53  static constexpr auto BiasEnum = FmhaPipeline::BiasEnum;
54  static constexpr bool kStoreLSE = FmhaPipeline::kStoreLSE;
55  static constexpr bool kDoFp8StaticQuant = FmhaPipeline::Problem::kDoFp8StaticQuant;
56  static constexpr bool kSkipMinSeqlenQ = FmhaPipeline::Problem::kSkipMinSeqlenQ;
57  static constexpr bool kIsPagedKV = FmhaPipeline::Problem::kIsPagedKV;
58  static constexpr bool kHasSink = FmhaPipeline::kHasSink;
59 
62  static constexpr bool kHasMask = FmhaMask::IsMasking;
63 
64  static constexpr bool kUseAsyncCopy = FmhaPipeline::Policy::AsyncCopy;
65 
66  // clang-format off
67  template <typename T> struct t2s;
68  template <> struct t2s<float> { static constexpr const char * name = "fp32"; };
69  template <> struct t2s<ck_tile::fp16_t> { static constexpr const char * name = "fp16"; };
70  template <> struct t2s<ck_tile::bf16_t> { static constexpr const char * name = "bf16"; };
71  template <> struct t2s<ck_tile::fp8_t> { static constexpr const char * name = "fp8"; };
72  template <> struct t2s<ck_tile::bf8_t> { static constexpr const char * name = "bf8"; };
73  // clang-format on
74 
75  CK_TILE_HOST static std::string GetName()
76  {
77  // sync with generate.py
78  // clang-format off
79  using bfs = typename FmhaPipeline::BlockFmhaShape;
80  using g0br = typename bfs::Gemm0BlockWarps;
81  using g1br = typename bfs::Gemm1BlockWarps;
82  using g0wt = typename bfs::Gemm0WarpTile;
83  using g1wt = typename bfs::Gemm1WarpTile;
84  #define _SS_ std::string
85  #define _TS_ std::to_string
86  auto pn = [&] () {
87  std::string n;
88  if (kPadSeqLenQ) n += "s";
89  if (kPadSeqLenK) n += "sk";
90  if (kPadHeadDimQ) n += "d";
91  if (kPadHeadDimV) n += "dv";
92  return n.empty() ? n : std::string("p") + n; }();
93  return
94  _SS_("fmha_fwd_pagedkv_d") + _TS_(bfs::kQKHeaddim) + "_" + _SS_(t2s<QDataType>::name) +
95  "_" + (kIsGroupMode ? "group" : "batch") + "_"
96  "b" + _TS_(bfs::kM0) + "x" + _TS_(bfs::kN0) + "x" + _TS_(bfs::kK0) + "x" +
97  _TS_(bfs::kN1) + "x" + _TS_(bfs::kK1) + "x" + _TS_(bfs::kQKHeaddim) + "_" +
98  "r" + _TS_(g0br::at(ck_tile::number<0>{})) + "x" + _TS_(g0br::at(ck_tile::number<1>{})) + "x" + _TS_(g0br::at(ck_tile::number<2>{})) + "_" +
99  "r" + _TS_(g1br::at(ck_tile::number<0>{})) + "x" + _TS_(g1br::at(ck_tile::number<1>{})) + "x" + _TS_(g1br::at(ck_tile::number<2>{})) + "_" +
100  "w" + _TS_(g0wt::at(ck_tile::number<0>{})) + "x" + _TS_(g0wt::at(ck_tile::number<1>{})) + "x" + _TS_(g0wt::at(ck_tile::number<2>{})) + "_" +
101  "w" + _TS_(g1wt::at(ck_tile::number<0>{})) + "x" + _TS_(g1wt::at(ck_tile::number<1>{})) + "x" + _TS_(g1wt::at(ck_tile::number<2>{})) + "_" +
102  (kBlockPerCuInput == -1 ? "" : ("o" + _TS_(kBlockPerCu) + "_")) + _SS_(FmhaPipeline::name) + "_" +
103  "v" + (std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor> ? "r" : "c") + (pn.empty() ? "_npad" : "_" + pn) +
104  (kHasLogitsSoftCap ? "_logits" : "_nlogits" ) + (BiasEnum == BlockAttentionBiasEnum::NO_BIAS ? _SS_("_nbias") : (_SS_("_") + BlockAttentionBiasEnumToStr<BiasEnum>::name)) +
105  (kHasMask ? "_" + _SS_(FmhaMask::name) : "_nmask") + (kStoreLSE ? "_lse" : "_nlse" ) + (kSkipMinSeqlenQ ? "_skip" : "_nskip" ) + (kDoFp8StaticQuant ? "_squant" : "_nsquant" ) + (kIsPagedKV ? "_pagedkv" : "_npagedkv" ) + (kHasSink ? "_sink" : "_nsink" );
106  #undef _SS_
107  #undef _TS_
108  // clang-format on
109  }
110 
111  template <ck_tile::index_t I> // to avoid duplicated base class prblem, introduce an template
112  // arg
114  {
115  };
116 
117  // kargs use aggregate initializer, so no constructor will provided
118  // use inheritance to minimize karg size
119  // user need to use MakeKargs() function to create kargs.
121  {
122  const void* q_ptr;
123  const void* k_ptr;
124  const void* v_ptr;
125  void* o_ptr;
126 
131 
133  // for MQA/GQA, nhead could be different. This parameter is nhead_q / nhead_k
134  // if this param is larger than 1, indicate MQA/GQA case
136  float scale_s;
137 
142 
147  };
148 
150  {
152 
153  void init_logits_soft_cap(float logits_soft_cap_)
154  {
155  if(0 < logits_soft_cap_)
156  {
157  logits_soft_cap = logits_soft_cap_;
159  }
160  else
161  {
162  logits_soft_cap = 0.f;
163  logits_soft_cap_rcp = 0.f;
164  }
165  }
166 
169  };
170 
172  {
173  const void* bias_ptr = nullptr;
176  };
177 
179  {
181  };
182 
184  {
185  // alibi is batch*nhead*1, no matter in batch/group mode, they are the same
186  const void* alibi_slope_ptr;
187  ck_tile::index_t alibi_slope_stride; // stride in batch, or 0 for all batch share same slope
188  };
189 
191  {
192  // ck_tile::index_t window_size_left, window_size_right;
195  };
196 
198  {
199  float scale_p;
200  float scale_o;
201  };
202 
204  {
205  void* lse_ptr = nullptr;
208  };
209 
211  {
213  };
214 
216  {
220  };
221 
223  {
224  bool is_gappy = false;
225  };
226 
228  {
230  };
231 
234  std::conditional_t<BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS,
235  FmhaFwdBatchModeBiasKargs,
236  std::conditional_t<BiasEnum == BlockAttentionBiasEnum::ALIBI,
237  FmhaFwdAlibiKargs,
238  FmhaFwdEmptyKargs<0>>>,
239  std::conditional_t<kHasMask, FmhaFwdMaskKargs, FmhaFwdEmptyKargs<1>>,
240  std::conditional_t<kStoreLSE, FmhaFwdCommonLSEKargs, FmhaFwdEmptyKargs<2>>,
241  std::conditional_t<kDoFp8StaticQuant, FmhaFwdFp8StaticQuantKargs, FmhaFwdEmptyKargs<3>>,
242  std::conditional_t<kIsPagedKV, CommonPageBlockTableKargs, CacheBatchIdxKargs>,
243  std::conditional_t<kHasLogitsSoftCap, FmhaFwdLogitsSoftCapKargs, FmhaFwdEmptyKargs<4>>
244  {
246 
248  ck_tile::index_t batch_stride_k; // when using paged-kvcache, this will be stride/size for
249  // single kcache page-block
250  ck_tile::index_t batch_stride_v; // when using paged-kvcache, this will be stride/size for
251  // single vcache page-block
253  };
254 
257  std::conditional_t<BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS,
258  FmhaFwdCommonBiasKargs,
259  std::conditional_t<BiasEnum == BlockAttentionBiasEnum::ALIBI,
260  FmhaFwdAlibiKargs,
261  FmhaFwdEmptyKargs<0>>>,
262  std::conditional_t<kHasMask, FmhaFwdMaskKargs, FmhaFwdEmptyKargs<1>>,
263  std::conditional_t<kStoreLSE, FmhaFwdCommonLSEKargs, FmhaFwdEmptyKargs<2>>,
264  std::conditional_t<kDoFp8StaticQuant, FmhaFwdFp8StaticQuantKargs, FmhaFwdEmptyKargs<3>>,
265  std::conditional_t<kHasLogitsSoftCap, FmhaFwdLogitsSoftCapKargs, FmhaFwdEmptyKargs<4>>,
266  std::conditional_t<kIsPagedKV, GroupModePageBlockTableKargs, FmhaFwdEmptyKargs<5>>,
267  std::conditional_t<kSkipMinSeqlenQ, FmhaFwdSkipMinSeqlenQKargs, FmhaFwdEmptyKargs<6>>
268  {
272 
273  ck_tile::index_t batch_stride_k; // only used for paged-kvcache, this will be stride/size
274  // for single kcache page-block
275  ck_tile::index_t batch_stride_v; // only used for paged-kvcache, this will be stride/size
276  // for single vcache page-block
277  };
278 
279  using Kargs = std::conditional_t<kIsGroupMode, FmhaFwdGroupModeKargs, FmhaFwdBatchModeKargs>;
280 
282  {
286  };
287 
288  template <bool Cond = !kIsGroupMode>
289  CK_TILE_HOST static constexpr std::enable_if_t<Cond, Kargs>
290  MakeKargsImpl(const void* q_ptr,
291  const void* k_ptr,
292  const void* v_ptr,
293  const void* bias_ptr,
294  void* lse_ptr,
295  void* o_ptr,
296  ck_tile::index_t seqlen_q,
297  ck_tile::index_t seqlen_k,
298  const void* seqlen_k_ptr, // only used for (paged-) kvcache
299  ck_tile::index_t hdim_q,
300  ck_tile::index_t hdim_v,
301  ck_tile::index_t num_head_q,
302  ck_tile::index_t nhead_ratio_qk,
303  const void* block_table_ptr,
304  ck_tile::index_t batch_stride_block_table,
305  ck_tile::index_t page_block_size,
306  const void* cache_batch_idx,
307  float scale_s,
308  float scale_p,
309  float scale_o,
310  float logits_soft_cap,
311  ck_tile::index_t stride_q,
312  ck_tile::index_t stride_k,
313  ck_tile::index_t stride_v,
314  ck_tile::index_t stride_bias,
315  ck_tile::index_t stride_o,
316  ck_tile::index_t nhead_stride_q,
317  ck_tile::index_t nhead_stride_k,
318  ck_tile::index_t nhead_stride_v,
319  ck_tile::index_t nhead_stride_bias,
320  ck_tile::index_t nhead_stride_lse,
321  ck_tile::index_t nhead_stride_o,
322  ck_tile::index_t batch_stride_q,
323  ck_tile::index_t batch_stride_k,
324  ck_tile::index_t batch_stride_v,
325  ck_tile::index_t batch_stride_bias,
326  ck_tile::index_t batch_stride_lse,
327  ck_tile::index_t batch_stride_o,
328  ck_tile::index_t window_size_left,
329  ck_tile::index_t window_size_right,
330  ck_tile::index_t sink_size,
331  ck_tile::index_t mask_type)
332  {
333  Kargs kargs{{q_ptr,
334  k_ptr,
335  v_ptr,
336  o_ptr,
337  seqlen_q,
338  seqlen_k,
339  hdim_q,
340  hdim_v,
341  num_head_q,
342  nhead_ratio_qk,
343 #if CK_TILE_FMHA_FWD_FAST_EXP2
344  static_cast<float>(scale_s * ck_tile::log2e_v<>),
345 #else
346  scale_s,
347 #endif
348  stride_q,
349  stride_k,
350  stride_v,
351  stride_o,
352  nhead_stride_q,
353  nhead_stride_k,
354  nhead_stride_v,
355  nhead_stride_o}, // args for common karg
356  {}, // placeholder for bias
357  {}, // placeholder for mask
358  {}, // placeholder for lse
359  {}, // placeholder for fp8_static_quant args
360  {}, // placeholder for pagedkv
361  {}, // placeholder for logits_soft_cap
362  reinterpret_cast<const int32_t*>(seqlen_k_ptr),
363  batch_stride_q,
364  batch_stride_k,
365  batch_stride_v,
366  batch_stride_o};
367 
369  {
370  kargs.bias_ptr = bias_ptr;
371  kargs.stride_bias = stride_bias;
372  kargs.nhead_stride_bias = nhead_stride_bias;
373  kargs.batch_stride_bias = batch_stride_bias;
374  }
375  else if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI)
376  {
377  kargs.alibi_slope_ptr = bias_ptr;
378  kargs.alibi_slope_stride = stride_bias;
379  }
380  if constexpr(kHasMask)
381  {
382  kargs.window_size_left = window_size_left;
383  kargs.window_size_right = window_size_right;
384  kargs.sink_size = sink_size;
385  kargs.mask_type = static_cast<ck_tile::GenericAttentionMaskEnum>(mask_type);
386  }
387  if constexpr(kStoreLSE)
388  {
389  kargs.lse_ptr = lse_ptr;
390  kargs.nhead_stride_lse = nhead_stride_lse;
391  kargs.batch_stride_lse = batch_stride_lse;
392  }
393  if constexpr(kDoFp8StaticQuant)
394  {
395  kargs.scale_p = scale_p;
396  kargs.scale_o = scale_o;
397  }
398  if constexpr(kIsPagedKV)
399  {
400  kargs.block_table_ptr = reinterpret_cast<const int32_t*>(block_table_ptr);
401  kargs.batch_stride_block_table = batch_stride_block_table;
402  kargs.page_block_size = page_block_size;
403  }
404  else
405  {
406  kargs.cache_batch_idx = reinterpret_cast<const int32_t*>(cache_batch_idx);
407  }
408  if constexpr(kHasLogitsSoftCap)
409  {
410  kargs.init_logits_soft_cap(logits_soft_cap);
411  }
412 
413  return kargs;
414  }
415 
416  // std::variant<> can't take in a list initializer, overload for backward compatibility
417  template <bool Cond = !kIsGroupMode>
418  CK_TILE_HOST static constexpr std::enable_if_t<Cond, Kargs>
419  MakeKargs(const void* q_ptr,
420  const void* k_ptr,
421  const void* v_ptr,
422  const void* bias_ptr,
423  void* lse_ptr,
424  void* o_ptr,
425  ck_tile::index_t seqlen_q,
426  ck_tile::index_t seqlen_k,
427  const void* seqlen_k_ptr, // only used for (paged-) kvcache
428  ck_tile::index_t hdim_q,
429  ck_tile::index_t hdim_v,
430  ck_tile::index_t num_head_q,
431  ck_tile::index_t nhead_ratio_qk,
432  const void* block_table_ptr,
433  ck_tile::index_t batch_stride_block_table,
434  ck_tile::index_t page_block_size,
435  const void* cache_batch_idx,
436  float scale_s,
437  float scale_p,
438  float scale_o,
439  float logits_soft_cap,
440  ck_tile::index_t stride_q,
441  ck_tile::index_t stride_k,
442  ck_tile::index_t stride_v,
443  ck_tile::index_t stride_bias,
444  ck_tile::index_t stride_o,
445  ck_tile::index_t nhead_stride_q,
446  ck_tile::index_t nhead_stride_k,
447  ck_tile::index_t nhead_stride_v,
448  ck_tile::index_t nhead_stride_bias,
449  ck_tile::index_t nhead_stride_lse,
450  ck_tile::index_t nhead_stride_o,
451  ck_tile::index_t batch_stride_q,
452  ck_tile::index_t batch_stride_k,
453  ck_tile::index_t batch_stride_v,
454  ck_tile::index_t batch_stride_bias,
455  ck_tile::index_t batch_stride_lse,
456  ck_tile::index_t batch_stride_o,
457  ck_tile::index_t window_size_left,
458  ck_tile::index_t window_size_right,
459  ck_tile::index_t sink_size,
460  ck_tile::index_t mask_type)
461  {
462  return MakeKargsImpl(q_ptr,
463  k_ptr,
464  v_ptr,
465  bias_ptr,
466  lse_ptr,
467  o_ptr,
468  seqlen_q,
469  seqlen_k,
470  seqlen_k_ptr,
471  hdim_q,
472  hdim_v,
473  num_head_q,
474  nhead_ratio_qk,
475  block_table_ptr,
476  batch_stride_block_table,
477  page_block_size,
478  cache_batch_idx,
479  scale_s,
480  scale_p,
481  scale_o,
482  logits_soft_cap,
483  stride_q,
484  stride_k,
485  stride_v,
486  stride_bias,
487  stride_o,
488  nhead_stride_q,
489  nhead_stride_k,
490  nhead_stride_v,
491  nhead_stride_bias,
492  nhead_stride_lse,
493  nhead_stride_o,
494  batch_stride_q,
495  batch_stride_k,
496  batch_stride_v,
497  batch_stride_bias,
498  batch_stride_lse,
499  batch_stride_o,
500  window_size_left,
501  window_size_right,
502  sink_size,
503  mask_type);
504  }
505 
506  template <bool Cond = kIsGroupMode>
507  CK_TILE_HOST static constexpr std::enable_if_t<Cond, Kargs>
508  MakeKargsImpl(const void* q_ptr,
509  const void* k_ptr,
510  const void* v_ptr,
511  const void* bias_ptr,
512  void* lse_ptr,
513  void* o_ptr,
514  const void* seqstart_q_ptr,
515  const void* seqstart_k_ptr,
516  const void* seqlen_k_ptr,
517  ck_tile::index_t hdim_q,
518  ck_tile::index_t hdim_v,
519  ck_tile::index_t num_head_q,
520  ck_tile::index_t nhead_ratio_qk,
521  const void* block_table_ptr,
522  ck_tile::index_t batch_stride_block_table,
523  ck_tile::index_t page_block_size,
524  bool is_gappy,
525  float scale_s,
526  float scale_p,
527  float scale_o,
528  float logits_soft_cap,
529  ck_tile::index_t stride_q,
530  ck_tile::index_t stride_k,
531  ck_tile::index_t stride_v,
532  ck_tile::index_t stride_bias,
533  ck_tile::index_t stride_o,
534  ck_tile::index_t nhead_stride_q,
535  ck_tile::index_t nhead_stride_k,
536  ck_tile::index_t nhead_stride_v,
537  ck_tile::index_t nhead_stride_bias,
538  ck_tile::index_t nhead_stride_lse,
539  ck_tile::index_t nhead_stride_o,
540  ck_tile::index_t batch_stride_k, // only used for paged-kvcache
541  ck_tile::index_t batch_stride_v, // only used for paged-kvcache
542  ck_tile::index_t window_size_left,
543  ck_tile::index_t window_size_right,
544  ck_tile::index_t sink_size,
545  ck_tile::index_t mask_type,
546  ck_tile::index_t min_seqlen_q)
547  {
548  Kargs kargs{{q_ptr,
549  k_ptr,
550  v_ptr,
551  o_ptr,
552  -1, // seqlen will be updated by another pointer
553  -1, //
554  hdim_q,
555  hdim_v,
556  num_head_q,
557  nhead_ratio_qk,
558 #if CK_TILE_FMHA_FWD_FAST_EXP2
559  static_cast<float>(scale_s * ck_tile::log2e_v<>),
560 #else
561  scale_s,
562 #endif
563  stride_q,
564  stride_k,
565  stride_v,
566  stride_o,
567  nhead_stride_q,
568  nhead_stride_k,
569  nhead_stride_v,
570  nhead_stride_o}, // args for common karg
571  {}, // placeholder for bias
572  {}, // placeholder for mask
573  {}, // placeholder for lse
574  {}, // placeholder for fp8_static_quant args
575  {}, // placeholder for logits_soft_cap
576  {}, // placeholder for pagdkv
577  {}, // placeholder for min_seqlen_q
578  reinterpret_cast<const int32_t*>(seqstart_q_ptr),
579  reinterpret_cast<const int32_t*>(seqstart_k_ptr),
580  reinterpret_cast<const int32_t*>(seqlen_k_ptr),
581  batch_stride_k,
582  batch_stride_v};
583 
585  {
586  kargs.bias_ptr = bias_ptr;
587  kargs.stride_bias = stride_bias;
588  kargs.nhead_stride_bias = nhead_stride_bias;
589  }
590  else if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI)
591  {
592  kargs.alibi_slope_ptr = bias_ptr;
593  kargs.alibi_slope_stride = stride_bias;
594  }
595  if constexpr(kHasMask)
596  {
597  kargs.window_size_left = window_size_left;
598  kargs.window_size_right = window_size_right;
599  kargs.sink_size = sink_size;
600  kargs.mask_type = static_cast<ck_tile::GenericAttentionMaskEnum>(mask_type);
601  }
602  if constexpr(kStoreLSE)
603  {
604  kargs.lse_ptr = lse_ptr;
605  kargs.nhead_stride_lse = nhead_stride_lse;
606  }
607  if constexpr(kDoFp8StaticQuant)
608  {
609  kargs.scale_p = scale_p;
610  kargs.scale_o = scale_o;
611  }
612  if constexpr(kHasLogitsSoftCap)
613  {
614  kargs.init_logits_soft_cap(logits_soft_cap);
615  }
616  if constexpr(kIsPagedKV)
617  {
618  kargs.block_table_ptr = reinterpret_cast<const int32_t*>(block_table_ptr);
619  kargs.batch_stride_block_table = batch_stride_block_table;
620  kargs.page_block_size = page_block_size;
621  kargs.is_gappy = is_gappy;
622  }
623  if constexpr(kSkipMinSeqlenQ)
624  {
625  kargs.min_seqlen_q = min_seqlen_q;
626  }
627 
628  return kargs;
629  }
630 
631  // std::variant<> can't take in a list initializer, overload for backward compatibility
632  template <bool Cond = kIsGroupMode>
633  CK_TILE_HOST static constexpr std::enable_if_t<Cond, Kargs>
634  MakeKargs(const void* q_ptr,
635  const void* k_ptr,
636  const void* v_ptr,
637  const void* bias_ptr,
638  void* lse_ptr,
639  void* o_ptr,
640  const void* seqstart_q_ptr,
641  const void* seqstart_k_ptr,
642  const void* seqlen_k_ptr,
643  ck_tile::index_t hdim_q,
644  ck_tile::index_t hdim_v,
645  ck_tile::index_t num_head_q,
646  ck_tile::index_t nhead_ratio_qk,
647  const void* block_table_ptr,
648  ck_tile::index_t batch_stride_block_table,
649  ck_tile::index_t page_block_size,
650  bool is_gappy,
651  float scale_s,
652  float scale_p,
653  float scale_o,
654  float logits_soft_cap,
655  ck_tile::index_t stride_q,
656  ck_tile::index_t stride_k,
657  ck_tile::index_t stride_v,
658  ck_tile::index_t stride_bias,
659  ck_tile::index_t stride_o,
660  ck_tile::index_t nhead_stride_q,
661  ck_tile::index_t nhead_stride_k,
662  ck_tile::index_t nhead_stride_v,
663  ck_tile::index_t nhead_stride_bias,
664  ck_tile::index_t nhead_stride_lse,
665  ck_tile::index_t nhead_stride_o,
666  ck_tile::index_t batch_stride_k, // only used for paged-kvcache
667  ck_tile::index_t batch_stride_v, // only used for paged-kvcache
668  ck_tile::index_t window_size_left,
669  ck_tile::index_t window_size_right,
670  ck_tile::index_t sink_size,
671  ck_tile::index_t mask_type,
672  ck_tile::index_t min_seqlen_q)
673  {
674  return MakeKargsImpl(q_ptr,
675  k_ptr,
676  v_ptr,
677  bias_ptr,
678  lse_ptr,
679  o_ptr,
680  seqstart_q_ptr,
681  seqstart_k_ptr,
682  seqlen_k_ptr,
683  hdim_q,
684  hdim_v,
685  num_head_q,
686  nhead_ratio_qk,
687  block_table_ptr,
688  batch_stride_block_table,
689  page_block_size,
690  is_gappy,
691  scale_s,
692  scale_p,
693  scale_o,
694  logits_soft_cap,
695  stride_q,
696  stride_k,
697  stride_v,
698  stride_bias,
699  stride_o,
700  nhead_stride_q,
701  nhead_stride_k,
702  nhead_stride_v,
703  nhead_stride_bias,
704  nhead_stride_lse,
705  nhead_stride_o,
706  batch_stride_k,
707  batch_stride_v,
708  window_size_left,
709  window_size_right,
710  sink_size,
711  mask_type,
712  min_seqlen_q);
713  }
714 
715  CK_TILE_HOST static void PrintParameters(const Kargs& kargs, int num_batches)
716  {
717  static bool dummy = [&]() {
718  std::cout << std::endl;
719 
720  std::cout << " q_ptr: " << kargs.q_ptr << " k_ptr:" << kargs.k_ptr
721  << " v_ptr: " << kargs.v_ptr << " o_ptr:" << kargs.o_ptr
722  << " hdim_q: " << kargs.hdim_q << " hdim_v: " << kargs.hdim_v
723  << " num_head_q:" << kargs.num_head_q
724  << " nhead_ratio_qk: " << kargs.nhead_ratio_qk << " scale_s:" << kargs.scale_s
725  << " stride_q:" << kargs.stride_q << " stride_k:" << kargs.stride_k
726  << " stride_v:" << kargs.stride_v << " stride_o:" << kargs.stride_o
727  << " nhead_stride_q: " << kargs.nhead_stride_q
728  << " nhead_stride_k: " << kargs.nhead_stride_k
729  << " nhead_stride_v:" << kargs.nhead_stride_v
730  << " nhead_stride_o: " << kargs.nhead_stride_o;
731  if constexpr(!kIsGroupMode)
732  {
733  std::cout << " batch_stride_q:" << kargs.batch_stride_q;
734  }
735  std::cout << " batch_stride_k:" << kargs.batch_stride_k
736  << " batch_stride_v:" << kargs.batch_stride_v;
737 
738  if constexpr(kIsGroupMode)
739  {
740  if constexpr(kSkipMinSeqlenQ)
741  {
742  std::cout << " min_seqlen_q: " << kargs.min_seqlen_q;
743  }
744 
745  std::cout << " seqstart_q_ptr:" << kargs.seqstart_q_ptr
746  << " seqstart_k_ptr: " << kargs.seqstart_k_ptr
747  << " seqlen_k_ptr:" << kargs.seqlen_k_ptr;
748  if(kargs.seqlen_k_ptr != nullptr)
749  {
750  std::cout << "{";
751  for(int i_batch = 0; i_batch < num_batches; i_batch++)
752  std::cout << kargs.seqlen_k_ptr[i_batch] << ",";
753  std::cout << "}";
754  }
755  }
756  if constexpr(kHasMask)
757  {
758  std::cout << " window_size_left: " << kargs.window_size_left
759  << " window_size_right:" << kargs.window_size_right
760  << " mask_type: " << static_cast<int>(kargs.mask_type);
761  }
762 
763  if constexpr(kIsPagedKV)
764  {
765  std::cout << " block_table_ptr: " << kargs.block_table_ptr
766  << " batch_stride_block_table:" << kargs.batch_stride_block_table
767  << " page_block_size: " << kargs.page_block_size;
768 
769  std::cout << "table value: [";
770  for(int b = 0; b < num_batches; b++)
771  {
772  std::cout << "[ ";
773  for(int i = 0; i < kargs.batch_stride_block_table; i++)
774  {
775  std::cout << kargs.block_table_ptr[b * kargs.batch_stride_block_table + i]
776  << ",";
777  }
778  std::cout << " ]";
779  }
780  std::cout << " ]";
781  }
782  std::cout << std::endl;
783  return true;
784  }();
785  (void)dummy;
786  }
787  CK_TILE_HOST static constexpr auto GridSize(ck_tile::index_t batch_size_,
788  ck_tile::index_t nhead_,
789  ck_tile::index_t seqlen_q_,
790  ck_tile::index_t hdim_v_,
791  bool has_padded_seqlen_k)
792  {
793  // has_padded_seqlen_k is determined by checking (seqlen_k_ptr != nullptr)
794  if(has_padded_seqlen_k)
795  {
796  // TODO: this may need tuning
797  return dim3(nhead_,
798  batch_size_,
799  ck_tile::integer_divide_ceil(seqlen_q_, FmhaPipeline::kM0) *
800  ck_tile::integer_divide_ceil(hdim_v_, FmhaPipeline::kN1));
801  }
802  else
803  {
804  // TODO: this may need tuning
805  return dim3(ck_tile::integer_divide_ceil(seqlen_q_, FmhaPipeline::kM0) *
806  ck_tile::integer_divide_ceil(hdim_v_, FmhaPipeline::kN1),
807  nhead_,
808  batch_size_);
809  }
810  }
811 
812  CK_TILE_DEVICE static constexpr auto GetTileIndex(const Kargs& kargs)
813  {
814  bool has_padded_seqlen_k = false;
815 
816  if constexpr(kIsGroupMode)
817  has_padded_seqlen_k = (kargs.seqlen_k_ptr != nullptr);
818 
819  if(has_padded_seqlen_k)
820  {
821  // const index_t num_tile_m0 = seqlen_q / kM0;
822  const index_t num_tile_n1 =
823  ck_tile::integer_divide_ceil(kargs.hdim_v, FmhaPipeline::kN1);
824 
825  const index_t i_block = blockIdx.z;
826  const index_t i_nhead = blockIdx.x;
827  const index_t i_batch = blockIdx.y;
828 
829  const auto f = [](index_t dividend, index_t divisor) {
830  index_t quotient = dividend / divisor;
831  index_t modulus = dividend - quotient * divisor;
832  return ck_tile::make_tuple(quotient, modulus);
833  };
834 
835  const auto [i_tile_m, i_tile_n] = f(i_block, num_tile_n1);
836 
837  if constexpr(kHasMask)
838  {
839  // assume that num_tile_n1 is always 1
840  return ck_tile::make_tuple(gridDim.z - 1 - i_tile_m, i_tile_n, i_nhead, i_batch);
841  }
842  else
843  {
844  return ck_tile::make_tuple(i_tile_m, i_tile_n, i_nhead, i_batch);
845  }
846  }
847  else
848  {
849  // const index_t num_tile_m0 = seqlen_q / kM0;
850  const index_t num_tile_n1 =
851  ck_tile::integer_divide_ceil(kargs.hdim_v, FmhaPipeline::kN1);
852 
853  const index_t i_block = blockIdx.x;
854  const index_t i_nhead = blockIdx.y;
855  const index_t i_batch = blockIdx.z;
856 
857  const auto f = [](index_t dividend, index_t divisor) {
858  index_t quotient = dividend / divisor;
859  index_t modulus = dividend - quotient * divisor;
860  return ck_tile::make_tuple(quotient, modulus);
861  };
862 
863  const auto [i_tile_m, i_tile_n] = f(i_block, num_tile_n1);
864 
865  if constexpr(kHasMask)
866  {
867  // assume that num_tile_n1 is always 1
868  return ck_tile::make_tuple(gridDim.x - 1 - i_tile_m, i_tile_n, i_nhead, i_batch);
869  }
870  else
871  {
872  return ck_tile::make_tuple(i_tile_m, i_tile_n, i_nhead, i_batch);
873  }
874  }
875  }
876 
877  CK_TILE_HOST static dim3 BlockSize()
878  {
879  if(is_wave32())
880  {
881  return dim3(kBlockSize / 2);
882  }
883  else
884  {
885  return dim3(kBlockSize);
886  }
887  }
888 
890  {
891  return ck_tile::max(FmhaPipeline::GetSmemSize(), EpiloguePipeline::GetSmemSize());
892  }
893 
894  CK_TILE_DEVICE void operator()(Kargs kargs) const
895  {
896  // allocate LDS
897  __shared__ char smem_ptr[GetSmemSize()];
898 
899  // divide problem
900  const auto [i_tile_m, i_tile_n, i_nhead, i_batch] = GetTileIndex(kargs);
901 
902  const index_t i_m0 = amd_wave_read_first_lane(i_tile_m * FmhaPipeline::kM0);
903  const index_t i_n1 = amd_wave_read_first_lane(i_tile_n * FmhaPipeline::kN1);
904 
905  long_index_t batch_offset_q = 0;
906  long_index_t batch_offset_k = 0;
907  long_index_t batch_offset_v = 0;
908  long_index_t batch_offset_bias = 0;
909  long_index_t batch_offset_lse = 0;
910  long_index_t batch_offset_o = 0;
911  index_t kv_l2p_offset = 0;
912 
913  if constexpr(kIsGroupMode)
914  {
915  // get starting offset for each batch
916  const long_index_t query_start = kargs.seqstart_q_ptr[i_batch];
917  const long_index_t key_start = kargs.seqstart_k_ptr[i_batch];
918 
919  batch_offset_q = query_start * kargs.stride_q;
920  batch_offset_k = key_start * kargs.stride_k;
921  if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
922  {
923  batch_offset_v = key_start * kargs.stride_v;
924  }
925  else
926  {
927  batch_offset_v = key_start;
928  }
930  {
931  batch_offset_bias = query_start * kargs.stride_bias;
932  }
933  if constexpr(kStoreLSE)
934  {
935  batch_offset_lse = query_start;
936  }
937 
938  batch_offset_o = query_start * kargs.stride_o;
939 
940  // get real # queries & # keys under group mode
941  const auto adjusted_seqstart_q_ptr = kargs.seqstart_q_ptr + i_batch;
942  kargs.seqlen_q = adjusted_seqstart_q_ptr[1] - adjusted_seqstart_q_ptr[0];
943 
944  if constexpr(kSkipMinSeqlenQ)
945  {
946  if(kargs.seqlen_q <= kargs.min_seqlen_q)
947  {
948  return;
949  }
950  }
951 
952  // # of required blocks is different in each groups, terminate unnecessary blocks
953  // earlier
954  if(kargs.seqlen_q <= i_m0)
955  {
956  return;
957  }
958 
959  if(kargs.seqlen_k_ptr != nullptr)
960  {
961  kargs.seqlen_k = kargs.seqlen_k_ptr[i_batch];
962  }
963  else
964  {
965  const auto adjusted_seqstart_k_ptr = kargs.seqstart_k_ptr + i_batch;
966  kargs.seqlen_k = adjusted_seqstart_k_ptr[1] - adjusted_seqstart_k_ptr[0];
967  }
968 
969  if constexpr(kIsPagedKV)
970  {
971  if(kargs.is_gappy)
972  {
973  // seqstart_k_ptr has different meaning in this case
974  kv_l2p_offset = kargs.seqstart_k_ptr[i_batch];
975  }
976  }
977  }
978  else
979  {
980  const index_t i_cache_batch = [&, i_batch_ = i_batch] {
981  if constexpr(kIsPagedKV)
982  {
983  return i_batch_;
984  }
985  else
986  {
987  return (kargs.cache_batch_idx != nullptr ? kargs.cache_batch_idx[i_batch_]
988  : i_batch_);
989  }
990  }();
991 
992  batch_offset_q = static_cast<long_index_t>(i_batch) * kargs.batch_stride_q;
993  batch_offset_k = static_cast<long_index_t>(i_cache_batch) * kargs.batch_stride_k;
994  batch_offset_v = static_cast<long_index_t>(i_cache_batch) * kargs.batch_stride_v;
996  {
997  batch_offset_bias = static_cast<long_index_t>(i_batch) * kargs.batch_stride_bias;
998  }
999  if constexpr(kStoreLSE)
1000  {
1001  batch_offset_lse = static_cast<long_index_t>(i_batch) * kargs.batch_stride_lse;
1002  }
1003 
1004  batch_offset_o = static_cast<long_index_t>(i_batch) * kargs.batch_stride_o;
1005 
1006  if(kargs.seqlen_k_ptr != nullptr)
1007  {
1008  kargs.seqlen_k = kargs.seqlen_k_ptr[i_batch];
1009  }
1010  }
1011 
1012  // for simplicity, batch stride we just modify the pointer
1013  const QDataType* q_ptr = reinterpret_cast<const QDataType*>(kargs.q_ptr) +
1014  static_cast<long_index_t>(i_nhead) * kargs.nhead_stride_q +
1015  batch_offset_q;
1016  const KDataType* k_ptr =
1017  reinterpret_cast<const KDataType*>(kargs.k_ptr) +
1018  static_cast<long_index_t>(i_nhead / kargs.nhead_ratio_qk) * kargs.nhead_stride_k +
1019  batch_offset_k;
1020  const VDataType* v_ptr =
1021  reinterpret_cast<const VDataType*>(kargs.v_ptr) +
1022  static_cast<long_index_t>(i_nhead / kargs.nhead_ratio_qk) * kargs.nhead_stride_v +
1023  batch_offset_v;
1024  ODataType* o_ptr = reinterpret_cast<ODataType*>(kargs.o_ptr) +
1025  static_cast<long_index_t>(i_nhead) * kargs.nhead_stride_o +
1026  batch_offset_o;
1027 
1028  // Q/K/V DRAM and DRAM window
1029  const auto q_dram = [&]() {
1030  const auto q_dram_naive = make_naive_tensor_view<address_space_enum::global>(
1031  q_ptr,
1032  make_tuple(kargs.seqlen_q, kargs.hdim_q),
1033  make_tuple(kargs.stride_q, 1),
1035  number<1>{});
1036  if constexpr(FmhaPipeline::kQLoadOnce)
1037  {
1038  return pad_tensor_view(
1039  q_dram_naive,
1042  }
1043  else
1044  {
1045  return pad_tensor_view(
1046  q_dram_naive,
1049  }
1050  }();
1051 
1052  const auto make_k_dram = [&](const KDataType* data, index_t height) {
1053  const auto k_dram_naive = make_naive_tensor_view<address_space_enum::global>(
1054  data, // will update this pointer if using paged-kvcache
1055  make_tuple(height, kargs.hdim_q),
1056  make_tuple(kargs.stride_k, 1),
1058  number<1>{});
1059 
1060  return pad_tensor_view(
1061  k_dram_naive,
1064  };
1065  const auto k_dram = [&]() {
1066  if constexpr(kIsPagedKV)
1067  {
1068  return make_k_dram(nullptr, kargs.page_block_size);
1069  }
1070  else
1071  {
1072  return make_k_dram(k_ptr, kargs.seqlen_k);
1073  }
1074  }();
1075 
1076  const auto make_v_dram = [&](const VDataType* data, index_t length) {
1077  if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
1078  {
1079  const auto v_dram_naive = make_naive_tensor_view<address_space_enum::global>(
1080  data, // will update this pointer if using paged-kvcache
1081  make_tuple(length, kargs.hdim_v),
1082  make_tuple(kargs.stride_v, 1),
1084  number<1>{});
1085 
1086  const auto v_dram_transposed =
1087  transform_tensor_view(v_dram_naive,
1089  make_pass_through_transform(length)),
1092 
1093  return pad_tensor_view(
1094  v_dram_transposed,
1097  }
1098  else
1099  {
1100  const auto v_dram_naive = make_naive_tensor_view<address_space_enum::global>(
1101  data, // will update this pointer if using paged-kvcache
1102  make_tuple(kargs.hdim_v, length),
1103  make_tuple(kargs.stride_v, 1),
1105  number<1>{});
1106 
1107  return pad_tensor_view(
1108  v_dram_naive,
1111  }
1112  };
1113  const auto v_dram = [&]() {
1114  if constexpr(kIsPagedKV)
1115  {
1116  return make_v_dram(nullptr, kargs.page_block_size);
1117  }
1118  else
1119  {
1120  return make_v_dram(v_ptr, kargs.seqlen_k);
1121  }
1122  }();
1123 
1124  auto q_dram_window = make_tile_window(
1125  q_dram,
1126  [&]() {
1127  if constexpr(FmhaPipeline::kQLoadOnce)
1130  else
1132  }(),
1133  {i_m0, 0});
1134 
1135  auto k_page_block_navigator =
1136  [&, i_batch_ = i_batch, i_nhead_ = i_nhead / kargs.nhead_ratio_qk]() {
1137  if constexpr(kIsPagedKV)
1138  {
1139  const auto* block_indices =
1140  reinterpret_cast<const int32_t*>(kargs.block_table_ptr) +
1141  i_batch_ * kargs.batch_stride_block_table;
1142  const index_t num_blocks =
1143  integer_divide_ceil(kv_l2p_offset + kargs.seqlen_k, kargs.page_block_size);
1144 
1145  const long_index_t fixed_offset =
1146  static_cast<long_index_t>(i_nhead_) * kargs.nhead_stride_k;
1147 
1148  return make_page_block_navigator<const KDataType, 0>(
1149  kargs.k_ptr,
1150  kargs.batch_stride_k, // kcache page-block stride/size
1151  fixed_offset,
1152  block_indices,
1153  num_blocks,
1154  kargs.page_block_size,
1155  k_dram,
1156  make_k_dram(nullptr,
1157  (kv_l2p_offset + kargs.seqlen_k) -
1158  (num_blocks - 1) * kargs.page_block_size));
1159  }
1160  else
1161  {
1162  return make_page_block_navigator(k_dram);
1163  }
1164  }();
1165 
1166  auto v_page_block_navigator =
1167  [&, i_batch_ = i_batch, i_nhead_ = i_nhead / kargs.nhead_ratio_qk]() {
1168  if constexpr(kIsPagedKV)
1169  {
1170  const auto* block_indices =
1171  reinterpret_cast<const int32_t*>(kargs.block_table_ptr) +
1172  i_batch_ * kargs.batch_stride_block_table;
1173  const index_t num_blocks =
1174  integer_divide_ceil(kv_l2p_offset + kargs.seqlen_k, kargs.page_block_size);
1175 
1176  const long_index_t fixed_offset =
1177  static_cast<long_index_t>(i_nhead_) * kargs.nhead_stride_v;
1178 
1179  return make_page_block_navigator<const VDataType, 1>(
1180  kargs.v_ptr,
1181  kargs.batch_stride_v, // vcache page-block stride/size
1182  fixed_offset,
1183  block_indices,
1184  num_blocks,
1185  kargs.page_block_size,
1186  v_dram,
1187  make_v_dram(nullptr,
1188  (kv_l2p_offset + kargs.seqlen_k) -
1189  (num_blocks - 1) * kargs.page_block_size));
1190  }
1191  else
1192  {
1193  return make_page_block_navigator(v_dram);
1194  }
1195  }();
1196 
1197  auto k_dram_window_lengths =
1199  auto v_dram_window_lengths =
1201 
1204  const auto bias_dram_window = [&, i_nhead_ = i_nhead]() {
1205  constexpr auto bias_dram_window_lengths =
1208  {
1209  const BiasDataType* bias_ptr =
1210  reinterpret_cast<const BiasDataType*>(kargs.bias_ptr) +
1211  static_cast<long_index_t>(i_nhead_) * kargs.nhead_stride_bias +
1212  batch_offset_bias;
1213 
1214  const auto bias_dram = [&]() {
1215  const auto bias_dram_naive = make_naive_tensor_view<address_space_enum::global>(
1216  bias_ptr,
1217  make_tuple(kargs.seqlen_q, kargs.seqlen_k),
1218  make_tuple(kargs.stride_bias, 1),
1220  number<1>{});
1221 
1222  return pad_tensor_view(bias_dram_naive,
1223  bias_dram_window_lengths,
1225  }();
1226 
1227  return make_tile_window(bias_dram, bias_dram_window_lengths, {i_m0, 0});
1228  }
1229  else
1230  {
1231  return make_null_tile_window(bias_dram_window_lengths);
1232  }
1233  }();
1234 
1235  // lse
1236  auto lse_dram_window = [&, i_nhead_ = i_nhead]() {
1237  constexpr auto lse_dram_window_lengths = make_tuple(number<FmhaPipeline::kM0>{});
1238  if constexpr(kStoreLSE)
1239  {
1240  LSEDataType* lse_ptr =
1241  reinterpret_cast<LSEDataType*>(kargs.lse_ptr) +
1242  static_cast<long_index_t>(i_nhead_) * kargs.nhead_stride_lse + batch_offset_lse;
1243 
1244  const auto lse_dram = [&]() {
1245  const auto lse_dram_naive = make_naive_tensor_view<address_space_enum::global>(
1246  lse_ptr,
1247  make_tuple(kargs.seqlen_q),
1248  make_tuple(1),
1249  number<1>{},
1250  number<1>{});
1251 
1252  return pad_tensor_view(
1253  lse_dram_naive, lse_dram_window_lengths, sequence<kPadSeqLenQ>{});
1254  }();
1255 
1256  return make_tile_window(lse_dram, lse_dram_window_lengths, {i_m0});
1257  }
1258  else
1259  {
1260  return make_null_tile_window(lse_dram_window_lengths);
1261  }
1262  }();
1263 
1264  FmhaMask mask = [&]() {
1265  if constexpr(kHasMask)
1266  return ck_tile::make_generic_attention_mask_from_lr_window<FmhaMask>(
1267  kargs.window_size_left,
1268  kargs.window_size_right,
1269  kargs.sink_size,
1270  kargs.seqlen_q,
1271  kargs.seqlen_k,
1273  else
1274  return FmhaMask{kargs.seqlen_q, kargs.seqlen_k};
1275  }();
1276 
1277  // WA i_batch capture structure binding before c++20
1278  auto position_encoding = [&, i_batch_ = i_batch, i_nhead_ = i_nhead]() {
1279  if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI)
1280  {
1281  // data loading, shared by entire wg
1282  // TODO: how to use s_read?
1283  SaccDataType slope =
1284  *(reinterpret_cast<const SaccDataType*>(kargs.alibi_slope_ptr) +
1285  i_batch_ * kargs.alibi_slope_stride + i_nhead_);
1286 #if CK_TILE_FMHA_FWD_FAST_EXP2
1287  slope *= ck_tile::log2e_v<>;
1288 #endif
1289  if constexpr(kHasMask)
1290  {
1291  return make_alibi_from_lr_mask<SaccDataType, true>(slope,
1292  kargs.window_size_left,
1293  kargs.window_size_right,
1294  kargs.seqlen_q,
1295  kargs.seqlen_k,
1296  kargs.mask_type);
1297  }
1298  else
1299  {
1301  slope, kargs.seqlen_q, kargs.seqlen_k, AlibiMode::FROM_BOTTOM_RIGHT};
1302  }
1303  }
1304  else
1305  {
1307  }
1308  }();
1309 
1310  AttentionVariant variant;
1311  const auto variant_params = [&] {
1312  if constexpr(kHasLogitsSoftCap)
1313  {
1315  mask, kargs.scale_s, kargs.logits_soft_cap, kargs.logits_soft_cap_rcp};
1316  }
1317  else
1318  {
1319  return ck_tile::StandardAttentionParams<FmhaMask>{mask, kargs.scale_s};
1320  }
1321  }();
1322 
1323  BlockIndices block_indices{i_batch, i_nhead, i_nhead / kargs.nhead_ratio_qk};
1324 
1325  auto o_acc_tile = [&]() {
1326  if constexpr(kDoFp8StaticQuant)
1327  {
1328  return FmhaPipeline{}(
1329  q_dram_window,
1330  identity{}, // q_element_func
1331  k_dram_window_lengths,
1332  k_page_block_navigator,
1333  identity{}, // k_element_func
1334  v_dram_window_lengths,
1335  v_page_block_navigator,
1336  identity{}, // v_element_func
1337  bias_dram_window,
1338  identity{}, // bias_element_func
1339  lse_dram_window,
1340  identity{}, // lse_element_func
1341  identity{}, // s_acc_element_func
1342  scales{kargs.scale_p}, // p_compute_element_func
1343  composes(saturates<fp8_t>{}, scales{kargs.scale_o}), // o_acc_element_func
1344  mask,
1345  position_encoding,
1346  kargs.scale_s,
1347  variant,
1348  variant_params,
1349  block_indices,
1350  kv_l2p_offset,
1351  smem_ptr);
1352  }
1353  else
1354  {
1355  return FmhaPipeline{}(q_dram_window,
1356  k_dram_window_lengths,
1357  k_page_block_navigator,
1358  v_dram_window_lengths,
1359  v_page_block_navigator,
1360  bias_dram_window,
1361  lse_dram_window,
1362  mask,
1363  position_encoding,
1364  kargs.scale_s,
1365  variant,
1366  variant_params,
1367  block_indices,
1368  kv_l2p_offset,
1369  smem_ptr);
1370  }
1371  }();
1372 
1373  // O DRAM and O DRAM window
1374  auto o_dram = [&]() {
1375  const auto o_dram_naive = make_naive_tensor_view<address_space_enum::global>(
1376  o_ptr,
1377  make_tuple(kargs.seqlen_q, kargs.hdim_v),
1378  make_tuple(kargs.stride_o, 1),
1380  number<1>{});
1381  return pad_tensor_view(
1382  o_dram_naive,
1385  }();
1386 
1387  auto o_dram_window =
1388  make_tile_window(o_dram,
1390  {i_m0, i_n1});
1391 
1392  EpiloguePipeline{}(o_dram_window, o_acc_tile, nullptr);
1393  }
1394 };
1395 
1396 } // namespace ck_tile
#define CK_TILE_DEVICE
Definition: config.hpp:45
#define CK_TILE_HOST
Definition: config.hpp:44
#define CK_TILE_HOST_DEVICE
Definition: config.hpp:46
#define _TS_
#define _SS_
Definition: cluster_descriptor.hpp:13
constexpr CK_TILE_DEVICE auto make_null_tile_window(const WindowLengths &window_lengths)
Definition: null_tile_window.hpp:66
_BitInt(8) fp8_t
Definition: float8.hpp:204
constexpr CK_TILE_HOST_DEVICE auto integer_divide_ceil(X x, Y y)
Definition: math.hpp:145
__device__ uint32_t amd_wave_read_first_lane(uint16_t v)
Definition: amd_buffer_addressing.hpp:36
_Float16 fp16_t
Definition: half.hpp:110
CK_TILE_HOST_DEVICE auto make_page_block_navigator(const TensorView &tensor_view)
Definition: page_block_navigator.hpp:333
CK_TILE_HOST_DEVICE_EXTERN composes(Ts &&...) -> composes< remove_cvref_t< Ts >... >
constexpr CK_TILE_HOST_DEVICE auto transform_tensor_view(const OldTensorView &old_tensor_view, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition: tensor_view.hpp:526
bfloat16_t bf16_t
Definition: bfloat16.hpp:113
int32_t index_t
Definition: integer.hpp:9
constexpr CK_TILE_HOST_DEVICE auto pad_tensor_view(const TensorView &tensor_view, const TileLengths &tile_lengths, DoPads)
Definition: tensor_view.hpp:545
constexpr CK_TILE_HOST_DEVICE auto make_pass_through_transform(const LowLength &low_length)
Definition: coordinate_transform.hpp:1633
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition: type_traits.hpp:21
int64_t long_index_t
Definition: integer.hpp:11
int32_t int32_t
Definition: integer.hpp:10
constexpr CK_TILE_DEVICE auto make_tile_window(null_tensor_view, const WindowLengths &window_lengths, const multi_index< WindowLengths::size()> &, Ts &&...)
Definition: null_tile_window.hpp:75
unsigned _BitInt(8) bf8_t
Definition: float8.hpp:206
constexpr CK_TILE_HOST_DEVICE auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:360
GenericAttentionMaskEnum
Definition: block_masking.hpp:11
constexpr CK_TILE_HOST_DEVICE T max(T x)
Definition: math.hpp:157
typename conditional< predicate, X, Y >::type conditional_t
Definition: functional.hpp:115
Definition: block_position_encoding.hpp:48
Definition: block_attention_bias_enum.hpp:19
Definition: block_position_encoding.hpp:137
Definition: fmha_fwd_pagedkv_kernel.hpp:282
ck_tile::index_t qo_head_idx
Definition: fmha_fwd_pagedkv_kernel.hpp:284
ck_tile::index_t batch_idx
Definition: fmha_fwd_pagedkv_kernel.hpp:283
ck_tile::index_t kv_head_idx
Definition: fmha_fwd_pagedkv_kernel.hpp:285
Definition: fmha_fwd_pagedkv_kernel.hpp:228
const int32_t * cache_batch_idx
Definition: fmha_fwd_pagedkv_kernel.hpp:229
Definition: fmha_fwd_pagedkv_kernel.hpp:216
ck_tile::index_t batch_stride_block_table
Definition: fmha_fwd_pagedkv_kernel.hpp:218
const int32_t * block_table_ptr
Definition: fmha_fwd_pagedkv_kernel.hpp:217
ck_tile::index_t page_block_size
Definition: fmha_fwd_pagedkv_kernel.hpp:219
Definition: fmha_fwd_pagedkv_kernel.hpp:184
ck_tile::index_t alibi_slope_stride
Definition: fmha_fwd_pagedkv_kernel.hpp:187
const void * alibi_slope_ptr
Definition: fmha_fwd_pagedkv_kernel.hpp:186
Definition: fmha_fwd_pagedkv_kernel.hpp:179
ck_tile::index_t batch_stride_bias
Definition: fmha_fwd_pagedkv_kernel.hpp:180
Definition: fmha_fwd_pagedkv_kernel.hpp:244
ck_tile::index_t batch_stride_k
Definition: fmha_fwd_pagedkv_kernel.hpp:248
ck_tile::index_t batch_stride_q
Definition: fmha_fwd_pagedkv_kernel.hpp:247
const int32_t * seqlen_k_ptr
Definition: fmha_fwd_pagedkv_kernel.hpp:245
ck_tile::index_t batch_stride_o
Definition: fmha_fwd_pagedkv_kernel.hpp:252
ck_tile::index_t batch_stride_v
Definition: fmha_fwd_pagedkv_kernel.hpp:250
Definition: fmha_fwd_pagedkv_kernel.hpp:172
ck_tile::index_t stride_bias
Definition: fmha_fwd_pagedkv_kernel.hpp:174
const void * bias_ptr
Definition: fmha_fwd_pagedkv_kernel.hpp:173
ck_tile::index_t nhead_stride_bias
Definition: fmha_fwd_pagedkv_kernel.hpp:175
Definition: fmha_fwd_pagedkv_kernel.hpp:121
ck_tile::index_t hdim_v
Definition: fmha_fwd_pagedkv_kernel.hpp:130
ck_tile::index_t seqlen_q
Definition: fmha_fwd_pagedkv_kernel.hpp:127
ck_tile::index_t nhead_stride_k
Definition: fmha_fwd_pagedkv_kernel.hpp:144
ck_tile::index_t stride_o
Definition: fmha_fwd_pagedkv_kernel.hpp:141
const void * k_ptr
Definition: fmha_fwd_pagedkv_kernel.hpp:123
ck_tile::index_t nhead_stride_q
Definition: fmha_fwd_pagedkv_kernel.hpp:143
ck_tile::index_t stride_v
Definition: fmha_fwd_pagedkv_kernel.hpp:140
float scale_s
Definition: fmha_fwd_pagedkv_kernel.hpp:136
ck_tile::index_t stride_k
Definition: fmha_fwd_pagedkv_kernel.hpp:139
const void * v_ptr
Definition: fmha_fwd_pagedkv_kernel.hpp:124
ck_tile::index_t nhead_stride_v
Definition: fmha_fwd_pagedkv_kernel.hpp:145
ck_tile::index_t nhead_stride_o
Definition: fmha_fwd_pagedkv_kernel.hpp:146
ck_tile::index_t hdim_q
Definition: fmha_fwd_pagedkv_kernel.hpp:129
ck_tile::index_t seqlen_k
Definition: fmha_fwd_pagedkv_kernel.hpp:128
const void * q_ptr
Definition: fmha_fwd_pagedkv_kernel.hpp:122
ck_tile::index_t num_head_q
Definition: fmha_fwd_pagedkv_kernel.hpp:132
ck_tile::index_t nhead_ratio_qk
Definition: fmha_fwd_pagedkv_kernel.hpp:135
void * o_ptr
Definition: fmha_fwd_pagedkv_kernel.hpp:125
ck_tile::index_t stride_q
Definition: fmha_fwd_pagedkv_kernel.hpp:138
Definition: fmha_fwd_pagedkv_kernel.hpp:204
ck_tile::index_t batch_stride_lse
Definition: fmha_fwd_pagedkv_kernel.hpp:207
void * lse_ptr
Definition: fmha_fwd_pagedkv_kernel.hpp:205
ck_tile::index_t nhead_stride_lse
Definition: fmha_fwd_pagedkv_kernel.hpp:206
Definition: fmha_fwd_pagedkv_kernel.hpp:114
Definition: fmha_fwd_pagedkv_kernel.hpp:198
float scale_p
Definition: fmha_fwd_pagedkv_kernel.hpp:199
float scale_o
Definition: fmha_fwd_pagedkv_kernel.hpp:200
Definition: fmha_fwd_pagedkv_kernel.hpp:268
const int32_t * seqlen_k_ptr
Definition: fmha_fwd_pagedkv_kernel.hpp:271
const int32_t * seqstart_k_ptr
Definition: fmha_fwd_pagedkv_kernel.hpp:270
ck_tile::index_t batch_stride_k
Definition: fmha_fwd_pagedkv_kernel.hpp:273
ck_tile::index_t batch_stride_v
Definition: fmha_fwd_pagedkv_kernel.hpp:275
const int32_t * seqstart_q_ptr
Definition: fmha_fwd_pagedkv_kernel.hpp:269
Definition: fmha_fwd_pagedkv_kernel.hpp:150
float logits_soft_cap
Definition: fmha_fwd_pagedkv_kernel.hpp:167
void init_logits_soft_cap(float logits_soft_cap_)
Definition: fmha_fwd_pagedkv_kernel.hpp:153
float logits_soft_cap_rcp
Definition: fmha_fwd_pagedkv_kernel.hpp:168
Definition: fmha_fwd_pagedkv_kernel.hpp:191
ck_tile::index_t window_size_left
Definition: fmha_fwd_pagedkv_kernel.hpp:193
ck_tile::index_t sink_size
Definition: fmha_fwd_pagedkv_kernel.hpp:193
ck_tile::index_t window_size_right
Definition: fmha_fwd_pagedkv_kernel.hpp:193
ck_tile::GenericAttentionMaskEnum mask_type
Definition: fmha_fwd_pagedkv_kernel.hpp:194
Definition: fmha_fwd_pagedkv_kernel.hpp:211
ck_tile::index_t min_seqlen_q
Definition: fmha_fwd_pagedkv_kernel.hpp:212
Definition: fmha_fwd_pagedkv_kernel.hpp:223
bool is_gappy
Definition: fmha_fwd_pagedkv_kernel.hpp:224
Definition: fmha_fwd_pagedkv_kernel.hpp:67
Definition: fmha_fwd_pagedkv_kernel.hpp:28
static constexpr bool kHasSink
Definition: fmha_fwd_pagedkv_kernel.hpp:58
static constexpr bool kIsGroupMode
Definition: fmha_fwd_pagedkv_kernel.hpp:47
static constexpr CK_TILE_HOST std::enable_if_t< Cond, Kargs > MakeKargsImpl(const void *q_ptr, const void *k_ptr, const void *v_ptr, const void *bias_ptr, void *lse_ptr, void *o_ptr, const void *seqstart_q_ptr, const void *seqstart_k_ptr, const void *seqlen_k_ptr, ck_tile::index_t hdim_q, ck_tile::index_t hdim_v, ck_tile::index_t num_head_q, ck_tile::index_t nhead_ratio_qk, const void *block_table_ptr, ck_tile::index_t batch_stride_block_table, ck_tile::index_t page_block_size, bool is_gappy, float scale_s, float scale_p, float scale_o, float logits_soft_cap, ck_tile::index_t stride_q, ck_tile::index_t stride_k, ck_tile::index_t stride_v, ck_tile::index_t stride_bias, ck_tile::index_t stride_o, ck_tile::index_t nhead_stride_q, ck_tile::index_t nhead_stride_k, ck_tile::index_t nhead_stride_v, ck_tile::index_t nhead_stride_bias, ck_tile::index_t nhead_stride_lse, ck_tile::index_t nhead_stride_o, ck_tile::index_t batch_stride_k, ck_tile::index_t batch_stride_v, ck_tile::index_t window_size_left, ck_tile::index_t window_size_right, ck_tile::index_t sink_size, ck_tile::index_t mask_type, ck_tile::index_t min_seqlen_q)
Definition: fmha_fwd_pagedkv_kernel.hpp:508
static constexpr bool kHasLogitsSoftCap
Definition: fmha_fwd_pagedkv_kernel.hpp:52
ck_tile::remove_cvref_t< EpiloguePipeline_ > EpiloguePipeline
Definition: fmha_fwd_pagedkv_kernel.hpp:30
static constexpr ck_tile::index_t kBlockPerCu
Definition: fmha_fwd_pagedkv_kernel.hpp:32
static constexpr bool kStoreLSE
Definition: fmha_fwd_pagedkv_kernel.hpp:54
static CK_TILE_HOST std::string GetName()
Definition: fmha_fwd_pagedkv_kernel.hpp:75
static constexpr bool kPadSeqLenK
Definition: fmha_fwd_pagedkv_kernel.hpp:49
static constexpr bool kIsPagedKV
Definition: fmha_fwd_pagedkv_kernel.hpp:57
static constexpr bool kSkipMinSeqlenQ
Definition: fmha_fwd_pagedkv_kernel.hpp:56
static CK_TILE_HOST void PrintParameters(const Kargs &kargs, int num_batches)
Definition: fmha_fwd_pagedkv_kernel.hpp:715
static constexpr ck_tile::index_t kBlockSize
Definition: fmha_fwd_pagedkv_kernel.hpp:31
static constexpr bool kPadHeadDimV
Definition: fmha_fwd_pagedkv_kernel.hpp:51
ck_tile::remove_cvref_t< typename FmhaPipeline::LSEDataType > LSEDataType
Definition: fmha_fwd_pagedkv_kernel.hpp:41
static CK_TILE_HOST dim3 BlockSize()
Definition: fmha_fwd_pagedkv_kernel.hpp:877
ck_tile::remove_cvref_t< typename FmhaPipeline::VDataType > VDataType
Definition: fmha_fwd_pagedkv_kernel.hpp:39
static constexpr CK_TILE_DEVICE auto GetTileIndex(const Kargs &kargs)
Definition: fmha_fwd_pagedkv_kernel.hpp:812
static constexpr ck_tile::index_t kBlockPerCuInput
Definition: fmha_fwd_pagedkv_kernel.hpp:35
static constexpr bool kUseAsyncCopy
Definition: fmha_fwd_pagedkv_kernel.hpp:64
static constexpr bool kHasMask
Definition: fmha_fwd_pagedkv_kernel.hpp:62
static constexpr CK_TILE_HOST auto GridSize(ck_tile::index_t batch_size_, ck_tile::index_t nhead_, ck_tile::index_t seqlen_q_, ck_tile::index_t hdim_v_, bool has_padded_seqlen_k)
Definition: fmha_fwd_pagedkv_kernel.hpp:787
ck_tile::remove_cvref_t< FmhaPipeline_ > FmhaPipeline
Definition: fmha_fwd_pagedkv_kernel.hpp:29
ck_tile::remove_cvref_t< typename FmhaPipeline::ODataType > ODataType
Definition: fmha_fwd_pagedkv_kernel.hpp:42
static constexpr CK_TILE_HOST std::enable_if_t< Cond, Kargs > MakeKargs(const void *q_ptr, const void *k_ptr, const void *v_ptr, const void *bias_ptr, void *lse_ptr, void *o_ptr, ck_tile::index_t seqlen_q, ck_tile::index_t seqlen_k, const void *seqlen_k_ptr, ck_tile::index_t hdim_q, ck_tile::index_t hdim_v, ck_tile::index_t num_head_q, ck_tile::index_t nhead_ratio_qk, const void *block_table_ptr, ck_tile::index_t batch_stride_block_table, ck_tile::index_t page_block_size, const void *cache_batch_idx, float scale_s, float scale_p, float scale_o, float logits_soft_cap, ck_tile::index_t stride_q, ck_tile::index_t stride_k, ck_tile::index_t stride_v, ck_tile::index_t stride_bias, ck_tile::index_t stride_o, ck_tile::index_t nhead_stride_q, ck_tile::index_t nhead_stride_k, ck_tile::index_t nhead_stride_v, ck_tile::index_t nhead_stride_bias, ck_tile::index_t nhead_stride_lse, ck_tile::index_t nhead_stride_o, ck_tile::index_t batch_stride_q, ck_tile::index_t batch_stride_k, ck_tile::index_t batch_stride_v, ck_tile::index_t batch_stride_bias, ck_tile::index_t batch_stride_lse, ck_tile::index_t batch_stride_o, ck_tile::index_t window_size_left, ck_tile::index_t window_size_right, ck_tile::index_t sink_size, ck_tile::index_t mask_type)
Definition: fmha_fwd_pagedkv_kernel.hpp:419
ck_tile::remove_cvref_t< typename FmhaPipeline::BiasDataType > BiasDataType
Definition: fmha_fwd_pagedkv_kernel.hpp:40
static constexpr CK_TILE_HOST std::enable_if_t< Cond, Kargs > MakeKargs(const void *q_ptr, const void *k_ptr, const void *v_ptr, const void *bias_ptr, void *lse_ptr, void *o_ptr, const void *seqstart_q_ptr, const void *seqstart_k_ptr, const void *seqlen_k_ptr, ck_tile::index_t hdim_q, ck_tile::index_t hdim_v, ck_tile::index_t num_head_q, ck_tile::index_t nhead_ratio_qk, const void *block_table_ptr, ck_tile::index_t batch_stride_block_table, ck_tile::index_t page_block_size, bool is_gappy, float scale_s, float scale_p, float scale_o, float logits_soft_cap, ck_tile::index_t stride_q, ck_tile::index_t stride_k, ck_tile::index_t stride_v, ck_tile::index_t stride_bias, ck_tile::index_t stride_o, ck_tile::index_t nhead_stride_q, ck_tile::index_t nhead_stride_k, ck_tile::index_t nhead_stride_v, ck_tile::index_t nhead_stride_bias, ck_tile::index_t nhead_stride_lse, ck_tile::index_t nhead_stride_o, ck_tile::index_t batch_stride_k, ck_tile::index_t batch_stride_v, ck_tile::index_t window_size_left, ck_tile::index_t window_size_right, ck_tile::index_t sink_size, ck_tile::index_t mask_type, ck_tile::index_t min_seqlen_q)
Definition: fmha_fwd_pagedkv_kernel.hpp:634
static constexpr bool kPadHeadDimQ
Definition: fmha_fwd_pagedkv_kernel.hpp:50
ck_tile::remove_cvref_t< typename FmhaPipeline::VLayout > VLayout
Definition: fmha_fwd_pagedkv_kernel.hpp:45
static constexpr bool kPadSeqLenQ
Definition: fmha_fwd_pagedkv_kernel.hpp:48
ck_tile::remove_cvref_t< typename FmhaPipeline::SaccDataType > SaccDataType
Definition: fmha_fwd_pagedkv_kernel.hpp:43
ck_tile::remove_cvref_t< typename FmhaPipeline::FmhaMask > FmhaMask
Definition: fmha_fwd_pagedkv_kernel.hpp:61
static constexpr CK_TILE_HOST_DEVICE ck_tile::index_t GetSmemSize()
Definition: fmha_fwd_pagedkv_kernel.hpp:889
CK_TILE_DEVICE void operator()(Kargs kargs) const
Definition: fmha_fwd_pagedkv_kernel.hpp:894
ck_tile::remove_cvref_t< typename FmhaPipeline::AttentionVariant > AttentionVariant
Definition: fmha_fwd_pagedkv_kernel.hpp:60
ck_tile::remove_cvref_t< typename FmhaPipeline::KDataType > KDataType
Definition: fmha_fwd_pagedkv_kernel.hpp:38
ck_tile::remove_cvref_t< typename FmhaPipeline::QDataType > QDataType
Definition: fmha_fwd_pagedkv_kernel.hpp:37
static constexpr bool kDoFp8StaticQuant
Definition: fmha_fwd_pagedkv_kernel.hpp:55
std::conditional_t< kIsGroupMode, FmhaFwdGroupModeKargs, FmhaFwdBatchModeKargs > Kargs
Definition: fmha_fwd_pagedkv_kernel.hpp:279
static constexpr auto BiasEnum
Definition: fmha_fwd_pagedkv_kernel.hpp:53
static constexpr CK_TILE_HOST std::enable_if_t< Cond, Kargs > MakeKargsImpl(const void *q_ptr, const void *k_ptr, const void *v_ptr, const void *bias_ptr, void *lse_ptr, void *o_ptr, ck_tile::index_t seqlen_q, ck_tile::index_t seqlen_k, const void *seqlen_k_ptr, ck_tile::index_t hdim_q, ck_tile::index_t hdim_v, ck_tile::index_t num_head_q, ck_tile::index_t nhead_ratio_qk, const void *block_table_ptr, ck_tile::index_t batch_stride_block_table, ck_tile::index_t page_block_size, const void *cache_batch_idx, float scale_s, float scale_p, float scale_o, float logits_soft_cap, ck_tile::index_t stride_q, ck_tile::index_t stride_k, ck_tile::index_t stride_v, ck_tile::index_t stride_bias, ck_tile::index_t stride_o, ck_tile::index_t nhead_stride_q, ck_tile::index_t nhead_stride_k, ck_tile::index_t nhead_stride_v, ck_tile::index_t nhead_stride_bias, ck_tile::index_t nhead_stride_lse, ck_tile::index_t nhead_stride_o, ck_tile::index_t batch_stride_q, ck_tile::index_t batch_stride_k, ck_tile::index_t batch_stride_v, ck_tile::index_t batch_stride_bias, ck_tile::index_t batch_stride_lse, ck_tile::index_t batch_stride_o, ck_tile::index_t window_size_left, ck_tile::index_t window_size_right, ck_tile::index_t sink_size, ck_tile::index_t mask_type)
Definition: fmha_fwd_pagedkv_kernel.hpp:290
Definition: variants.hpp:63
float logits_soft_cap
Definition: variants.hpp:128
Definition: variants.hpp:51
Definition: integral_constant.hpp:13
Definition: functional.hpp:114
Definition: unary_element_function.hpp:55
Definition: math.hpp:28
Definition: sequence.hpp:49