/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp Source File
fmha_fwd_splitkv_kernel.hpp
Go to the documentation of this file.
1 // Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
2 // SPDX-License-Identifier: MIT
3 
4 #pragma once
5 
6 #include "ck_tile/core.hpp"
7 #include "ck_tile/ops/common.hpp"
10 
11 #include <string>
12 #include <type_traits>
13 
14 // S[seqlen_q, seqlen_k] = Q[seqlen_q, hdim_q] @ K[seqlen_k, hdim_q]
15 // S'[seqlen_q, seqlen_k] = S[seqlen_q, seqlen_k] * Scale[1]
16 // S''[seqlen_q, seqlen_k] = S'[seqlen_q, seqlen_k] + Bias[seqlen_q, seqlen_k]
17 // P[seqlen_q, seqlen_k] = Softmax(S''[seqlen_q, seqlen_k])
18 // O[seqlen_q, hdim_v] = P[seqlen_q, seqlen_k] @ V^T[hdim_v, seqlen_k]
19 
20 namespace ck_tile {
21 
22 template <typename FmhaPipeline_, typename EpiloguePipeline_>
24 {
27  static constexpr ck_tile::index_t kBlockSize = FmhaPipeline::kBlockSize;
28  static constexpr ck_tile::index_t kBlockPerCu = FmhaPipeline::kBlockPerCu;
29 
30  static_assert(kBlockPerCu > 0);
31  static constexpr ck_tile::index_t kBlockPerCuInput = FmhaPipeline::Problem::kBlockPerCu;
32 
41 
43 
44  static constexpr bool kIsGroupMode = FmhaPipeline::kIsGroupMode;
45  static constexpr bool kPadSeqLenQ = FmhaPipeline::kPadSeqLenQ;
46  static constexpr bool kPadSeqLenK = FmhaPipeline::kPadSeqLenK;
47  static constexpr bool kPadHeadDimQ = FmhaPipeline::kPadHeadDimQ;
48  static constexpr bool kPadHeadDimV = FmhaPipeline::kPadHeadDimV;
49  static constexpr bool kHasLogitsSoftCap = FmhaPipeline::kHasLogitsSoftCap;
50  static constexpr auto BiasEnum = FmhaPipeline::BiasEnum;
51  static constexpr bool kStoreLSE = FmhaPipeline::kStoreLSE;
52  static constexpr bool kDoFp8StaticQuant = FmhaPipeline::Problem::kDoFp8StaticQuant;
53  static constexpr bool kIsPagedKV = FmhaPipeline::Problem::kIsPagedKV;
54  static constexpr bool kHasSink = FmhaPipeline::Problem::kHasSink;
55  static constexpr bool kMergeNumHeadGroupsSeqLenQ =
56  FmhaPipeline::Problem::kMergeNumHeadGroupsSeqLenQ;
59  static constexpr bool kHasMask = FmhaMask::IsMasking;
60 
61  static_assert(!kMergeNumHeadGroupsSeqLenQ ||
63  !kHasMask));
64 
65  // clang-format off
66  template <typename T> struct t2s;
67  template <> struct t2s<float> { static constexpr const char * name = "fp32"; };
68  template <> struct t2s<ck_tile::fp16_t> { static constexpr const char * name = "fp16"; };
69  template <> struct t2s<ck_tile::bf16_t> { static constexpr const char * name = "bf16"; };
70  template <> struct t2s<ck_tile::fp8_t> { static constexpr const char * name = "fp8"; };
71  template <> struct t2s<ck_tile::bf8_t> { static constexpr const char * name = "bf8"; };
72  // clang-format on
73 
74  CK_TILE_HOST static std::string GetName()
75  {
76  // sync with generate.py
77  // clang-format off
78  using bfs = typename FmhaPipeline::BlockFmhaShape;
79  using g0br = typename bfs::Gemm0BlockWarps;
80  using g1br = typename bfs::Gemm1BlockWarps;
81  using g0wt = typename bfs::Gemm0WarpTile;
82  using g1wt = typename bfs::Gemm1WarpTile;
83  #define _SS_ std::string
84  #define _TS_ std::to_string
85  auto pn = [&] () {
86  std::string n;
87  if (kPadSeqLenQ) n += "s";
88  if (kPadSeqLenK) n += "sk";
89  if (kPadHeadDimQ) n += "d";
90  if (kPadHeadDimV) n += "dv";
91  return n.empty() ? n : std::string("p") + n; }();
92  return
93  _SS_("fmha_fwd_splitkv_d") + _TS_(bfs::kQKHeaddim) + "_" + _SS_(t2s<QDataType>::name) +
94  "_" + (kIsGroupMode ? "group" : "batch") + "_"
95  "b" + _TS_(bfs::kM0) + "x" + _TS_(bfs::kN0) + "x" + _TS_(bfs::kK0) + "x" +
96  _TS_(bfs::kN1) + "x" + _TS_(bfs::kK1) + "x" + _TS_(bfs::kQKHeaddim) + "_" +
97  "r" + _TS_(g0br::at(ck_tile::number<0>{})) + "x" + _TS_(g0br::at(ck_tile::number<1>{})) + "x" + _TS_(g0br::at(ck_tile::number<2>{})) + "_" +
98  "r" + _TS_(g1br::at(ck_tile::number<0>{})) + "x" + _TS_(g1br::at(ck_tile::number<1>{})) + "x" + _TS_(g1br::at(ck_tile::number<2>{})) + "_" +
99  "w" + _TS_(g0wt::at(ck_tile::number<0>{})) + "x" + _TS_(g0wt::at(ck_tile::number<1>{})) + "x" + _TS_(g0wt::at(ck_tile::number<2>{})) + "_" +
100  "w" + _TS_(g1wt::at(ck_tile::number<0>{})) + "x" + _TS_(g1wt::at(ck_tile::number<1>{})) + "x" + _TS_(g1wt::at(ck_tile::number<2>{})) + "_" +
101  (kBlockPerCuInput == -1 ? "" : ("o" + _TS_(kBlockPerCu) + "_")) + _SS_(FmhaPipeline::name) + "_" +
102  "v" + (std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor> ? "r" : "c") + (pn.empty() ? "_npad" : "_" + pn) +
103  (kHasLogitsSoftCap ? "_logits" : "_nlogits" ) + (BiasEnum == BlockAttentionBiasEnum::NO_BIAS ? _SS_("_nbias") : (_SS_("_") + BlockAttentionBiasEnumToStr<BiasEnum>::name)) +
104  (kHasMask ? "_" + _SS_(FmhaMask::name) : "_nmask") + (kStoreLSE ? "_lse" : "_nlse" ) +
105  (kDoFp8StaticQuant ? "_squant" : "_nsquant") + (kIsPagedKV ? "_pagedkv" : "_npagedkv" ) + (kHasSink ? "_sink" : "_nsink" );
106  #undef _SS_
107  #undef _TS_
108  // clang-format on
109  }
110 
111  template <ck_tile::index_t I> // to avoid duplicated base class prblem, introduce an template
112  // arg
113  struct EmptyKargs
114  {
115  };
116 
117  // kargs use aggregate initializer, so no constructor will provided
118  // use inheritance to minimize karg size
119  // user need to use MakeKargs() function to create kargs.
120  struct CommonKargs
121  {
122  const void* q_ptr;
123  const void* k_ptr;
124  const void* v_ptr;
125  void* lse_acc_ptr;
126  void* o_acc_ptr;
127 
129 
134 
136  // for MQA/GQA, nhead could be different. This parameter is nhead_q / nhead_k
137  // if this param is larger than 1, indicate MQA/GQA case
140 
141  float scale_s;
142 
147 
153 
156  };
157 
159  {
160  LogitsSoftCapKargs() = default;
161 
162  void init_logits_soft_cap(float logits_soft_cap_)
163  {
164  if(0 < logits_soft_cap_)
165  {
166  logits_soft_cap = logits_soft_cap_;
168  }
169  else
170  {
171  logits_soft_cap = 0.f;
172  logits_soft_cap_rcp = 0.f;
173  }
174  }
175 
178  };
179 
181  {
182  const void* bias_ptr = nullptr;
185  };
186 
188  {
190  };
191 
192  struct AlibiKargs
193  {
194  // alibi is batch*nhead*1, no matter in batch/group mode, they are the same
195  const void* alibi_slope_ptr;
196  ck_tile::index_t alibi_slope_stride; // stride in batch, or 0 for all batch share same slope
197  };
198 
199  struct MaskKargs
200  {
201  // ck_tile::index_t window_size_left, window_size_right;
204  };
205 
207  {
208  float scale_p;
209  };
210 
212  {
216  };
217 
219  {
220  bool is_gappy = false;
221  };
222 
224  {
226  };
227 
229  : CommonKargs,
230  std::conditional_t<BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS,
231  BatchModeBiasKargs,
232  std::conditional_t<BiasEnum == BlockAttentionBiasEnum::ALIBI,
233  AlibiKargs,
234  EmptyKargs<0>>>,
235  std::conditional_t<kHasMask, MaskKargs, EmptyKargs<1>>,
236  std::conditional_t<kDoFp8StaticQuant, Fp8StaticQuantKargs, EmptyKargs<2>>,
237  std::conditional_t<kIsPagedKV, CommonPageBlockTableKargs, CacheBatchIdxKargs>,
238  std::conditional_t<kHasLogitsSoftCap, LogitsSoftCapKargs, EmptyKargs<3>>
239  {
241 
243  ck_tile::index_t batch_stride_k; // when using paged-kvcache, this will be stride/size for
244  // single kcache page-block
245  ck_tile::index_t batch_stride_v; // when using paged-kvcache, this will be stride/size for
246  // single vcache page-block
249  };
250 
252  : CommonKargs,
253  std::conditional_t<BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS,
254  CommonBiasKargs,
255  std::conditional_t<BiasEnum == BlockAttentionBiasEnum::ALIBI,
256  AlibiKargs,
257  EmptyKargs<0>>>,
258  std::conditional_t<kHasMask, MaskKargs, EmptyKargs<1>>,
259  std::conditional_t<kDoFp8StaticQuant, Fp8StaticQuantKargs, EmptyKargs<2>>,
260  std::conditional_t<kIsPagedKV, GroupModePageBlockTableKargs, EmptyKargs<3>>,
261  std::conditional_t<kHasLogitsSoftCap, LogitsSoftCapKargs, EmptyKargs<4>>
262  {
266 
267  ck_tile::index_t batch_stride_k; // only used for paged-kvcache, this will be stride/size
268  // for single kcache page-block
269  ck_tile::index_t batch_stride_v; // only used for paged-kvcache, this will be stride/size
270  // for single vcache page-block
271  };
272 
273  using Kargs = std::conditional_t<kIsGroupMode, GroupModeKargs, BatchModeKargs>;
274 
276  {
280  };
281 
282  template <bool Cond = !kIsGroupMode>
283  CK_TILE_HOST static constexpr std::enable_if_t<Cond, Kargs>
284  MakeKargs(const void* q_ptr,
285  const void* k_ptr,
286  const void* v_ptr,
287  const void* bias_ptr,
288  void* lse_acc_ptr, /* workspace for lse accumulation when num_splits > 1, otherwise
289  final lse */
290  void* o_acc_ptr, /* workspace for o accumulation when num_splits > 1, otherwise final
291  o */
292  ck_tile::index_t batch,
293  ck_tile::index_t seqlen_q,
294  ck_tile::index_t seqlen_k, // only used if 'seqlen_k_ptr' is not specified
295  const void* seqlen_k_ptr, // only used for (paged-) kvcache
296  ck_tile::index_t hdim_q,
297  ck_tile::index_t hdim_v,
298  ck_tile::index_t num_head_q,
299  ck_tile::index_t nhead_ratio_qk,
300  ck_tile::index_t num_splits,
301  const void* block_table_ptr,
302  ck_tile::index_t batch_stride_block_table,
303  ck_tile::index_t page_block_size,
304  const void* cache_batch_idx,
305  float scale_s,
306  float scale_p,
307  float logits_soft_cap,
308  ck_tile::index_t stride_q,
309  ck_tile::index_t stride_k,
310  ck_tile::index_t stride_v,
311  ck_tile::index_t stride_bias,
312  ck_tile::index_t stride_o_acc,
313  ck_tile::index_t nhead_stride_q,
314  ck_tile::index_t nhead_stride_k,
315  ck_tile::index_t nhead_stride_v,
316  ck_tile::index_t nhead_stride_bias,
317  ck_tile::index_t nhead_stride_lse_acc,
318  ck_tile::index_t nhead_stride_o_acc,
319  ck_tile::index_t batch_stride_q,
320  ck_tile::index_t batch_stride_k,
321  ck_tile::index_t batch_stride_v,
322  ck_tile::index_t batch_stride_bias,
323  ck_tile::index_t batch_stride_lse_acc,
324  ck_tile::index_t batch_stride_o_acc,
325  ck_tile::index_t split_stride_lse_acc,
326  ck_tile::index_t split_stride_o_acc,
327  ck_tile::index_t window_size_left,
328  ck_tile::index_t window_size_right,
329  ck_tile::index_t sink_size,
330  ck_tile::index_t mask_type)
331  {
332  Kargs kargs{{q_ptr,
333  k_ptr,
334  v_ptr,
335  lse_acc_ptr,
336  o_acc_ptr,
337  batch,
338  seqlen_q,
339  seqlen_k,
340  hdim_q,
341  hdim_v,
342  num_head_q,
343  nhead_ratio_qk,
344  num_splits,
345 #if CK_TILE_FMHA_FWD_FAST_EXP2
346  static_cast<float>(scale_s * ck_tile::log2e_v<>),
347 #else
348  scale_s,
349 #endif
350  stride_q,
351  stride_k,
352  stride_v,
353  stride_o_acc,
354  nhead_stride_q,
355  nhead_stride_k,
356  nhead_stride_v,
357  nhead_stride_lse_acc,
358  nhead_stride_o_acc,
359  split_stride_lse_acc,
360  split_stride_o_acc}, // args for common karg
361  {}, // placeholder for bias
362  {}, // placeholder for mask
363  {}, // placeholder for fp8_static_quant args
364  {}, // placeholder for paged-block table or cache_batch_idx
365  {}, // placeholder for logits_soft_cap
366  reinterpret_cast<const int32_t*>(seqlen_k_ptr),
367  batch_stride_q,
368  batch_stride_k,
369  batch_stride_v,
370  batch_stride_lse_acc,
371  batch_stride_o_acc};
372 
374  {
375  kargs.bias_ptr = bias_ptr;
376  kargs.stride_bias = stride_bias;
377  kargs.nhead_stride_bias = nhead_stride_bias;
378  kargs.batch_stride_bias = batch_stride_bias;
379  }
380  else if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI)
381  {
382  kargs.alibi_slope_ptr = bias_ptr;
383  kargs.alibi_slope_stride = stride_bias;
384  }
385  if constexpr(kHasMask)
386  {
387  kargs.window_size_left = window_size_left;
388  kargs.window_size_right = window_size_right;
389  kargs.sink_size = sink_size;
390  kargs.mask_type = static_cast<ck_tile::GenericAttentionMaskEnum>(mask_type);
391  }
392  if constexpr(kDoFp8StaticQuant)
393  {
394  kargs.scale_p = scale_p;
395  }
396  if constexpr(kIsPagedKV)
397  {
398  kargs.block_table_ptr = reinterpret_cast<const int32_t*>(block_table_ptr);
399  kargs.batch_stride_block_table = batch_stride_block_table;
400  kargs.page_block_size = page_block_size;
401  }
402  else
403  {
404  kargs.cache_batch_idx = reinterpret_cast<const int32_t*>(cache_batch_idx);
405  }
406  if constexpr(kHasLogitsSoftCap)
407  {
408  kargs.init_logits_soft_cap(logits_soft_cap);
409  }
410 
411  return kargs;
412  }
413 
414  template <bool Cond = kIsGroupMode>
415  CK_TILE_HOST static constexpr std::enable_if_t<Cond, Kargs>
416  MakeKargs(const void* q_ptr,
417  const void* k_ptr,
418  const void* v_ptr,
419  const void* bias_ptr,
420  void* lse_acc_ptr, /* workspace for lse accumulation when num_splits > 1, otherwise
421  final lse */
422  void* o_acc_ptr, /* workspace for o accumulation when num_splits > 1, otherwise final
423  o */
424  ck_tile::index_t batch,
425  const void* seqstart_q_ptr,
426  const void* seqstart_k_ptr,
427  const void* seqlen_k_ptr,
428  ck_tile::index_t hdim_q,
429  ck_tile::index_t hdim_v,
430  ck_tile::index_t num_head_q,
431  ck_tile::index_t nhead_ratio_qk,
432  ck_tile::index_t num_splits,
433  const void* block_table_ptr,
434  ck_tile::index_t batch_stride_block_table,
435  ck_tile::index_t page_block_size,
436  bool is_gappy,
437  float scale_s,
438  float scale_p,
439  float logits_soft_cap,
440  ck_tile::index_t stride_q,
441  ck_tile::index_t stride_k,
442  ck_tile::index_t stride_v,
443  ck_tile::index_t stride_bias,
444  ck_tile::index_t stride_o_acc,
445  ck_tile::index_t nhead_stride_q,
446  ck_tile::index_t nhead_stride_k,
447  ck_tile::index_t nhead_stride_v,
448  ck_tile::index_t nhead_stride_bias,
449  ck_tile::index_t nhead_stride_lse_acc,
450  ck_tile::index_t nhead_stride_o_acc,
451  ck_tile::index_t batch_stride_k, // only used for paged-kvcache
452  ck_tile::index_t batch_stride_v, // only used for paged-kvcache
453  ck_tile::index_t split_stride_lse_acc,
454  ck_tile::index_t split_stride_o_acc,
455  ck_tile::index_t window_size_left,
456  ck_tile::index_t window_size_right,
457  ck_tile::index_t sink_size,
458  ck_tile::index_t mask_type)
459  {
460  Kargs kargs{{q_ptr,
461  k_ptr,
462  v_ptr,
463  lse_acc_ptr,
464  o_acc_ptr,
465  batch,
466  -1, // seqlen_q will be updated by another pointer
467  -1, // seqlen_k will be updated by another pointer
468  hdim_q,
469  hdim_v,
470  num_head_q,
471  nhead_ratio_qk,
472  num_splits,
473 #if CK_TILE_FMHA_FWD_FAST_EXP2
474  static_cast<float>(scale_s * ck_tile::log2e_v<>),
475 #else
476  scale_s,
477 #endif
478  stride_q,
479  stride_k,
480  stride_v,
481  stride_o_acc,
482  nhead_stride_q,
483  nhead_stride_k,
484  nhead_stride_v,
485  nhead_stride_lse_acc,
486  nhead_stride_o_acc,
487  split_stride_lse_acc,
488  split_stride_o_acc}, // args for common karg
489  {}, // placeholder for bias
490  {}, // placeholder for mask
491  {}, // placeholder for fp8_static_quant args
492  {}, // placeholder for paged-block table
493  {}, // placeholder for logits_soft_cap
494  reinterpret_cast<const int32_t*>(seqstart_q_ptr),
495  reinterpret_cast<const int32_t*>(seqstart_k_ptr),
496  reinterpret_cast<const int32_t*>(seqlen_k_ptr),
497  batch_stride_k,
498  batch_stride_v};
499 
501  {
502  kargs.bias_ptr = bias_ptr;
503  kargs.stride_bias = stride_bias;
504  kargs.nhead_stride_bias = nhead_stride_bias;
505  }
506  else if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI)
507  {
508  kargs.alibi_slope_ptr = bias_ptr;
509  kargs.alibi_slope_stride = stride_bias;
510  }
511  if constexpr(kHasMask)
512  {
513  kargs.window_size_left = window_size_left;
514  kargs.window_size_right = window_size_right;
515  kargs.sink_size = sink_size;
516  kargs.mask_type = static_cast<ck_tile::GenericAttentionMaskEnum>(mask_type);
517  }
518  if constexpr(kDoFp8StaticQuant)
519  {
520  kargs.scale_p = scale_p;
521  }
522  if constexpr(kIsPagedKV)
523  {
524  kargs.block_table_ptr = reinterpret_cast<const int32_t*>(block_table_ptr);
525  kargs.batch_stride_block_table = batch_stride_block_table;
526  kargs.page_block_size = page_block_size;
527  kargs.is_gappy = is_gappy;
528  }
529  if constexpr(kHasLogitsSoftCap)
530  {
531  kargs.init_logits_soft_cap(logits_soft_cap);
532  }
533 
534  return kargs;
535  }
536 
537  CK_TILE_HOST static constexpr auto GridSize(ck_tile::index_t batch_size,
538  ck_tile::index_t nhead_q,
539  ck_tile::index_t nhead_kv,
540  ck_tile::index_t max_seqlen_q,
541  ck_tile::index_t hdim_v,
542  ck_tile::index_t num_splits)
543  {
544  ck_tile::index_t nhead_ = kMergeNumHeadGroupsSeqLenQ ? nhead_kv : nhead_q;
545  ck_tile::index_t max_seqlen_q_ =
546  max_seqlen_q * (kMergeNumHeadGroupsSeqLenQ ? nhead_q / nhead_kv : 1);
547 
548  // TODO: this may need tuning
549  return dim3(ck_tile::integer_divide_ceil(max_seqlen_q_, FmhaPipeline::kM0) *
550  ck_tile::integer_divide_ceil(hdim_v, FmhaPipeline::kN1) * num_splits,
551  nhead_,
552  batch_size);
553  }
554 
555  CK_TILE_DEVICE static constexpr auto GetTileIndex(const Kargs& kargs)
556  {
557  const index_t num_tile_n1 = ck_tile::integer_divide_ceil(kargs.hdim_v, FmhaPipeline::kN1);
558 
559  const auto f = [](index_t dividend, index_t divisor) {
560  index_t quotient = dividend / divisor;
561  index_t modulus = dividend - quotient * divisor;
562  return ck_tile::make_tuple(quotient, modulus);
563  };
564 
565  const auto [mn, i_split] = f(blockIdx.x, kargs.num_splits);
566  const auto [i_tile_m, i_tile_n] = f(mn, num_tile_n1);
567  const index_t i_nhead = blockIdx.y;
568  const index_t i_batch = blockIdx.z;
569 
570  if constexpr(kHasMask)
571  {
572  // assume that num_tile_n1 is always 1
573  return ck_tile::make_tuple(
574  (gridDim.x / kargs.num_splits) - 1 - i_tile_m, i_tile_n, i_split, i_nhead, i_batch);
575  }
576  else
577  {
578  return ck_tile::make_tuple(i_tile_m, i_tile_n, i_split, i_nhead, i_batch);
579  }
580  }
581 
582  CK_TILE_HOST static dim3 BlockSize()
583  {
584  if(is_wave32())
585  {
586  return dim3(kBlockSize / 2);
587  }
588  else
589  {
590  return dim3(kBlockSize);
591  }
592  }
593 
595  {
596  return ck_tile::max(FmhaPipeline::GetSmemSize(), EpiloguePipeline::GetSmemSize());
597  }
598 
599  CK_TILE_DEVICE void operator()(Kargs kargs) const
600  {
601  // allocate LDS
602  __shared__ char smem_ptr[GetSmemSize()];
603 
604  // divide problem
605  const auto [i_tile_m, i_tile_n, i_split, i_nhead, i_batch] = GetTileIndex(kargs);
606 
607  const index_t i_m0 = amd_wave_read_first_lane(i_tile_m * FmhaPipeline::kM0);
608  const index_t i_n1 = amd_wave_read_first_lane(i_tile_n * FmhaPipeline::kN1);
609 
610  long_index_t batch_offset_q = 0;
611  long_index_t batch_offset_k = 0; // unused for paged-kvcache
612  long_index_t batch_offset_v = 0; // unused for paged-kvcache
613  long_index_t batch_offset_bias = 0;
614  long_index_t batch_offset_lse_acc = 0;
615  long_index_t batch_offset_o_acc = 0;
616  index_t kv_l2p_offset =
617  0; // logical-to-physical offset of seqlen_k coordinate. only used for paged-kvcache
618 
619  if constexpr(kIsGroupMode)
620  {
621  // get starting offset for each batch
622  const long_index_t query_start = kargs.seqstart_q_ptr[i_batch];
623  const long_index_t key_start = kargs.seqstart_k_ptr[i_batch];
624 
625  batch_offset_q = query_start * kargs.stride_q;
626  batch_offset_k = key_start * kargs.stride_k;
627  if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
628  {
629  batch_offset_v = key_start * kargs.stride_v;
630  }
631  else
632  {
633  batch_offset_v = key_start;
634  }
636  {
637  batch_offset_bias = query_start * kargs.stride_bias;
638  }
639 
640  batch_offset_lse_acc = query_start;
641  batch_offset_o_acc = query_start * kargs.stride_o_acc;
642 
643  // get real # queries & # keys under group mode
644  kargs.seqlen_q = kargs.seqstart_q_ptr[i_batch + 1] - kargs.seqstart_q_ptr[i_batch];
645 
646  // # of required blocks is different in each groups, terminate unnecessary blocks
647  // earlier
648  if(kargs.seqlen_q * (kMergeNumHeadGroupsSeqLenQ ? kargs.nhead_ratio_qk : 1) <= i_m0)
649  {
650  return;
651  }
652 
653  if(kargs.seqlen_k_ptr != nullptr)
654  {
655  kargs.seqlen_k = kargs.seqlen_k_ptr[i_batch];
656  }
657  else
658  {
659  kargs.seqlen_k = kargs.seqstart_k_ptr[i_batch + 1] - kargs.seqstart_k_ptr[i_batch];
660  }
661 
662  if constexpr(kIsPagedKV)
663  {
664  if(kargs.is_gappy)
665  {
666  // seqstart_k_ptr has different meaning in this case
667  kv_l2p_offset = kargs.seqstart_k_ptr[i_batch];
668  }
669  }
670  }
671  else
672  {
673  const index_t i_cache_batch = [&, i_batch_ = i_batch] {
674  if constexpr(kIsPagedKV)
675  {
676  return i_batch_;
677  }
678  else
679  {
680  return (kargs.cache_batch_idx != nullptr ? kargs.cache_batch_idx[i_batch_]
681  : i_batch_);
682  }
683  }();
684 
685  batch_offset_q = static_cast<long_index_t>(i_batch) * kargs.batch_stride_q;
686  batch_offset_k = static_cast<long_index_t>(i_cache_batch) * kargs.batch_stride_k;
687  batch_offset_v = static_cast<long_index_t>(i_cache_batch) * kargs.batch_stride_v;
688  batch_offset_lse_acc = static_cast<long_index_t>(i_batch) * kargs.batch_stride_lse_acc;
689  batch_offset_o_acc = static_cast<long_index_t>(i_batch) * kargs.batch_stride_o_acc;
690 
692  {
693  batch_offset_bias = static_cast<long_index_t>(i_batch) * kargs.batch_stride_bias;
694  }
695 
696  if(kargs.seqlen_k_ptr != nullptr)
697  {
698  kargs.seqlen_k = kargs.seqlen_k_ptr[i_batch];
699  }
700  }
701 
702  // for simplicity, batch stride we just modify the pointer
703  const index_t i_nhead_k =
704  (kMergeNumHeadGroupsSeqLenQ ? i_nhead : i_nhead / kargs.nhead_ratio_qk);
705 
706  const QDataType* q_ptr = reinterpret_cast<const QDataType*>(kargs.q_ptr) +
707  static_cast<long_index_t>(i_nhead) *
708  (kMergeNumHeadGroupsSeqLenQ ? kargs.nhead_ratio_qk : 1) *
709  kargs.nhead_stride_q +
710  batch_offset_q;
711  const KDataType* k_ptr = reinterpret_cast<const KDataType*>(kargs.k_ptr) +
712  static_cast<long_index_t>(i_nhead_k) * kargs.nhead_stride_k +
713  batch_offset_k;
714  const VDataType* v_ptr = reinterpret_cast<const VDataType*>(kargs.v_ptr) +
715  static_cast<long_index_t>(i_nhead_k) * kargs.nhead_stride_v +
716  batch_offset_v;
717 
718  ODataType* o_acc_ptr = reinterpret_cast<ODataType*>(kargs.o_acc_ptr) +
719  static_cast<long_index_t>(i_nhead) *
720  (kMergeNumHeadGroupsSeqLenQ ? kargs.nhead_ratio_qk : 1) *
721  kargs.nhead_stride_o_acc +
722  batch_offset_o_acc + i_split * kargs.split_stride_o_acc;
723 
724  // Q/K/V DRAM and DRAM window
725  const auto q_dram = [&] {
726  const auto q_dram_naive = [&] {
727  if constexpr(kMergeNumHeadGroupsSeqLenQ)
728  {
729  // reshape: (nhead_ratio_qk, seqlen_q, hdim_q) -> (nhead_ratio_qk * seqlen_q,
730  // hdim_q)
731  const auto view = make_naive_tensor_view<address_space_enum::global>(
732  q_ptr,
733  make_tuple(kargs.nhead_ratio_qk, kargs.seqlen_q, kargs.hdim_q),
734  make_tuple(kargs.nhead_stride_q, kargs.stride_q, 1),
736  number<1>{});
737 
738  return transform_tensor_view(
739  view,
740  make_tuple(
741  make_merge_transform(make_tuple(kargs.nhead_ratio_qk, kargs.seqlen_q)),
742  make_pass_through_transform(kargs.hdim_q)),
745  }
746  else
747  {
748  return make_naive_tensor_view<address_space_enum::global>(
749  q_ptr,
750  make_tuple(kargs.seqlen_q, kargs.hdim_q),
751  make_tuple(kargs.stride_q, 1),
753  number<1>{});
754  }
755  }();
756 
757  if constexpr(FmhaPipeline::kQLoadOnce)
758  {
759  return pad_tensor_view(
760  q_dram_naive,
763  }
764  else
765  {
766  return pad_tensor_view(
767  q_dram_naive,
770  }
771  }();
772 
773  const auto make_k_dram = [&](const KDataType* data, index_t height) {
774  const auto k_dram_naive = make_naive_tensor_view<address_space_enum::global>(
775  data, // will update this pointer if using paged-kvcache
776  make_tuple(height, kargs.hdim_q),
777  make_tuple(kargs.stride_k, 1),
779  number<1>{});
780 
781  return pad_tensor_view(
782  k_dram_naive,
785  };
786  const auto k_dram = [&]() {
787  if constexpr(kIsPagedKV)
788  {
789  return make_k_dram(nullptr, kargs.page_block_size);
790  }
791  else
792  {
793  return make_k_dram(k_ptr, kargs.seqlen_k);
794  }
795  }();
796 
797  const auto make_v_dram = [&](const VDataType* data, index_t length) {
798  if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
799  {
800  const auto v_dram_naive = make_naive_tensor_view<address_space_enum::global>(
801  data, // will update this pointer if using paged-kvcache
802  make_tuple(length, kargs.hdim_v),
803  make_tuple(kargs.stride_v, 1),
805  number<1>{});
806 
807  const auto v_dram_transposed =
808  transform_tensor_view(v_dram_naive,
813 
814  return pad_tensor_view(
815  v_dram_transposed,
818  }
819  else
820  {
821  const auto v_dram_naive = make_naive_tensor_view<address_space_enum::global>(
822  data, // will update this pointer if using paged-kvcache
823  make_tuple(kargs.hdim_v, length),
824  make_tuple(kargs.stride_v, 1),
826  number<1>{});
827 
828  return pad_tensor_view(
829  v_dram_naive,
832  }
833  };
834  const auto v_dram = [&]() {
835  if constexpr(kIsPagedKV)
836  {
837  return make_v_dram(nullptr, kargs.page_block_size);
838  }
839  else
840  {
841  return make_v_dram(v_ptr, kargs.seqlen_k);
842  }
843  }();
844 
845  auto k_page_block_navigator = [&, i_batch_ = i_batch]() {
846  if constexpr(kIsPagedKV)
847  {
848  const auto* block_indices =
849  reinterpret_cast<const int32_t*>(kargs.block_table_ptr) +
850  i_batch_ * kargs.batch_stride_block_table;
851  const index_t num_blocks =
852  integer_divide_ceil(kv_l2p_offset + kargs.seqlen_k, kargs.page_block_size);
853 
854  const long_index_t fixed_offset =
855  static_cast<long_index_t>(i_nhead_k) * kargs.nhead_stride_k;
856 
857  return make_page_block_navigator<const KDataType, 0>(
858  kargs.k_ptr,
859  kargs.batch_stride_k, // kcache page-block stride/size
860  fixed_offset,
861  block_indices,
862  num_blocks,
863  kargs.page_block_size,
864  k_dram,
865  make_k_dram(nullptr,
866  (kv_l2p_offset + kargs.seqlen_k) -
867  (num_blocks - 1) * kargs.page_block_size));
868  }
869  else
870  {
871  return make_page_block_navigator(k_dram);
872  }
873  }();
874 
875  auto v_page_block_navigator = [&, i_batch_ = i_batch]() {
876  if constexpr(kIsPagedKV)
877  {
878  const auto* block_indices =
879  reinterpret_cast<const int32_t*>(kargs.block_table_ptr) +
880  i_batch_ * kargs.batch_stride_block_table;
881  const index_t num_blocks =
882  integer_divide_ceil(kv_l2p_offset + kargs.seqlen_k, kargs.page_block_size);
883 
884  const long_index_t fixed_offset =
885  static_cast<long_index_t>(i_nhead_k) * kargs.nhead_stride_v;
886 
887  return make_page_block_navigator<const VDataType, 1>(
888  kargs.v_ptr,
889  kargs.batch_stride_v, // vcache page-block stride/size
890  fixed_offset,
891  block_indices,
892  num_blocks,
893  kargs.page_block_size,
894  v_dram,
895  make_v_dram(nullptr,
896  (kv_l2p_offset + kargs.seqlen_k) -
897  (num_blocks - 1) * kargs.page_block_size));
898  }
899  else
900  {
901  return make_page_block_navigator(v_dram);
902  }
903  }();
904 
905  auto q_dram_window = make_tile_window(
906  q_dram,
907  [&]() {
908  if constexpr(FmhaPipeline::kQLoadOnce)
911  else
913  }(),
914  {i_m0, 0});
915 
916  auto k_dram_window_lengths =
918  auto v_dram_window_lengths =
920 
923  const auto bias_dram_window = [&, i_nhead_ = i_nhead]() {
924  constexpr auto bias_dram_window_lengths =
927  {
928  const BiasDataType* bias_ptr =
929  reinterpret_cast<const BiasDataType*>(kargs.bias_ptr) +
930  static_cast<long_index_t>(i_nhead_) * kargs.nhead_stride_bias +
931  batch_offset_bias;
932 
933  const auto bias_dram = [&]() {
934  const auto bias_dram_naive = make_naive_tensor_view<address_space_enum::global>(
935  bias_ptr,
936  make_tuple(kargs.seqlen_q, kargs.seqlen_k),
937  make_tuple(kargs.stride_bias, 1),
939  number<1>{});
940 
941  return pad_tensor_view(
942  bias_dram_naive, bias_dram_window_lengths, sequence<false, kPadSeqLenK>{});
943  }();
944 
945  return make_tile_window(bias_dram, bias_dram_window_lengths, {i_m0, 0});
946  }
947  else
948  {
949  return make_null_tile_window(bias_dram_window_lengths);
950  }
951  }();
952 
953  // lse acc
954  auto lse_acc_dram_window = [&, i_nhead_ = i_nhead, i_split_ = i_split]() {
955  constexpr auto lse_acc_dram_window_lengths = make_tuple(number<FmhaPipeline::kM0>{});
956  LSEDataType* lse_acc_ptr = reinterpret_cast<LSEDataType*>(kargs.lse_acc_ptr) +
957  static_cast<long_index_t>(i_nhead_) *
958  (kMergeNumHeadGroupsSeqLenQ ? kargs.nhead_ratio_qk : 1) *
959  kargs.nhead_stride_lse_acc +
960  batch_offset_lse_acc + i_split_ * kargs.split_stride_lse_acc;
961 
962  const auto lse_acc_dram = [&] {
963  const auto lse_acc_dram_naive = [&] {
964  if constexpr(kMergeNumHeadGroupsSeqLenQ)
965  {
966  // reshape: (nhead_ratio_qk, seqlen_q) -> (nhead_ratio_qk * seqlen_q)
967  const auto view = make_naive_tensor_view<address_space_enum::global>(
968  lse_acc_ptr,
969  make_tuple(kargs.nhead_ratio_qk, kargs.seqlen_q),
970  make_tuple(kargs.nhead_stride_lse_acc, 1),
971  number<1>{},
972  number<1>{});
973 
974  return transform_tensor_view(view,
976  kargs.nhead_ratio_qk, kargs.seqlen_q))),
979  }
980  else
981  {
982  return make_naive_tensor_view<address_space_enum::global>(
983  lse_acc_ptr,
984  make_tuple(kargs.seqlen_q),
985  make_tuple(1),
986  number<1>{},
987  number<1>{});
988  }
989  }();
990  return pad_tensor_view(
991  lse_acc_dram_naive, lse_acc_dram_window_lengths, sequence<kPadSeqLenQ>{});
992  }();
993 
994  return make_tile_window(lse_acc_dram, lse_acc_dram_window_lengths, {i_m0});
995  }();
996 
997  FmhaMask mask = [&]() {
998  if constexpr(kHasMask)
999  return ck_tile::make_generic_attention_mask_from_lr_window<FmhaMask>(
1000  kargs.window_size_left,
1001  kargs.window_size_right,
1002  kargs.sink_size,
1003  kargs.seqlen_q,
1004  kargs.seqlen_k,
1006  else
1007  return FmhaMask{kargs.seqlen_q, kargs.seqlen_k};
1008  }();
1009 
1010  // WA i_batch capture structure binding before c++20
1011  auto position_encoding = [&, i_batch_ = i_batch, i_nhead_ = i_nhead]() {
1012  if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI)
1013  {
1014  // data loading, shared by entire wg
1015  // TODO: how to use s_read?
1016  SaccDataType slope =
1017  *(reinterpret_cast<const SaccDataType*>(kargs.alibi_slope_ptr) +
1018  i_batch_ * kargs.alibi_slope_stride + i_nhead_);
1019 #if CK_TILE_FMHA_FWD_FAST_EXP2
1020  slope *= ck_tile::log2e_v<>;
1021 #endif
1022  if constexpr(kHasMask)
1023  {
1024  return make_alibi_from_lr_mask<SaccDataType, true, 32>(slope,
1025  kargs.window_size_left,
1026  kargs.window_size_right,
1027  kargs.seqlen_q,
1028  kargs.seqlen_k,
1029  kargs.mask_type);
1030  }
1031  else
1032  {
1033  return Alibi<SaccDataType, true, 32>{
1034  slope, kargs.seqlen_q, kargs.seqlen_k, AlibiMode::FROM_BOTTOM_RIGHT};
1035  }
1036  }
1037  else
1038  {
1039  return EmptyPositionEncoding<SaccDataType>{};
1040  }
1041  }();
1042 
1043  AttentionVariant variant;
1044  const auto variant_params = [&] {
1045  if constexpr(kHasLogitsSoftCap)
1046  {
1048  mask, kargs.scale_s, kargs.logits_soft_cap, kargs.logits_soft_cap_rcp};
1049  }
1050  else
1051  {
1052  return ck_tile::StandardAttentionParams<FmhaMask>{mask, kargs.scale_s};
1053  }
1054  }();
1055 
1056  BlockIndices block_indices{i_batch, i_nhead, i_nhead_k};
1057 
1058  auto o_acc_tile = [&, i_split_ = i_split]() {
1059  if constexpr(kDoFp8StaticQuant)
1060  {
1061  return FmhaPipeline{}(q_dram_window,
1062  identity{}, // q_element_func
1063  k_dram_window_lengths,
1064  k_page_block_navigator,
1065  identity{}, // k_element_func
1066  v_dram_window_lengths,
1067  v_page_block_navigator,
1068  identity{}, // v_element_func
1069  bias_dram_window,
1070  identity{}, // bias_element_func
1071  lse_acc_dram_window,
1072  identity{}, // lse_element_func
1073  identity{}, // s_acc_element_func
1074  scales{kargs.scale_p}, // p_compute_element_func
1075  identity{}, // o_acc_element_func
1076  kargs.num_splits,
1077  i_split_,
1078  mask,
1079  position_encoding,
1080  kargs.scale_s,
1081  variant,
1082  variant_params,
1083  block_indices,
1084  kv_l2p_offset,
1085  smem_ptr);
1086  }
1087  else
1088  {
1089  return FmhaPipeline{}(q_dram_window,
1090  k_dram_window_lengths,
1091  k_page_block_navigator,
1092  v_dram_window_lengths,
1093  v_page_block_navigator,
1094  bias_dram_window,
1095  lse_acc_dram_window,
1096  kargs.num_splits,
1097  i_split_,
1098  mask,
1099  position_encoding,
1100  kargs.scale_s,
1101  variant,
1102  variant_params,
1103  block_indices,
1104  kv_l2p_offset,
1105  smem_ptr);
1106  }
1107  }();
1108 
1109  // Oacc DRAM and Oacc DRAM window
1110  auto o_acc_dram = [&] {
1111  const auto o_acc_dram_naive = [&] {
1112  if constexpr(kMergeNumHeadGroupsSeqLenQ)
1113  {
1114  // reshape: (nhead_ratio_qk, seqlen_q, hdim_v) -> (nhead_ratio_qk * seqlen_q,
1115  // hdim_v)
1116  const auto view = make_naive_tensor_view<address_space_enum::global>(
1117  o_acc_ptr,
1118  make_tuple(kargs.nhead_ratio_qk, kargs.seqlen_q, kargs.hdim_v),
1119  make_tuple(kargs.nhead_stride_o_acc, kargs.stride_o_acc, 1),
1120  number<FmhaPipeline::kAlignmentOacc>{},
1121  number<1>{});
1122 
1123  return transform_tensor_view(
1124  view,
1125  make_tuple(
1126  make_merge_transform(make_tuple(kargs.nhead_ratio_qk, kargs.seqlen_q)),
1127  make_pass_through_transform(kargs.hdim_v)),
1128  make_tuple(sequence<0, 1>{}, sequence<2>{}),
1129  make_tuple(sequence<0>{}, sequence<1>{}));
1130  }
1131  else
1132  {
1133  return make_naive_tensor_view<address_space_enum::global>(
1134  o_acc_ptr,
1135  make_tuple(kargs.seqlen_q, kargs.hdim_v),
1136  make_tuple(kargs.stride_o_acc, 1),
1137  number<FmhaPipeline::kAlignmentOacc>{},
1138  number<1>{});
1139  }
1140  }();
1141 
1142  return pad_tensor_view(
1143  o_acc_dram_naive,
1144  make_tuple(number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kN1>{}),
1145  sequence<kPadSeqLenQ, kPadHeadDimV>{});
1146  }();
1147 
1148  auto o_acc_dram_window =
1149  make_tile_window(o_acc_dram,
1150  make_tuple(number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kN1>{}),
1151  {i_m0, i_n1});
1152 
1153  EpiloguePipeline{}(o_acc_dram_window, o_acc_tile, nullptr);
1154  }
1155 };
1156 
1157 } // namespace ck_tile
#define CK_TILE_DEVICE
Definition: config.hpp:45
#define CK_TILE_HOST
Definition: config.hpp:44
#define CK_TILE_HOST_DEVICE
Definition: config.hpp:46
#define _TS_
#define _SS_
Definition: cluster_descriptor.hpp:13
constexpr CK_TILE_DEVICE auto make_null_tile_window(const WindowLengths &window_lengths)
Definition: null_tile_window.hpp:66
_BitInt(8) fp8_t
Definition: float8.hpp:204
constexpr CK_TILE_HOST_DEVICE auto integer_divide_ceil(X x, Y y)
Definition: math.hpp:145
__device__ uint32_t amd_wave_read_first_lane(uint16_t v)
Definition: amd_buffer_addressing.hpp:36
_Float16 fp16_t
Definition: half.hpp:110
CK_TILE_HOST_DEVICE auto make_page_block_navigator(const TensorView &tensor_view)
Definition: page_block_navigator.hpp:333
constexpr CK_TILE_HOST_DEVICE auto transform_tensor_view(const OldTensorView &old_tensor_view, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition: tensor_view.hpp:526
bfloat16_t bf16_t
Definition: bfloat16.hpp:113
constexpr CK_TILE_HOST_DEVICE auto make_merge_transform(const LowLengths &low_lengths)
Definition: coordinate_transform.hpp:1690
int32_t index_t
Definition: integer.hpp:9
constexpr CK_TILE_HOST_DEVICE auto pad_tensor_view(const TensorView &tensor_view, const TileLengths &tile_lengths, DoPads)
Definition: tensor_view.hpp:545
constexpr CK_TILE_HOST_DEVICE auto make_pass_through_transform(const LowLength &low_length)
Definition: coordinate_transform.hpp:1633
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition: type_traits.hpp:21
int64_t long_index_t
Definition: integer.hpp:11
int32_t int32_t
Definition: integer.hpp:10
constexpr CK_TILE_DEVICE auto make_tile_window(null_tensor_view, const WindowLengths &window_lengths, const multi_index< WindowLengths::size()> &, Ts &&...)
Definition: null_tile_window.hpp:75
unsigned _BitInt(8) bf8_t
Definition: float8.hpp:206
constexpr CK_TILE_HOST_DEVICE auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:360
GenericAttentionMaskEnum
Definition: block_masking.hpp:11
constexpr CK_TILE_HOST_DEVICE T max(T x)
Definition: math.hpp:157
CK_TILE_HOST_DEVICE_EXTERN scales(Scale) -> scales< Scale >
typename conditional< predicate, X, Y >::type conditional_t
Definition: functional.hpp:115
Definition: block_attention_bias_enum.hpp:19
Definition: fmha_fwd_splitkv_kernel.hpp:193
const void * alibi_slope_ptr
Definition: fmha_fwd_splitkv_kernel.hpp:195
ck_tile::index_t alibi_slope_stride
Definition: fmha_fwd_splitkv_kernel.hpp:196
Definition: fmha_fwd_splitkv_kernel.hpp:188
ck_tile::index_t batch_stride_bias
Definition: fmha_fwd_splitkv_kernel.hpp:189
Definition: fmha_fwd_splitkv_kernel.hpp:239
ck_tile::index_t batch_stride_lse_acc
Definition: fmha_fwd_splitkv_kernel.hpp:247
ck_tile::index_t batch_stride_v
Definition: fmha_fwd_splitkv_kernel.hpp:245
ck_tile::index_t batch_stride_q
Definition: fmha_fwd_splitkv_kernel.hpp:242
ck_tile::index_t batch_stride_k
Definition: fmha_fwd_splitkv_kernel.hpp:243
ck_tile::index_t batch_stride_o_acc
Definition: fmha_fwd_splitkv_kernel.hpp:248
const int32_t * seqlen_k_ptr
Definition: fmha_fwd_splitkv_kernel.hpp:240
Definition: fmha_fwd_splitkv_kernel.hpp:276
ck_tile::index_t batch_idx
Definition: fmha_fwd_splitkv_kernel.hpp:277
ck_tile::index_t kv_head_idx
Definition: fmha_fwd_splitkv_kernel.hpp:279
ck_tile::index_t qo_head_idx
Definition: fmha_fwd_splitkv_kernel.hpp:278
Definition: fmha_fwd_splitkv_kernel.hpp:224
const int32_t * cache_batch_idx
Definition: fmha_fwd_splitkv_kernel.hpp:225
Definition: fmha_fwd_splitkv_kernel.hpp:181
const void * bias_ptr
Definition: fmha_fwd_splitkv_kernel.hpp:182
ck_tile::index_t stride_bias
Definition: fmha_fwd_splitkv_kernel.hpp:183
ck_tile::index_t nhead_stride_bias
Definition: fmha_fwd_splitkv_kernel.hpp:184
Definition: fmha_fwd_splitkv_kernel.hpp:121
ck_tile::index_t split_stride_o_acc
Definition: fmha_fwd_splitkv_kernel.hpp:155
ck_tile::index_t nhead_stride_o_acc
Definition: fmha_fwd_splitkv_kernel.hpp:152
const void * k_ptr
Definition: fmha_fwd_splitkv_kernel.hpp:123
ck_tile::index_t num_splits
Definition: fmha_fwd_splitkv_kernel.hpp:139
void * lse_acc_ptr
Definition: fmha_fwd_splitkv_kernel.hpp:125
ck_tile::index_t nhead_stride_q
Definition: fmha_fwd_splitkv_kernel.hpp:148
void * o_acc_ptr
Definition: fmha_fwd_splitkv_kernel.hpp:126
ck_tile::index_t nhead_stride_lse_acc
Definition: fmha_fwd_splitkv_kernel.hpp:151
ck_tile::index_t nhead_stride_v
Definition: fmha_fwd_splitkv_kernel.hpp:150
ck_tile::index_t hdim_q
Definition: fmha_fwd_splitkv_kernel.hpp:132
const void * v_ptr
Definition: fmha_fwd_splitkv_kernel.hpp:124
ck_tile::index_t nhead_stride_k
Definition: fmha_fwd_splitkv_kernel.hpp:149
ck_tile::index_t split_stride_lse_acc
Definition: fmha_fwd_splitkv_kernel.hpp:154
ck_tile::index_t nhead_ratio_qk
Definition: fmha_fwd_splitkv_kernel.hpp:138
ck_tile::index_t stride_q
Definition: fmha_fwd_splitkv_kernel.hpp:143
ck_tile::index_t seqlen_k
Definition: fmha_fwd_splitkv_kernel.hpp:131
ck_tile::index_t batch
Definition: fmha_fwd_splitkv_kernel.hpp:128
const void * q_ptr
Definition: fmha_fwd_splitkv_kernel.hpp:122
ck_tile::index_t stride_v
Definition: fmha_fwd_splitkv_kernel.hpp:145
ck_tile::index_t stride_k
Definition: fmha_fwd_splitkv_kernel.hpp:144
ck_tile::index_t num_head_q
Definition: fmha_fwd_splitkv_kernel.hpp:135
ck_tile::index_t seqlen_q
Definition: fmha_fwd_splitkv_kernel.hpp:130
ck_tile::index_t stride_o_acc
Definition: fmha_fwd_splitkv_kernel.hpp:146
float scale_s
Definition: fmha_fwd_splitkv_kernel.hpp:141
ck_tile::index_t hdim_v
Definition: fmha_fwd_splitkv_kernel.hpp:133
Definition: fmha_fwd_splitkv_kernel.hpp:212
ck_tile::index_t page_block_size
Definition: fmha_fwd_splitkv_kernel.hpp:215
ck_tile::index_t batch_stride_block_table
Definition: fmha_fwd_splitkv_kernel.hpp:214
const int32_t * block_table_ptr
Definition: fmha_fwd_splitkv_kernel.hpp:213
Definition: fmha_fwd_splitkv_kernel.hpp:114
Definition: fmha_fwd_splitkv_kernel.hpp:207
float scale_p
Definition: fmha_fwd_splitkv_kernel.hpp:208
Definition: fmha_fwd_splitkv_kernel.hpp:262
const int32_t * seqlen_k_ptr
Definition: fmha_fwd_splitkv_kernel.hpp:265
ck_tile::index_t batch_stride_k
Definition: fmha_fwd_splitkv_kernel.hpp:267
const int32_t * seqstart_q_ptr
Definition: fmha_fwd_splitkv_kernel.hpp:263
const int32_t * seqstart_k_ptr
Definition: fmha_fwd_splitkv_kernel.hpp:264
ck_tile::index_t batch_stride_v
Definition: fmha_fwd_splitkv_kernel.hpp:269
Definition: fmha_fwd_splitkv_kernel.hpp:219
bool is_gappy
Definition: fmha_fwd_splitkv_kernel.hpp:220
Definition: fmha_fwd_splitkv_kernel.hpp:159
float logits_soft_cap_rcp
Definition: fmha_fwd_splitkv_kernel.hpp:177
float logits_soft_cap
Definition: fmha_fwd_splitkv_kernel.hpp:176
void init_logits_soft_cap(float logits_soft_cap_)
Definition: fmha_fwd_splitkv_kernel.hpp:162
Definition: fmha_fwd_splitkv_kernel.hpp:200
ck_tile::index_t sink_size
Definition: fmha_fwd_splitkv_kernel.hpp:202
ck_tile::index_t window_size_right
Definition: fmha_fwd_splitkv_kernel.hpp:202
ck_tile::GenericAttentionMaskEnum mask_type
Definition: fmha_fwd_splitkv_kernel.hpp:203
ck_tile::index_t window_size_left
Definition: fmha_fwd_splitkv_kernel.hpp:202
Definition: fmha_fwd_splitkv_kernel.hpp:66
Definition: fmha_fwd_splitkv_kernel.hpp:24
static constexpr auto BiasEnum
Definition: fmha_fwd_splitkv_kernel.hpp:50
static constexpr ck_tile::index_t kBlockSize
Definition: fmha_fwd_splitkv_kernel.hpp:27
ck_tile::remove_cvref_t< typename FmhaPipeline::BiasDataType > BiasDataType
Definition: fmha_fwd_splitkv_kernel.hpp:36
static constexpr CK_TILE_HOST_DEVICE ck_tile::index_t GetSmemSize()
Definition: fmha_fwd_splitkv_kernel.hpp:594
ck_tile::remove_cvref_t< EpiloguePipeline_ > EpiloguePipeline
Definition: fmha_fwd_splitkv_kernel.hpp:26
static constexpr CK_TILE_HOST std::enable_if_t< Cond, Kargs > MakeKargs(const void *q_ptr, const void *k_ptr, const void *v_ptr, const void *bias_ptr, void *lse_acc_ptr, void *o_acc_ptr, ck_tile::index_t batch, ck_tile::index_t seqlen_q, ck_tile::index_t seqlen_k, const void *seqlen_k_ptr, ck_tile::index_t hdim_q, ck_tile::index_t hdim_v, ck_tile::index_t num_head_q, ck_tile::index_t nhead_ratio_qk, ck_tile::index_t num_splits, const void *block_table_ptr, ck_tile::index_t batch_stride_block_table, ck_tile::index_t page_block_size, const void *cache_batch_idx, float scale_s, float scale_p, float logits_soft_cap, ck_tile::index_t stride_q, ck_tile::index_t stride_k, ck_tile::index_t stride_v, ck_tile::index_t stride_bias, ck_tile::index_t stride_o_acc, ck_tile::index_t nhead_stride_q, ck_tile::index_t nhead_stride_k, ck_tile::index_t nhead_stride_v, ck_tile::index_t nhead_stride_bias, ck_tile::index_t nhead_stride_lse_acc, ck_tile::index_t nhead_stride_o_acc, ck_tile::index_t batch_stride_q, ck_tile::index_t batch_stride_k, ck_tile::index_t batch_stride_v, ck_tile::index_t batch_stride_bias, ck_tile::index_t batch_stride_lse_acc, ck_tile::index_t batch_stride_o_acc, ck_tile::index_t split_stride_lse_acc, ck_tile::index_t split_stride_o_acc, ck_tile::index_t window_size_left, ck_tile::index_t window_size_right, ck_tile::index_t sink_size, ck_tile::index_t mask_type)
Definition: fmha_fwd_splitkv_kernel.hpp:284
ck_tile::remove_cvref_t< typename FmhaPipeline::VLayout > VLayout
Definition: fmha_fwd_splitkv_kernel.hpp:42
static constexpr bool kPadHeadDimQ
Definition: fmha_fwd_splitkv_kernel.hpp:47
std::conditional_t< kIsGroupMode, GroupModeKargs, BatchModeKargs > Kargs
Definition: fmha_fwd_splitkv_kernel.hpp:273
static constexpr bool kHasSink
Definition: fmha_fwd_splitkv_kernel.hpp:54
static constexpr CK_TILE_HOST auto GridSize(ck_tile::index_t batch_size, ck_tile::index_t nhead_q, ck_tile::index_t nhead_kv, ck_tile::index_t max_seqlen_q, ck_tile::index_t hdim_v, ck_tile::index_t num_splits)
Definition: fmha_fwd_splitkv_kernel.hpp:537
static constexpr bool kPadSeqLenK
Definition: fmha_fwd_splitkv_kernel.hpp:46
static constexpr bool kPadSeqLenQ
Definition: fmha_fwd_splitkv_kernel.hpp:45
ck_tile::remove_cvref_t< typename FmhaPipeline::KDataType > KDataType
Definition: fmha_fwd_splitkv_kernel.hpp:34
remove_cvref_t< typename FmhaPipeline::ODataType > ODataType
Definition: fmha_fwd_splitkv_kernel.hpp:40
static constexpr bool kMergeNumHeadGroupsSeqLenQ
Definition: fmha_fwd_splitkv_kernel.hpp:55
ck_tile::remove_cvref_t< typename FmhaPipeline::FmhaMask > FmhaMask
Definition: fmha_fwd_splitkv_kernel.hpp:58
static constexpr bool kDoFp8StaticQuant
Definition: fmha_fwd_splitkv_kernel.hpp:52
static CK_TILE_HOST std::string GetName()
Definition: fmha_fwd_splitkv_kernel.hpp:74
ck_tile::remove_cvref_t< typename FmhaPipeline::SaccDataType > SaccDataType
Definition: fmha_fwd_splitkv_kernel.hpp:38
static CK_TILE_HOST dim3 BlockSize()
Definition: fmha_fwd_splitkv_kernel.hpp:582
static constexpr bool kHasMask
Definition: fmha_fwd_splitkv_kernel.hpp:59
static constexpr CK_TILE_DEVICE auto GetTileIndex(const Kargs &kargs)
Definition: fmha_fwd_splitkv_kernel.hpp:555
static constexpr bool kPadHeadDimV
Definition: fmha_fwd_splitkv_kernel.hpp:48
ck_tile::remove_cvref_t< typename FmhaPipeline::VDataType > VDataType
Definition: fmha_fwd_splitkv_kernel.hpp:35
static constexpr ck_tile::index_t kBlockPerCuInput
Definition: fmha_fwd_splitkv_kernel.hpp:31
ck_tile::remove_cvref_t< typename FmhaPipeline::LSEDataType > LSEDataType
Definition: fmha_fwd_splitkv_kernel.hpp:37
remove_cvref_t< typename FmhaPipeline::OaccDataType > OaccDataType
Definition: fmha_fwd_splitkv_kernel.hpp:39
static constexpr CK_TILE_HOST std::enable_if_t< Cond, Kargs > MakeKargs(const void *q_ptr, const void *k_ptr, const void *v_ptr, const void *bias_ptr, void *lse_acc_ptr, void *o_acc_ptr, ck_tile::index_t batch, const void *seqstart_q_ptr, const void *seqstart_k_ptr, const void *seqlen_k_ptr, ck_tile::index_t hdim_q, ck_tile::index_t hdim_v, ck_tile::index_t num_head_q, ck_tile::index_t nhead_ratio_qk, ck_tile::index_t num_splits, const void *block_table_ptr, ck_tile::index_t batch_stride_block_table, ck_tile::index_t page_block_size, bool is_gappy, float scale_s, float scale_p, float logits_soft_cap, ck_tile::index_t stride_q, ck_tile::index_t stride_k, ck_tile::index_t stride_v, ck_tile::index_t stride_bias, ck_tile::index_t stride_o_acc, ck_tile::index_t nhead_stride_q, ck_tile::index_t nhead_stride_k, ck_tile::index_t nhead_stride_v, ck_tile::index_t nhead_stride_bias, ck_tile::index_t nhead_stride_lse_acc, ck_tile::index_t nhead_stride_o_acc, ck_tile::index_t batch_stride_k, ck_tile::index_t batch_stride_v, ck_tile::index_t split_stride_lse_acc, ck_tile::index_t split_stride_o_acc, ck_tile::index_t window_size_left, ck_tile::index_t window_size_right, ck_tile::index_t sink_size, ck_tile::index_t mask_type)
Definition: fmha_fwd_splitkv_kernel.hpp:416
ck_tile::remove_cvref_t< typename FmhaPipeline::AttentionVariant > AttentionVariant
Definition: fmha_fwd_splitkv_kernel.hpp:57
static constexpr ck_tile::index_t kBlockPerCu
Definition: fmha_fwd_splitkv_kernel.hpp:28
ck_tile::remove_cvref_t< typename FmhaPipeline::QDataType > QDataType
Definition: fmha_fwd_splitkv_kernel.hpp:33
static constexpr bool kStoreLSE
Definition: fmha_fwd_splitkv_kernel.hpp:51
static constexpr bool kIsPagedKV
Definition: fmha_fwd_splitkv_kernel.hpp:53
static constexpr bool kIsGroupMode
Definition: fmha_fwd_splitkv_kernel.hpp:44
CK_TILE_DEVICE void operator()(Kargs kargs) const
Definition: fmha_fwd_splitkv_kernel.hpp:599
static constexpr bool kHasLogitsSoftCap
Definition: fmha_fwd_splitkv_kernel.hpp:49
ck_tile::remove_cvref_t< FmhaPipeline_ > FmhaPipeline
Definition: fmha_fwd_splitkv_kernel.hpp:25
Definition: variants.hpp:63
float logits_soft_cap
Definition: variants.hpp:128
Definition: variants.hpp:51
Definition: integral_constant.hpp:13
Definition: sequence.hpp:49