/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp Source File
fmha_fwd_splitkv_kernel.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include "ck_tile/core.hpp"
7 #include "ck_tile/ops/common.hpp"
10 
11 #include <string>
12 #include <type_traits>
13 
14 // S[seqlen_q, seqlen_k] = Q[seqlen_q, hdim_q] @ K[seqlen_k, hdim_q]
15 // S'[seqlen_q, seqlen_k] = S[seqlen_q, seqlen_k] * Scale[1]
16 // S''[seqlen_q, seqlen_k] = S'[seqlen_q, seqlen_k] + Bias[seqlen_q, seqlen_k]
17 // P[seqlen_q, seqlen_k] = Softmax(S''[seqlen_q, seqlen_k])
18 // O[seqlen_q, hdim_v] = P[seqlen_q, seqlen_k] @ V^T[hdim_v, seqlen_k]
19 
20 namespace ck_tile {
21 
22 template <typename FmhaPipeline_, typename EpiloguePipeline_>
24 {
27  static constexpr ck_tile::index_t kBlockSize = FmhaPipeline::kBlockSize;
28  static constexpr ck_tile::index_t kBlockPerCu = FmhaPipeline::kBlockPerCu;
29 
30  static_assert(kBlockPerCu > 0);
31  static constexpr ck_tile::index_t kBlockPerCuInput = FmhaPipeline::Problem::kBlockPerCu;
32 
41 
43 
44  static constexpr bool kIsGroupMode = FmhaPipeline::kIsGroupMode;
45  static constexpr bool kPadSeqLenQ = FmhaPipeline::kPadSeqLenQ;
46  static constexpr bool kPadSeqLenK = FmhaPipeline::kPadSeqLenK;
47  static constexpr bool kPadHeadDimQ = FmhaPipeline::kPadHeadDimQ;
48  static constexpr bool kPadHeadDimV = FmhaPipeline::kPadHeadDimV;
49  static constexpr bool kHasLogitsSoftCap = FmhaPipeline::kHasLogitsSoftCap;
50  static constexpr auto BiasEnum = FmhaPipeline::BiasEnum;
51  static constexpr bool kStoreLSE = FmhaPipeline::kStoreLSE;
52  static constexpr bool kDoFp8StaticQuant = FmhaPipeline::Problem::kDoFp8StaticQuant;
53  static constexpr bool kIsPagedKV = FmhaPipeline::Problem::kIsPagedKV;
54  static constexpr bool kMergeNumHeadGroupsSeqLenQ =
55  FmhaPipeline::Problem::kMergeNumHeadGroupsSeqLenQ;
58  static constexpr bool kHasMask = FmhaMask::IsMasking;
59 
60  static_assert(!kMergeNumHeadGroupsSeqLenQ ||
62  !kHasMask));
63 
64  // clang-format off
65  template <typename T> struct t2s;
66  template <> struct t2s<float> { static constexpr const char * name = "fp32"; };
67  template <> struct t2s<ck_tile::fp16_t> { static constexpr const char * name = "fp16"; };
68  template <> struct t2s<ck_tile::bf16_t> { static constexpr const char * name = "bf16"; };
69  template <> struct t2s<ck_tile::fp8_t> { static constexpr const char * name = "fp8"; };
70  template <> struct t2s<ck_tile::bf8_t> { static constexpr const char * name = "bf8"; };
71  // clang-format on
72 
73  CK_TILE_HOST static std::string GetName()
74  {
75  // sync with generate.py
76  // clang-format off
77  using bfs = typename FmhaPipeline::BlockFmhaShape;
78  using g0br = typename bfs::Gemm0BlockWarps;
79  using g1br = typename bfs::Gemm1BlockWarps;
80  using g0wt = typename bfs::Gemm0WarpTile;
81  using g1wt = typename bfs::Gemm1WarpTile;
82  #define _SS_ std::string
83  #define _TS_ std::to_string
84  auto pn = [&] () {
85  std::string n;
86  if (kPadSeqLenQ) n += "s";
87  if (kPadSeqLenK) n += "sk";
88  if (kPadHeadDimQ) n += "d";
89  if (kPadHeadDimV) n += "dv";
90  return n.empty() ? n : std::string("p") + n; }();
91  return
92  _SS_("fmha_fwd_splitkv_d") + _TS_(bfs::kQKHeaddim) + "_" + _SS_(t2s<QDataType>::name) +
93  "_" + (kIsGroupMode ? "group" : "batch") + "_"
94  "b" + _TS_(bfs::kM0) + "x" + _TS_(bfs::kN0) + "x" + _TS_(bfs::kK0) + "x" +
95  _TS_(bfs::kN1) + "x" + _TS_(bfs::kK1) + "x" + _TS_(bfs::kQKHeaddim) + "_" +
96  "r" + _TS_(g0br::at(ck_tile::number<0>{})) + "x" + _TS_(g0br::at(ck_tile::number<1>{})) + "x" + _TS_(g0br::at(ck_tile::number<2>{})) + "_" +
97  "r" + _TS_(g1br::at(ck_tile::number<0>{})) + "x" + _TS_(g1br::at(ck_tile::number<1>{})) + "x" + _TS_(g1br::at(ck_tile::number<2>{})) + "_" +
98  "w" + _TS_(g0wt::at(ck_tile::number<0>{})) + "x" + _TS_(g0wt::at(ck_tile::number<1>{})) + "x" + _TS_(g0wt::at(ck_tile::number<2>{})) + "_" +
99  "w" + _TS_(g1wt::at(ck_tile::number<0>{})) + "x" + _TS_(g1wt::at(ck_tile::number<1>{})) + "x" + _TS_(g1wt::at(ck_tile::number<2>{})) + "_" +
100  (kBlockPerCuInput == -1 ? "" : ("o" + _TS_(kBlockPerCu) + "_")) + _SS_(FmhaPipeline::name) + "_" +
101  "v" + (std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor> ? "r" : "c") + (pn.empty() ? "_npad" : "_" + pn) +
102  (kHasLogitsSoftCap ? "_logits" : "_nlogits" ) + (BiasEnum == BlockAttentionBiasEnum::NO_BIAS ? _SS_("_nbias") : (_SS_("_") + BlockAttentionBiasEnumToStr<BiasEnum>::name)) +
103  (kHasMask ? "_" + _SS_(FmhaMask::name) : "_nmask") + (kStoreLSE ? "_lse" : "_nlse" ) +
104  (kDoFp8StaticQuant ? "_squant" : "_nsquant") + (kIsPagedKV ? "_pagedkv" : "_npagedkv" );
105  #undef _SS_
106  #undef _TS_
107  // clang-format on
108  }
109 
110  template <ck_tile::index_t I> // to avoid duplicated base class prblem, introduce an template
111  // arg
112  struct EmptyKargs
113  {
114  };
115 
116  // kargs use aggregate initializer, so no constructor will provided
117  // use inheritance to minimize karg size
118  // user need to use MakeKargs() function to create kargs.
119  struct CommonKargs
120  {
121  const void* q_ptr;
122  const void* k_ptr;
123  const void* v_ptr;
124  void* lse_acc_ptr;
125  void* o_acc_ptr;
126 
128 
133 
135  // for MQA/GQA, nhead could be different. This parameter is nhead_q / nhead_k
136  // if this param is larger than 1, indicate MQA/GQA case
139 
140  float scale_s;
141 
146 
152 
155  };
156 
158  {
159  LogitsSoftCapKargs() = default;
160 
161  void init_logits_soft_cap(float logits_soft_cap_)
162  {
163  if(0 < logits_soft_cap_)
164  {
165  logits_soft_cap = logits_soft_cap_;
167  }
168  else
169  {
170  logits_soft_cap = 0.f;
171  logits_soft_cap_rcp = 0.f;
172  }
173  }
174 
177  };
178 
180  {
181  const void* bias_ptr = nullptr;
184  };
185 
187  {
189  };
190 
191  struct AlibiKargs
192  {
193  // alibi is batch*nhead*1, no matter in batch/group mode, they are the same
194  const void* alibi_slope_ptr;
195  ck_tile::index_t alibi_slope_stride; // stride in batch, or 0 for all batch share same slope
196  };
197 
198  struct MaskKargs
199  {
200  // ck_tile::index_t window_size_left, window_size_right;
203  };
204 
206  {
207  float scale_p;
208  };
209 
211  {
215  };
216 
218  {
219  bool is_gappy = false;
220  };
221 
223  {
225  };
226 
228  : CommonKargs,
229  std::conditional_t<BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS,
230  BatchModeBiasKargs,
231  std::conditional_t<BiasEnum == BlockAttentionBiasEnum::ALIBI,
232  AlibiKargs,
233  EmptyKargs<0>>>,
234  std::conditional_t<kHasMask, MaskKargs, EmptyKargs<1>>,
235  std::conditional_t<kDoFp8StaticQuant, Fp8StaticQuantKargs, EmptyKargs<2>>,
236  std::conditional_t<kIsPagedKV, CommonPageBlockTableKargs, CacheBatchIdxKargs>,
237  std::conditional_t<kHasLogitsSoftCap, LogitsSoftCapKargs, EmptyKargs<3>>
238  {
240 
242  ck_tile::index_t batch_stride_k; // when using paged-kvcache, this will be stride/size for
243  // single kcache page-block
244  ck_tile::index_t batch_stride_v; // when using paged-kvcache, this will be stride/size for
245  // single vcache page-block
248  };
249 
251  : CommonKargs,
252  std::conditional_t<BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS,
253  CommonBiasKargs,
254  std::conditional_t<BiasEnum == BlockAttentionBiasEnum::ALIBI,
255  AlibiKargs,
256  EmptyKargs<0>>>,
257  std::conditional_t<kHasMask, MaskKargs, EmptyKargs<1>>,
258  std::conditional_t<kDoFp8StaticQuant, Fp8StaticQuantKargs, EmptyKargs<2>>,
259  std::conditional_t<kIsPagedKV, GroupModePageBlockTableKargs, EmptyKargs<3>>,
260  std::conditional_t<kHasLogitsSoftCap, LogitsSoftCapKargs, EmptyKargs<4>>
261  {
265 
266  ck_tile::index_t batch_stride_k; // only used for paged-kvcache, this will be stride/size
267  // for single kcache page-block
268  ck_tile::index_t batch_stride_v; // only used for paged-kvcache, this will be stride/size
269  // for single vcache page-block
270  };
271 
272  using Kargs = std::conditional_t<kIsGroupMode, GroupModeKargs, BatchModeKargs>;
273 
275  {
279  };
280 
281  template <bool Cond = !kIsGroupMode>
282  CK_TILE_HOST static constexpr std::enable_if_t<Cond, Kargs>
283  MakeKargs(const void* q_ptr,
284  const void* k_ptr,
285  const void* v_ptr,
286  const void* bias_ptr,
287  void* lse_acc_ptr, /* workspace for lse accumulation when num_splits > 1, otherwise
288  final lse */
289  void* o_acc_ptr, /* workspace for o accumulation when num_splits > 1, otherwise final
290  o */
291  ck_tile::index_t batch,
292  ck_tile::index_t seqlen_q,
293  ck_tile::index_t seqlen_k, // only used if 'seqlen_k_ptr' is not specified
294  const void* seqlen_k_ptr, // only used for (paged-) kvcache
295  ck_tile::index_t hdim_q,
296  ck_tile::index_t hdim_v,
297  ck_tile::index_t num_head_q,
298  ck_tile::index_t nhead_ratio_qk,
299  ck_tile::index_t num_splits,
300  const void* block_table_ptr,
301  ck_tile::index_t batch_stride_block_table,
302  ck_tile::index_t page_block_size,
303  const void* cache_batch_idx,
304  float scale_s,
305  float scale_p,
306  float logits_soft_cap,
307  ck_tile::index_t stride_q,
308  ck_tile::index_t stride_k,
309  ck_tile::index_t stride_v,
310  ck_tile::index_t stride_bias,
311  ck_tile::index_t stride_o_acc,
312  ck_tile::index_t nhead_stride_q,
313  ck_tile::index_t nhead_stride_k,
314  ck_tile::index_t nhead_stride_v,
315  ck_tile::index_t nhead_stride_bias,
316  ck_tile::index_t nhead_stride_lse_acc,
317  ck_tile::index_t nhead_stride_o_acc,
318  ck_tile::index_t batch_stride_q,
319  ck_tile::index_t batch_stride_k,
320  ck_tile::index_t batch_stride_v,
321  ck_tile::index_t batch_stride_bias,
322  ck_tile::index_t batch_stride_lse_acc,
323  ck_tile::index_t batch_stride_o_acc,
324  ck_tile::index_t split_stride_lse_acc,
325  ck_tile::index_t split_stride_o_acc,
326  ck_tile::index_t window_size_left,
327  ck_tile::index_t window_size_right,
328  ck_tile::index_t mask_type)
329  {
330  Kargs kargs{{q_ptr,
331  k_ptr,
332  v_ptr,
333  lse_acc_ptr,
334  o_acc_ptr,
335  batch,
336  seqlen_q,
337  seqlen_k,
338  hdim_q,
339  hdim_v,
340  num_head_q,
341  nhead_ratio_qk,
342  num_splits,
343 #if CK_TILE_FMHA_FWD_FAST_EXP2
344  static_cast<float>(scale_s * ck_tile::log2e_v<>),
345 #else
346  scale_s,
347 #endif
348  stride_q,
349  stride_k,
350  stride_v,
351  stride_o_acc,
352  nhead_stride_q,
353  nhead_stride_k,
354  nhead_stride_v,
355  nhead_stride_lse_acc,
356  nhead_stride_o_acc,
357  split_stride_lse_acc,
358  split_stride_o_acc}, // args for common karg
359  {}, // placeholder for bias
360  {}, // placeholder for mask
361  {}, // placeholder for fp8_static_quant args
362  {}, // placeholder for paged-block table or cache_batch_idx
363  {}, // placeholder for logits_soft_cap
364  reinterpret_cast<const int32_t*>(seqlen_k_ptr),
365  batch_stride_q,
366  batch_stride_k,
367  batch_stride_v,
368  batch_stride_lse_acc,
369  batch_stride_o_acc};
370 
372  {
373  kargs.bias_ptr = bias_ptr;
374  kargs.stride_bias = stride_bias;
375  kargs.nhead_stride_bias = nhead_stride_bias;
376  kargs.batch_stride_bias = batch_stride_bias;
377  }
378  else if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI)
379  {
380  kargs.alibi_slope_ptr = bias_ptr;
381  kargs.alibi_slope_stride = stride_bias;
382  }
383  if constexpr(kHasMask)
384  {
385  kargs.window_size_left = window_size_left;
386  kargs.window_size_right = window_size_right;
387  kargs.mask_type = static_cast<ck_tile::GenericAttentionMaskEnum>(mask_type);
388  }
389  if constexpr(kDoFp8StaticQuant)
390  {
391  kargs.scale_p = scale_p;
392  }
393  if constexpr(kIsPagedKV)
394  {
395  kargs.block_table_ptr = reinterpret_cast<const int32_t*>(block_table_ptr);
396  kargs.batch_stride_block_table = batch_stride_block_table;
397  kargs.page_block_size = page_block_size;
398  }
399  else
400  {
401  kargs.cache_batch_idx = reinterpret_cast<const int32_t*>(cache_batch_idx);
402  }
403  if constexpr(kHasLogitsSoftCap)
404  {
405  kargs.init_logits_soft_cap(logits_soft_cap);
406  }
407 
408  return kargs;
409  }
410 
411  template <bool Cond = kIsGroupMode>
412  CK_TILE_HOST static constexpr std::enable_if_t<Cond, Kargs>
413  MakeKargs(const void* q_ptr,
414  const void* k_ptr,
415  const void* v_ptr,
416  const void* bias_ptr,
417  void* lse_acc_ptr, /* workspace for lse accumulation when num_splits > 1, otherwise
418  final lse */
419  void* o_acc_ptr, /* workspace for o accumulation when num_splits > 1, otherwise final
420  o */
421  ck_tile::index_t batch,
422  const void* seqstart_q_ptr,
423  const void* seqstart_k_ptr,
424  const void* seqlen_k_ptr,
425  ck_tile::index_t hdim_q,
426  ck_tile::index_t hdim_v,
427  ck_tile::index_t num_head_q,
428  ck_tile::index_t nhead_ratio_qk,
429  ck_tile::index_t num_splits,
430  const void* block_table_ptr,
431  ck_tile::index_t batch_stride_block_table,
432  ck_tile::index_t page_block_size,
433  bool is_gappy,
434  float scale_s,
435  float scale_p,
436  float logits_soft_cap,
437  ck_tile::index_t stride_q,
438  ck_tile::index_t stride_k,
439  ck_tile::index_t stride_v,
440  ck_tile::index_t stride_bias,
441  ck_tile::index_t stride_o_acc,
442  ck_tile::index_t nhead_stride_q,
443  ck_tile::index_t nhead_stride_k,
444  ck_tile::index_t nhead_stride_v,
445  ck_tile::index_t nhead_stride_bias,
446  ck_tile::index_t nhead_stride_lse_acc,
447  ck_tile::index_t nhead_stride_o_acc,
448  ck_tile::index_t batch_stride_k, // only used for paged-kvcache
449  ck_tile::index_t batch_stride_v, // only used for paged-kvcache
450  ck_tile::index_t split_stride_lse_acc,
451  ck_tile::index_t split_stride_o_acc,
452  ck_tile::index_t window_size_left,
453  ck_tile::index_t window_size_right,
454  ck_tile::index_t mask_type)
455  {
456  Kargs kargs{{q_ptr,
457  k_ptr,
458  v_ptr,
459  lse_acc_ptr,
460  o_acc_ptr,
461  batch,
462  -1, // seqlen_q will be updated by another pointer
463  -1, // seqlen_k will be updated by another pointer
464  hdim_q,
465  hdim_v,
466  num_head_q,
467  nhead_ratio_qk,
468  num_splits,
469 #if CK_TILE_FMHA_FWD_FAST_EXP2
470  static_cast<float>(scale_s * ck_tile::log2e_v<>),
471 #else
472  scale_s,
473 #endif
474  stride_q,
475  stride_k,
476  stride_v,
477  stride_o_acc,
478  nhead_stride_q,
479  nhead_stride_k,
480  nhead_stride_v,
481  nhead_stride_lse_acc,
482  nhead_stride_o_acc,
483  split_stride_lse_acc,
484  split_stride_o_acc}, // args for common karg
485  {}, // placeholder for bias
486  {}, // placeholder for mask
487  {}, // placeholder for fp8_static_quant args
488  {}, // placeholder for paged-block table
489  {}, // placeholder for logits_soft_cap
490  reinterpret_cast<const int32_t*>(seqstart_q_ptr),
491  reinterpret_cast<const int32_t*>(seqstart_k_ptr),
492  reinterpret_cast<const int32_t*>(seqlen_k_ptr),
493  batch_stride_k,
494  batch_stride_v};
495 
497  {
498  kargs.bias_ptr = bias_ptr;
499  kargs.stride_bias = stride_bias;
500  kargs.nhead_stride_bias = nhead_stride_bias;
501  }
502  else if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI)
503  {
504  kargs.alibi_slope_ptr = bias_ptr;
505  kargs.alibi_slope_stride = stride_bias;
506  }
507  if constexpr(kHasMask)
508  {
509  kargs.window_size_left = window_size_left;
510  kargs.window_size_right = window_size_right;
511  kargs.mask_type = static_cast<ck_tile::GenericAttentionMaskEnum>(mask_type);
512  }
513  if constexpr(kDoFp8StaticQuant)
514  {
515  kargs.scale_p = scale_p;
516  }
517  if constexpr(kIsPagedKV)
518  {
519  kargs.block_table_ptr = reinterpret_cast<const int32_t*>(block_table_ptr);
520  kargs.batch_stride_block_table = batch_stride_block_table;
521  kargs.page_block_size = page_block_size;
522  kargs.is_gappy = is_gappy;
523  }
524  if constexpr(kHasLogitsSoftCap)
525  {
526  kargs.init_logits_soft_cap(logits_soft_cap);
527  }
528 
529  return kargs;
530  }
531 
532  CK_TILE_HOST static constexpr auto GridSize(ck_tile::index_t batch_size,
533  ck_tile::index_t nhead_q,
534  ck_tile::index_t nhead_kv,
535  ck_tile::index_t max_seqlen_q,
536  ck_tile::index_t hdim_v,
537  ck_tile::index_t num_splits)
538  {
539  ck_tile::index_t nhead_ = kMergeNumHeadGroupsSeqLenQ ? nhead_kv : nhead_q;
540  ck_tile::index_t max_seqlen_q_ =
541  max_seqlen_q * (kMergeNumHeadGroupsSeqLenQ ? nhead_q / nhead_kv : 1);
542 
543  // TODO: this may need tuning
544  return dim3(ck_tile::integer_divide_ceil(max_seqlen_q_, FmhaPipeline::kM0) *
545  ck_tile::integer_divide_ceil(hdim_v, FmhaPipeline::kN1) * num_splits,
546  nhead_,
547  batch_size);
548  }
549 
550  CK_TILE_DEVICE static constexpr auto GetTileIndex(const Kargs& kargs)
551  {
552  const index_t num_tile_n1 = ck_tile::integer_divide_ceil(kargs.hdim_v, FmhaPipeline::kN1);
553 
554  const auto f = [](index_t dividend, index_t divisor) {
555  index_t quotient = dividend / divisor;
556  index_t modulus = dividend - quotient * divisor;
557  return ck_tile::make_tuple(quotient, modulus);
558  };
559 
560  const auto [mn, i_split] = f(blockIdx.x, kargs.num_splits);
561  const auto [i_tile_m, i_tile_n] = f(mn, num_tile_n1);
562  const index_t i_nhead = blockIdx.y;
563  const index_t i_batch = blockIdx.z;
564 
565  if constexpr(kHasMask)
566  {
567  // assume that num_tile_n1 is always 1
568  return ck_tile::make_tuple(
569  (gridDim.x / kargs.num_splits) - 1 - i_tile_m, i_tile_n, i_split, i_nhead, i_batch);
570  }
571  else
572  {
573  return ck_tile::make_tuple(i_tile_m, i_tile_n, i_split, i_nhead, i_batch);
574  }
575  }
576 
577  CK_TILE_HOST static dim3 BlockSize()
578  {
579  if(is_wave32())
580  {
581  return dim3(kBlockSize / 2);
582  }
583  else
584  {
585  return dim3(kBlockSize);
586  }
587  }
588 
590  {
591  return ck_tile::max(FmhaPipeline::GetSmemSize(), EpiloguePipeline::GetSmemSize());
592  }
593 
594  CK_TILE_DEVICE void operator()(Kargs kargs) const
595  {
596  // allocate LDS
597  __shared__ char smem_ptr[GetSmemSize()];
598 
599  // divide problem
600  const auto [i_tile_m, i_tile_n, i_split, i_nhead, i_batch] = GetTileIndex(kargs);
601 
602  const index_t i_m0 = amd_wave_read_first_lane(i_tile_m * FmhaPipeline::kM0);
603  const index_t i_n1 = amd_wave_read_first_lane(i_tile_n * FmhaPipeline::kN1);
604 
605  long_index_t batch_offset_q = 0;
606  long_index_t batch_offset_k = 0; // unused for paged-kvcache
607  long_index_t batch_offset_v = 0; // unused for paged-kvcache
608  long_index_t batch_offset_bias = 0;
609  long_index_t batch_offset_lse_acc = 0;
610  long_index_t batch_offset_o_acc = 0;
611  index_t kv_l2p_offset =
612  0; // logical-to-physical offset of seqlen_k coordinate. only used for paged-kvcache
613 
614  if constexpr(kIsGroupMode)
615  {
616  // get starting offset for each batch
617  const long_index_t query_start = kargs.seqstart_q_ptr[i_batch];
618  const long_index_t key_start = kargs.seqstart_k_ptr[i_batch];
619 
620  batch_offset_q = query_start * kargs.stride_q;
621  batch_offset_k = key_start * kargs.stride_k;
622  if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
623  {
624  batch_offset_v = key_start * kargs.stride_v;
625  }
626  else
627  {
628  batch_offset_v = key_start;
629  }
631  {
632  batch_offset_bias = query_start * kargs.stride_bias;
633  }
634 
635  batch_offset_lse_acc = query_start;
636  batch_offset_o_acc = query_start * kargs.stride_o_acc;
637 
638  // get real # queries & # keys under group mode
639  kargs.seqlen_q = kargs.seqstart_q_ptr[i_batch + 1] - kargs.seqstart_q_ptr[i_batch];
640 
641  // # of required blocks is different in each groups, terminate unnecessary blocks
642  // earlier
643  if(kargs.seqlen_q * (kMergeNumHeadGroupsSeqLenQ ? kargs.nhead_ratio_qk : 1) <= i_m0)
644  {
645  return;
646  }
647 
648  if(kargs.seqlen_k_ptr != nullptr)
649  {
650  kargs.seqlen_k = kargs.seqlen_k_ptr[i_batch];
651  }
652  else
653  {
654  kargs.seqlen_k = kargs.seqstart_k_ptr[i_batch + 1] - kargs.seqstart_k_ptr[i_batch];
655  }
656 
657  if constexpr(kIsPagedKV)
658  {
659  if(kargs.is_gappy)
660  {
661  // seqstart_k_ptr has different meaning in this case
662  kv_l2p_offset = kargs.seqstart_k_ptr[i_batch];
663  }
664  }
665  }
666  else
667  {
668  const index_t i_cache_batch = [&, i_batch_ = i_batch] {
669  if constexpr(kIsPagedKV)
670  {
671  return i_batch_;
672  }
673  else
674  {
675  return (kargs.cache_batch_idx != nullptr ? kargs.cache_batch_idx[i_batch_]
676  : i_batch_);
677  }
678  }();
679 
680  batch_offset_q = static_cast<long_index_t>(i_batch) * kargs.batch_stride_q;
681  batch_offset_k = static_cast<long_index_t>(i_cache_batch) * kargs.batch_stride_k;
682  batch_offset_v = static_cast<long_index_t>(i_cache_batch) * kargs.batch_stride_v;
683  batch_offset_lse_acc = static_cast<long_index_t>(i_batch) * kargs.batch_stride_lse_acc;
684  batch_offset_o_acc = static_cast<long_index_t>(i_batch) * kargs.batch_stride_o_acc;
685 
687  {
688  batch_offset_bias = static_cast<long_index_t>(i_batch) * kargs.batch_stride_bias;
689  }
690 
691  if(kargs.seqlen_k_ptr != nullptr)
692  {
693  kargs.seqlen_k = kargs.seqlen_k_ptr[i_batch];
694  }
695  }
696 
697  // for simplicity, batch stride we just modify the pointer
698  const index_t i_nhead_k =
699  (kMergeNumHeadGroupsSeqLenQ ? i_nhead : i_nhead / kargs.nhead_ratio_qk);
700 
701  const QDataType* q_ptr = reinterpret_cast<const QDataType*>(kargs.q_ptr) +
702  static_cast<long_index_t>(i_nhead) *
703  (kMergeNumHeadGroupsSeqLenQ ? kargs.nhead_ratio_qk : 1) *
704  kargs.nhead_stride_q +
705  batch_offset_q;
706  const KDataType* k_ptr = reinterpret_cast<const KDataType*>(kargs.k_ptr) +
707  static_cast<long_index_t>(i_nhead_k) * kargs.nhead_stride_k +
708  batch_offset_k;
709  const VDataType* v_ptr = reinterpret_cast<const VDataType*>(kargs.v_ptr) +
710  static_cast<long_index_t>(i_nhead_k) * kargs.nhead_stride_v +
711  batch_offset_v;
712 
713  ODataType* o_acc_ptr = reinterpret_cast<ODataType*>(kargs.o_acc_ptr) +
714  static_cast<long_index_t>(i_nhead) *
715  (kMergeNumHeadGroupsSeqLenQ ? kargs.nhead_ratio_qk : 1) *
716  kargs.nhead_stride_o_acc +
717  batch_offset_o_acc + i_split * kargs.split_stride_o_acc;
718 
719  // Q/K/V DRAM and DRAM window
720  const auto q_dram = [&] {
721  const auto q_dram_naive = [&] {
722  if constexpr(kMergeNumHeadGroupsSeqLenQ)
723  {
724  // reshape: (nhead_ratio_qk, seqlen_q, hdim_q) -> (nhead_ratio_qk * seqlen_q,
725  // hdim_q)
726  const auto view = make_naive_tensor_view<address_space_enum::global>(
727  q_ptr,
728  make_tuple(kargs.nhead_ratio_qk, kargs.seqlen_q, kargs.hdim_q),
729  make_tuple(kargs.nhead_stride_q, kargs.stride_q, 1),
731  number<1>{});
732 
733  return transform_tensor_view(
734  view,
735  make_tuple(
736  make_merge_transform(make_tuple(kargs.nhead_ratio_qk, kargs.seqlen_q)),
737  make_pass_through_transform(kargs.hdim_q)),
740  }
741  else
742  {
743  return make_naive_tensor_view<address_space_enum::global>(
744  q_ptr,
745  make_tuple(kargs.seqlen_q, kargs.hdim_q),
746  make_tuple(kargs.stride_q, 1),
748  number<1>{});
749  }
750  }();
751 
752  if constexpr(FmhaPipeline::kQLoadOnce)
753  {
754  return pad_tensor_view(
755  q_dram_naive,
758  }
759  else
760  {
761  return pad_tensor_view(
762  q_dram_naive,
765  }
766  }();
767 
768  const auto make_k_dram = [&](const KDataType* data, index_t height) {
769  const auto k_dram_naive = make_naive_tensor_view<address_space_enum::global>(
770  data, // will update this pointer if using paged-kvcache
771  make_tuple(height, kargs.hdim_q),
772  make_tuple(kargs.stride_k, 1),
774  number<1>{});
775 
776  return pad_tensor_view(
777  k_dram_naive,
780  };
781  const auto k_dram = [&]() {
782  if constexpr(kIsPagedKV)
783  {
784  return make_k_dram(nullptr, kargs.page_block_size);
785  }
786  else
787  {
788  return make_k_dram(k_ptr, kargs.seqlen_k);
789  }
790  }();
791 
792  const auto make_v_dram = [&](const VDataType* data, index_t length) {
793  if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
794  {
795  const auto v_dram_naive = make_naive_tensor_view<address_space_enum::global>(
796  data, // will update this pointer if using paged-kvcache
797  make_tuple(length, kargs.hdim_v),
798  make_tuple(kargs.stride_v, 1),
800  number<1>{});
801 
802  const auto v_dram_transposed =
803  transform_tensor_view(v_dram_naive,
808 
809  return pad_tensor_view(
810  v_dram_transposed,
813  }
814  else
815  {
816  const auto v_dram_naive = make_naive_tensor_view<address_space_enum::global>(
817  data, // will update this pointer if using paged-kvcache
818  make_tuple(kargs.hdim_v, length),
819  make_tuple(kargs.stride_v, 1),
821  number<1>{});
822 
823  return pad_tensor_view(
824  v_dram_naive,
827  }
828  };
829  const auto v_dram = [&]() {
830  if constexpr(kIsPagedKV)
831  {
832  return make_v_dram(nullptr, kargs.page_block_size);
833  }
834  else
835  {
836  return make_v_dram(v_ptr, kargs.seqlen_k);
837  }
838  }();
839 
840  auto k_page_block_navigator = [&, i_batch_ = i_batch]() {
841  if constexpr(kIsPagedKV)
842  {
843  const auto* block_indices =
844  reinterpret_cast<const int32_t*>(kargs.block_table_ptr) +
845  i_batch_ * kargs.batch_stride_block_table;
846  const index_t num_blocks =
847  integer_divide_ceil(kv_l2p_offset + kargs.seqlen_k, kargs.page_block_size);
848 
849  const long_index_t fixed_offset =
850  static_cast<long_index_t>(i_nhead_k) * kargs.nhead_stride_k;
851 
852  return make_page_block_navigator<const KDataType, 0>(
853  kargs.k_ptr,
854  kargs.batch_stride_k, // kcache page-block stride/size
855  fixed_offset,
856  block_indices,
857  num_blocks,
858  kargs.page_block_size,
859  k_dram,
860  make_k_dram(nullptr,
861  (kv_l2p_offset + kargs.seqlen_k) -
862  (num_blocks - 1) * kargs.page_block_size));
863  }
864  else
865  {
866  return make_page_block_navigator(k_dram);
867  }
868  }();
869 
870  auto v_page_block_navigator = [&, i_batch_ = i_batch]() {
871  if constexpr(kIsPagedKV)
872  {
873  const auto* block_indices =
874  reinterpret_cast<const int32_t*>(kargs.block_table_ptr) +
875  i_batch_ * kargs.batch_stride_block_table;
876  const index_t num_blocks =
877  integer_divide_ceil(kv_l2p_offset + kargs.seqlen_k, kargs.page_block_size);
878 
879  const long_index_t fixed_offset =
880  static_cast<long_index_t>(i_nhead_k) * kargs.nhead_stride_v;
881 
882  return make_page_block_navigator<const VDataType, 1>(
883  kargs.v_ptr,
884  kargs.batch_stride_v, // vcache page-block stride/size
885  fixed_offset,
886  block_indices,
887  num_blocks,
888  kargs.page_block_size,
889  v_dram,
890  make_v_dram(nullptr,
891  (kv_l2p_offset + kargs.seqlen_k) -
892  (num_blocks - 1) * kargs.page_block_size));
893  }
894  else
895  {
896  return make_page_block_navigator(v_dram);
897  }
898  }();
899 
900  auto q_dram_window = make_tile_window(
901  q_dram,
902  [&]() {
903  if constexpr(FmhaPipeline::kQLoadOnce)
906  else
908  }(),
909  {i_m0, 0});
910 
911  auto k_dram_window_lengths =
913  auto v_dram_window_lengths =
915 
918  const auto bias_dram_window = [&, i_nhead_ = i_nhead]() {
919  constexpr auto bias_dram_window_lengths =
922  {
923  const BiasDataType* bias_ptr =
924  reinterpret_cast<const BiasDataType*>(kargs.bias_ptr) +
925  static_cast<long_index_t>(i_nhead_) * kargs.nhead_stride_bias +
926  batch_offset_bias;
927 
928  const auto bias_dram = [&]() {
929  const auto bias_dram_naive = make_naive_tensor_view<address_space_enum::global>(
930  bias_ptr,
931  make_tuple(kargs.seqlen_q, kargs.seqlen_k),
932  make_tuple(kargs.stride_bias, 1),
934  number<1>{});
935 
936  return pad_tensor_view(
937  bias_dram_naive, bias_dram_window_lengths, sequence<false, kPadSeqLenK>{});
938  }();
939 
940  return make_tile_window(bias_dram, bias_dram_window_lengths, {i_m0, 0});
941  }
942  else
943  {
944  return make_null_tile_window(bias_dram_window_lengths);
945  }
946  }();
947 
948  // lse acc
949  auto lse_acc_dram_window = [&, i_nhead_ = i_nhead, i_split_ = i_split]() {
950  constexpr auto lse_acc_dram_window_lengths = make_tuple(number<FmhaPipeline::kM0>{});
951  LSEDataType* lse_acc_ptr = reinterpret_cast<LSEDataType*>(kargs.lse_acc_ptr) +
952  static_cast<long_index_t>(i_nhead_) *
953  (kMergeNumHeadGroupsSeqLenQ ? kargs.nhead_ratio_qk : 1) *
954  kargs.nhead_stride_lse_acc +
955  batch_offset_lse_acc + i_split_ * kargs.split_stride_lse_acc;
956 
957  const auto lse_acc_dram = [&] {
958  const auto lse_acc_dram_naive = [&] {
959  if constexpr(kMergeNumHeadGroupsSeqLenQ)
960  {
961  // reshape: (nhead_ratio_qk, seqlen_q) -> (nhead_ratio_qk * seqlen_q)
962  const auto view = make_naive_tensor_view<address_space_enum::global>(
963  lse_acc_ptr,
964  make_tuple(kargs.nhead_ratio_qk, kargs.seqlen_q),
965  make_tuple(kargs.nhead_stride_lse_acc, 1),
966  number<1>{},
967  number<1>{});
968 
969  return transform_tensor_view(view,
971  kargs.nhead_ratio_qk, kargs.seqlen_q))),
974  }
975  else
976  {
977  return make_naive_tensor_view<address_space_enum::global>(
978  lse_acc_ptr,
979  make_tuple(kargs.seqlen_q),
980  make_tuple(1),
981  number<1>{},
982  number<1>{});
983  }
984  }();
985  return pad_tensor_view(
986  lse_acc_dram_naive, lse_acc_dram_window_lengths, sequence<kPadSeqLenQ>{});
987  }();
988 
989  return make_tile_window(lse_acc_dram, lse_acc_dram_window_lengths, {i_m0});
990  }();
991 
992  FmhaMask mask = [&]() {
993  if constexpr(kHasMask)
994  return ck_tile::make_generic_attention_mask_from_lr_window<FmhaMask>(
995  kargs.window_size_left,
996  kargs.window_size_right,
997  kargs.seqlen_q,
998  kargs.seqlen_k,
1000  else
1001  return FmhaMask{kargs.seqlen_q, kargs.seqlen_k};
1002  }();
1003 
1004  // WA i_batch capture structure binding before c++20
1005  auto position_encoding = [&, i_batch_ = i_batch, i_nhead_ = i_nhead]() {
1006  if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI)
1007  {
1008  // data loading, shared by entire wg
1009  // TODO: how to use s_read?
1010  SaccDataType slope =
1011  *(reinterpret_cast<const SaccDataType*>(kargs.alibi_slope_ptr) +
1012  i_batch_ * kargs.alibi_slope_stride + i_nhead_);
1013 #if CK_TILE_FMHA_FWD_FAST_EXP2
1014  slope *= ck_tile::log2e_v<>;
1015 #endif
1016  if constexpr(kHasMask)
1017  {
1018  return make_alibi_from_lr_mask<SaccDataType, true, 32>(slope,
1019  kargs.window_size_left,
1020  kargs.window_size_right,
1021  kargs.seqlen_q,
1022  kargs.seqlen_k,
1023  kargs.mask_type);
1024  }
1025  else
1026  {
1027  return Alibi<SaccDataType, true, 32>{
1028  slope, kargs.seqlen_q, kargs.seqlen_k, AlibiMode::FROM_BOTTOM_RIGHT};
1029  }
1030  }
1031  else
1032  {
1033  return EmptyPositionEncoding<SaccDataType>{};
1034  }
1035  }();
1036 
1037  AttentionVariant variant;
1038  const auto variant_params = [&] {
1039  if constexpr(kHasLogitsSoftCap)
1040  {
1042  mask, kargs.scale_s, kargs.logits_soft_cap, kargs.logits_soft_cap_rcp};
1043  }
1044  else
1045  {
1046  return ck_tile::StandardAttentionParams<FmhaMask>{mask, kargs.scale_s};
1047  }
1048  }();
1049 
1050  BlockIndices block_indices{i_batch, i_nhead, i_nhead_k};
1051 
1052  auto o_acc_tile = [&, i_split_ = i_split]() {
1053  if constexpr(kDoFp8StaticQuant)
1054  {
1055  return FmhaPipeline{}(q_dram_window,
1056  identity{}, // q_element_func
1057  k_dram_window_lengths,
1058  k_page_block_navigator,
1059  identity{}, // k_element_func
1060  v_dram_window_lengths,
1061  v_page_block_navigator,
1062  identity{}, // v_element_func
1063  bias_dram_window,
1064  identity{}, // bias_element_func
1065  lse_acc_dram_window,
1066  identity{}, // lse_element_func
1067  identity{}, // s_acc_element_func
1068  scales{kargs.scale_p}, // p_compute_element_func
1069  identity{}, // o_acc_element_func
1070  kargs.num_splits,
1071  i_split_,
1072  mask,
1073  position_encoding,
1074  kargs.scale_s,
1075  variant,
1076  variant_params,
1077  block_indices,
1078  kv_l2p_offset,
1079  smem_ptr);
1080  }
1081  else
1082  {
1083  return FmhaPipeline{}(q_dram_window,
1084  k_dram_window_lengths,
1085  k_page_block_navigator,
1086  v_dram_window_lengths,
1087  v_page_block_navigator,
1088  bias_dram_window,
1089  lse_acc_dram_window,
1090  kargs.num_splits,
1091  i_split_,
1092  mask,
1093  position_encoding,
1094  kargs.scale_s,
1095  variant,
1096  variant_params,
1097  block_indices,
1098  kv_l2p_offset,
1099  smem_ptr);
1100  }
1101  }();
1102 
1103  // Oacc DRAM and Oacc DRAM window
1104  auto o_acc_dram = [&] {
1105  const auto o_acc_dram_naive = [&] {
1106  if constexpr(kMergeNumHeadGroupsSeqLenQ)
1107  {
1108  // reshape: (nhead_ratio_qk, seqlen_q, hdim_v) -> (nhead_ratio_qk * seqlen_q,
1109  // hdim_v)
1110  const auto view = make_naive_tensor_view<address_space_enum::global>(
1111  o_acc_ptr,
1112  make_tuple(kargs.nhead_ratio_qk, kargs.seqlen_q, kargs.hdim_v),
1113  make_tuple(kargs.nhead_stride_o_acc, kargs.stride_o_acc, 1),
1114  number<FmhaPipeline::kAlignmentOacc>{},
1115  number<1>{});
1116 
1117  return transform_tensor_view(
1118  view,
1119  make_tuple(
1120  make_merge_transform(make_tuple(kargs.nhead_ratio_qk, kargs.seqlen_q)),
1121  make_pass_through_transform(kargs.hdim_v)),
1122  make_tuple(sequence<0, 1>{}, sequence<2>{}),
1123  make_tuple(sequence<0>{}, sequence<1>{}));
1124  }
1125  else
1126  {
1127  return make_naive_tensor_view<address_space_enum::global>(
1128  o_acc_ptr,
1129  make_tuple(kargs.seqlen_q, kargs.hdim_v),
1130  make_tuple(kargs.stride_o_acc, 1),
1131  number<FmhaPipeline::kAlignmentOacc>{},
1132  number<1>{});
1133  }
1134  }();
1135 
1136  return pad_tensor_view(
1137  o_acc_dram_naive,
1138  make_tuple(number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kN1>{}),
1139  sequence<kPadSeqLenQ, kPadHeadDimV>{});
1140  }();
1141 
1142  auto o_acc_dram_window =
1143  make_tile_window(o_acc_dram,
1144  make_tuple(number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kN1>{}),
1145  {i_m0, i_n1});
1146 
1147  EpiloguePipeline{}(o_acc_dram_window, o_acc_tile, nullptr);
1148  }
1149 };
1150 
1151 } // namespace ck_tile
#define CK_TILE_DEVICE
Definition: config.hpp:41
#define CK_TILE_HOST
Definition: config.hpp:40
#define CK_TILE_HOST_DEVICE
Definition: config.hpp:42
#define _TS_
#define _SS_
Definition: cluster_descriptor.hpp:13
constexpr CK_TILE_DEVICE auto make_null_tile_window(const WindowLengths &window_lengths)
Definition: null_tile_window.hpp:66
_BitInt(8) fp8_t
Definition: float8.hpp:204
constexpr CK_TILE_HOST_DEVICE auto integer_divide_ceil(X x, Y y)
Definition: math.hpp:149
__device__ uint32_t amd_wave_read_first_lane(uint16_t v)
Definition: amd_buffer_addressing.hpp:35
_Float16 fp16_t
Definition: half.hpp:110
CK_TILE_HOST_DEVICE auto make_page_block_navigator(const TensorView &tensor_view)
Definition: page_block_navigator.hpp:333
constexpr CK_TILE_HOST_DEVICE auto transform_tensor_view(const OldTensorView &old_tensor_view, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition: tensor_view.hpp:511
bfloat16_t bf16_t
Definition: bfloat16.hpp:113
constexpr CK_TILE_HOST_DEVICE auto make_merge_transform(const LowLengths &low_lengths)
Definition: coordinate_transform.hpp:1615
int32_t index_t
Definition: integer.hpp:9
constexpr CK_TILE_HOST_DEVICE auto pad_tensor_view(const TensorView &tensor_view, const TileLengths &tile_lengths, DoPads)
Definition: tensor_view.hpp:530
constexpr CK_TILE_HOST_DEVICE auto make_pass_through_transform(const LowLength &low_length)
Definition: coordinate_transform.hpp:1558
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition: type_traits.hpp:21
__host__ __device__ scales(Scale) -> scales< Scale >
FIXME: create macro to replace 'host device' and nothing more.
int64_t long_index_t
Definition: integer.hpp:11
int32_t int32_t
Definition: integer.hpp:10
constexpr CK_TILE_DEVICE auto make_tile_window(null_tensor_view, const WindowLengths &window_lengths, const multi_index< WindowLengths::size()> &, Ts &&...)
Definition: null_tile_window.hpp:75
unsigned _BitInt(8) bf8_t
Definition: float8.hpp:206
constexpr CK_TILE_HOST_DEVICE auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:360
GenericAttentionMaskEnum
Definition: block_masking.hpp:11
constexpr CK_TILE_HOST_DEVICE T max(T x)
Definition: math.hpp:161
typename conditional< predicate, X, Y >::type conditional_t
Definition: functional.hpp:115
Definition: block_attention_bias_enum.hpp:19
Definition: fmha_fwd_splitkv_kernel.hpp:192
const void * alibi_slope_ptr
Definition: fmha_fwd_splitkv_kernel.hpp:194
ck_tile::index_t alibi_slope_stride
Definition: fmha_fwd_splitkv_kernel.hpp:195
Definition: fmha_fwd_splitkv_kernel.hpp:187
ck_tile::index_t batch_stride_bias
Definition: fmha_fwd_splitkv_kernel.hpp:188
Definition: fmha_fwd_splitkv_kernel.hpp:238
ck_tile::index_t batch_stride_lse_acc
Definition: fmha_fwd_splitkv_kernel.hpp:246
ck_tile::index_t batch_stride_v
Definition: fmha_fwd_splitkv_kernel.hpp:244
ck_tile::index_t batch_stride_q
Definition: fmha_fwd_splitkv_kernel.hpp:241
ck_tile::index_t batch_stride_k
Definition: fmha_fwd_splitkv_kernel.hpp:242
ck_tile::index_t batch_stride_o_acc
Definition: fmha_fwd_splitkv_kernel.hpp:247
const int32_t * seqlen_k_ptr
Definition: fmha_fwd_splitkv_kernel.hpp:239
Definition: fmha_fwd_splitkv_kernel.hpp:275
ck_tile::index_t batch_idx
Definition: fmha_fwd_splitkv_kernel.hpp:276
ck_tile::index_t kv_head_idx
Definition: fmha_fwd_splitkv_kernel.hpp:278
ck_tile::index_t qo_head_idx
Definition: fmha_fwd_splitkv_kernel.hpp:277
Definition: fmha_fwd_splitkv_kernel.hpp:223
const int32_t * cache_batch_idx
Definition: fmha_fwd_splitkv_kernel.hpp:224
Definition: fmha_fwd_splitkv_kernel.hpp:180
const void * bias_ptr
Definition: fmha_fwd_splitkv_kernel.hpp:181
ck_tile::index_t stride_bias
Definition: fmha_fwd_splitkv_kernel.hpp:182
ck_tile::index_t nhead_stride_bias
Definition: fmha_fwd_splitkv_kernel.hpp:183
Definition: fmha_fwd_splitkv_kernel.hpp:120
ck_tile::index_t split_stride_o_acc
Definition: fmha_fwd_splitkv_kernel.hpp:154
ck_tile::index_t nhead_stride_o_acc
Definition: fmha_fwd_splitkv_kernel.hpp:151
const void * k_ptr
Definition: fmha_fwd_splitkv_kernel.hpp:122
ck_tile::index_t num_splits
Definition: fmha_fwd_splitkv_kernel.hpp:138
void * lse_acc_ptr
Definition: fmha_fwd_splitkv_kernel.hpp:124
ck_tile::index_t nhead_stride_q
Definition: fmha_fwd_splitkv_kernel.hpp:147
void * o_acc_ptr
Definition: fmha_fwd_splitkv_kernel.hpp:125
ck_tile::index_t nhead_stride_lse_acc
Definition: fmha_fwd_splitkv_kernel.hpp:150
ck_tile::index_t nhead_stride_v
Definition: fmha_fwd_splitkv_kernel.hpp:149
ck_tile::index_t hdim_q
Definition: fmha_fwd_splitkv_kernel.hpp:131
const void * v_ptr
Definition: fmha_fwd_splitkv_kernel.hpp:123
ck_tile::index_t nhead_stride_k
Definition: fmha_fwd_splitkv_kernel.hpp:148
ck_tile::index_t split_stride_lse_acc
Definition: fmha_fwd_splitkv_kernel.hpp:153
ck_tile::index_t nhead_ratio_qk
Definition: fmha_fwd_splitkv_kernel.hpp:137
ck_tile::index_t stride_q
Definition: fmha_fwd_splitkv_kernel.hpp:142
ck_tile::index_t seqlen_k
Definition: fmha_fwd_splitkv_kernel.hpp:130
ck_tile::index_t batch
Definition: fmha_fwd_splitkv_kernel.hpp:127
const void * q_ptr
Definition: fmha_fwd_splitkv_kernel.hpp:121
ck_tile::index_t stride_v
Definition: fmha_fwd_splitkv_kernel.hpp:144
ck_tile::index_t stride_k
Definition: fmha_fwd_splitkv_kernel.hpp:143
ck_tile::index_t num_head_q
Definition: fmha_fwd_splitkv_kernel.hpp:134
ck_tile::index_t seqlen_q
Definition: fmha_fwd_splitkv_kernel.hpp:129
ck_tile::index_t stride_o_acc
Definition: fmha_fwd_splitkv_kernel.hpp:145
float scale_s
Definition: fmha_fwd_splitkv_kernel.hpp:140
ck_tile::index_t hdim_v
Definition: fmha_fwd_splitkv_kernel.hpp:132
Definition: fmha_fwd_splitkv_kernel.hpp:211
ck_tile::index_t page_block_size
Definition: fmha_fwd_splitkv_kernel.hpp:214
ck_tile::index_t batch_stride_block_table
Definition: fmha_fwd_splitkv_kernel.hpp:213
const int32_t * block_table_ptr
Definition: fmha_fwd_splitkv_kernel.hpp:212
Definition: fmha_fwd_splitkv_kernel.hpp:113
Definition: fmha_fwd_splitkv_kernel.hpp:206
float scale_p
Definition: fmha_fwd_splitkv_kernel.hpp:207
Definition: fmha_fwd_splitkv_kernel.hpp:261
const int32_t * seqlen_k_ptr
Definition: fmha_fwd_splitkv_kernel.hpp:264
ck_tile::index_t batch_stride_k
Definition: fmha_fwd_splitkv_kernel.hpp:266
const int32_t * seqstart_q_ptr
Definition: fmha_fwd_splitkv_kernel.hpp:262
const int32_t * seqstart_k_ptr
Definition: fmha_fwd_splitkv_kernel.hpp:263
ck_tile::index_t batch_stride_v
Definition: fmha_fwd_splitkv_kernel.hpp:268
Definition: fmha_fwd_splitkv_kernel.hpp:218
bool is_gappy
Definition: fmha_fwd_splitkv_kernel.hpp:219
Definition: fmha_fwd_splitkv_kernel.hpp:158
float logits_soft_cap_rcp
Definition: fmha_fwd_splitkv_kernel.hpp:176
float logits_soft_cap
Definition: fmha_fwd_splitkv_kernel.hpp:175
void init_logits_soft_cap(float logits_soft_cap_)
Definition: fmha_fwd_splitkv_kernel.hpp:161
Definition: fmha_fwd_splitkv_kernel.hpp:199
ck_tile::index_t window_size_right
Definition: fmha_fwd_splitkv_kernel.hpp:201
ck_tile::GenericAttentionMaskEnum mask_type
Definition: fmha_fwd_splitkv_kernel.hpp:202
ck_tile::index_t window_size_left
Definition: fmha_fwd_splitkv_kernel.hpp:201
Definition: fmha_fwd_splitkv_kernel.hpp:65
Definition: fmha_fwd_splitkv_kernel.hpp:24
static constexpr auto BiasEnum
Definition: fmha_fwd_splitkv_kernel.hpp:50
static constexpr ck_tile::index_t kBlockSize
Definition: fmha_fwd_splitkv_kernel.hpp:27
ck_tile::remove_cvref_t< typename FmhaPipeline::BiasDataType > BiasDataType
Definition: fmha_fwd_splitkv_kernel.hpp:36
static constexpr CK_TILE_HOST_DEVICE ck_tile::index_t GetSmemSize()
Definition: fmha_fwd_splitkv_kernel.hpp:589
ck_tile::remove_cvref_t< EpiloguePipeline_ > EpiloguePipeline
Definition: fmha_fwd_splitkv_kernel.hpp:26
ck_tile::remove_cvref_t< typename FmhaPipeline::VLayout > VLayout
Definition: fmha_fwd_splitkv_kernel.hpp:42
static constexpr CK_TILE_HOST std::enable_if_t< Cond, Kargs > MakeKargs(const void *q_ptr, const void *k_ptr, const void *v_ptr, const void *bias_ptr, void *lse_acc_ptr, void *o_acc_ptr, ck_tile::index_t batch, ck_tile::index_t seqlen_q, ck_tile::index_t seqlen_k, const void *seqlen_k_ptr, ck_tile::index_t hdim_q, ck_tile::index_t hdim_v, ck_tile::index_t num_head_q, ck_tile::index_t nhead_ratio_qk, ck_tile::index_t num_splits, const void *block_table_ptr, ck_tile::index_t batch_stride_block_table, ck_tile::index_t page_block_size, const void *cache_batch_idx, float scale_s, float scale_p, float logits_soft_cap, ck_tile::index_t stride_q, ck_tile::index_t stride_k, ck_tile::index_t stride_v, ck_tile::index_t stride_bias, ck_tile::index_t stride_o_acc, ck_tile::index_t nhead_stride_q, ck_tile::index_t nhead_stride_k, ck_tile::index_t nhead_stride_v, ck_tile::index_t nhead_stride_bias, ck_tile::index_t nhead_stride_lse_acc, ck_tile::index_t nhead_stride_o_acc, ck_tile::index_t batch_stride_q, ck_tile::index_t batch_stride_k, ck_tile::index_t batch_stride_v, ck_tile::index_t batch_stride_bias, ck_tile::index_t batch_stride_lse_acc, ck_tile::index_t batch_stride_o_acc, ck_tile::index_t split_stride_lse_acc, ck_tile::index_t split_stride_o_acc, ck_tile::index_t window_size_left, ck_tile::index_t window_size_right, ck_tile::index_t mask_type)
Definition: fmha_fwd_splitkv_kernel.hpp:283
static constexpr bool kPadHeadDimQ
Definition: fmha_fwd_splitkv_kernel.hpp:47
std::conditional_t< kIsGroupMode, GroupModeKargs, BatchModeKargs > Kargs
Definition: fmha_fwd_splitkv_kernel.hpp:272
static constexpr CK_TILE_HOST auto GridSize(ck_tile::index_t batch_size, ck_tile::index_t nhead_q, ck_tile::index_t nhead_kv, ck_tile::index_t max_seqlen_q, ck_tile::index_t hdim_v, ck_tile::index_t num_splits)
Definition: fmha_fwd_splitkv_kernel.hpp:532
static constexpr bool kPadSeqLenK
Definition: fmha_fwd_splitkv_kernel.hpp:46
static constexpr bool kPadSeqLenQ
Definition: fmha_fwd_splitkv_kernel.hpp:45
ck_tile::remove_cvref_t< typename FmhaPipeline::KDataType > KDataType
Definition: fmha_fwd_splitkv_kernel.hpp:34
remove_cvref_t< typename FmhaPipeline::ODataType > ODataType
Definition: fmha_fwd_splitkv_kernel.hpp:40
static constexpr bool kMergeNumHeadGroupsSeqLenQ
Definition: fmha_fwd_splitkv_kernel.hpp:54
ck_tile::remove_cvref_t< typename FmhaPipeline::FmhaMask > FmhaMask
Definition: fmha_fwd_splitkv_kernel.hpp:57
static constexpr bool kDoFp8StaticQuant
Definition: fmha_fwd_splitkv_kernel.hpp:52
static CK_TILE_HOST std::string GetName()
Definition: fmha_fwd_splitkv_kernel.hpp:73
ck_tile::remove_cvref_t< typename FmhaPipeline::SaccDataType > SaccDataType
Definition: fmha_fwd_splitkv_kernel.hpp:38
static CK_TILE_HOST dim3 BlockSize()
Definition: fmha_fwd_splitkv_kernel.hpp:577
static constexpr bool kHasMask
Definition: fmha_fwd_splitkv_kernel.hpp:58
static constexpr CK_TILE_DEVICE auto GetTileIndex(const Kargs &kargs)
Definition: fmha_fwd_splitkv_kernel.hpp:550
static constexpr bool kPadHeadDimV
Definition: fmha_fwd_splitkv_kernel.hpp:48
ck_tile::remove_cvref_t< typename FmhaPipeline::VDataType > VDataType
Definition: fmha_fwd_splitkv_kernel.hpp:35
static constexpr ck_tile::index_t kBlockPerCuInput
Definition: fmha_fwd_splitkv_kernel.hpp:31
ck_tile::remove_cvref_t< typename FmhaPipeline::LSEDataType > LSEDataType
Definition: fmha_fwd_splitkv_kernel.hpp:37
remove_cvref_t< typename FmhaPipeline::OaccDataType > OaccDataType
Definition: fmha_fwd_splitkv_kernel.hpp:39
ck_tile::remove_cvref_t< typename FmhaPipeline::AttentionVariant > AttentionVariant
Definition: fmha_fwd_splitkv_kernel.hpp:56
static constexpr ck_tile::index_t kBlockPerCu
Definition: fmha_fwd_splitkv_kernel.hpp:28
static constexpr CK_TILE_HOST std::enable_if_t< Cond, Kargs > MakeKargs(const void *q_ptr, const void *k_ptr, const void *v_ptr, const void *bias_ptr, void *lse_acc_ptr, void *o_acc_ptr, ck_tile::index_t batch, const void *seqstart_q_ptr, const void *seqstart_k_ptr, const void *seqlen_k_ptr, ck_tile::index_t hdim_q, ck_tile::index_t hdim_v, ck_tile::index_t num_head_q, ck_tile::index_t nhead_ratio_qk, ck_tile::index_t num_splits, const void *block_table_ptr, ck_tile::index_t batch_stride_block_table, ck_tile::index_t page_block_size, bool is_gappy, float scale_s, float scale_p, float logits_soft_cap, ck_tile::index_t stride_q, ck_tile::index_t stride_k, ck_tile::index_t stride_v, ck_tile::index_t stride_bias, ck_tile::index_t stride_o_acc, ck_tile::index_t nhead_stride_q, ck_tile::index_t nhead_stride_k, ck_tile::index_t nhead_stride_v, ck_tile::index_t nhead_stride_bias, ck_tile::index_t nhead_stride_lse_acc, ck_tile::index_t nhead_stride_o_acc, ck_tile::index_t batch_stride_k, ck_tile::index_t batch_stride_v, ck_tile::index_t split_stride_lse_acc, ck_tile::index_t split_stride_o_acc, ck_tile::index_t window_size_left, ck_tile::index_t window_size_right, ck_tile::index_t mask_type)
Definition: fmha_fwd_splitkv_kernel.hpp:413
ck_tile::remove_cvref_t< typename FmhaPipeline::QDataType > QDataType
Definition: fmha_fwd_splitkv_kernel.hpp:33
static constexpr bool kStoreLSE
Definition: fmha_fwd_splitkv_kernel.hpp:51
static constexpr bool kIsPagedKV
Definition: fmha_fwd_splitkv_kernel.hpp:53
static constexpr bool kIsGroupMode
Definition: fmha_fwd_splitkv_kernel.hpp:44
CK_TILE_DEVICE void operator()(Kargs kargs) const
Definition: fmha_fwd_splitkv_kernel.hpp:594
static constexpr bool kHasLogitsSoftCap
Definition: fmha_fwd_splitkv_kernel.hpp:49
ck_tile::remove_cvref_t< FmhaPipeline_ > FmhaPipeline
Definition: fmha_fwd_splitkv_kernel.hpp:25
Definition: variants.hpp:63
float logits_soft_cap
Definition: variants.hpp:128
Definition: variants.hpp:51
Definition: integral_constant.hpp:13
Definition: sequence.hpp:49