include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp Source File

include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp Source File#

Composable Kernel: include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp Source File
fmha_fwd_splitkv_kernel.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include "ck_tile/core.hpp"
7 #include "ck_tile/ops/common.hpp"
9 #include <string>
10 #include <type_traits>
11 
12 // S[seqlen_q, seqlen_k] = Q[seqlen_q, hdim_q] @ K[seqlen_k, hdim_q]
13 // S'[seqlen_q, seqlen_k] = S[seqlen_q, seqlen_k] * Scale[1]
14 // S''[seqlen_q, seqlen_k] = S'[seqlen_q, seqlen_k] + Bias[seqlen_q, seqlen_k]
15 // P[seqlen_q, seqlen_k] = Softmax(S''[seqlen_q, seqlen_k])
16 // O[seqlen_q, hdim_v] = P[seqlen_q, seqlen_k] @ V^T[hdim_v, seqlen_k]
17 
18 namespace ck_tile {
19 
20 template <typename FmhaPipeline_, typename EpiloguePipeline_>
22 {
25  static constexpr ck_tile::index_t kBlockSize = FmhaPipeline::kBlockSize;
26  static constexpr ck_tile::index_t kBlockPerCu = FmhaPipeline::kBlockPerCu;
27  static_assert(kBlockPerCu > 0);
28  static constexpr ck_tile::index_t kBlockPerCuInput = FmhaPipeline::Problem::kBlockPerCu;
29 
38 
40 
41  static constexpr bool kIsGroupMode = FmhaPipeline::kIsGroupMode;
42  static constexpr bool kPadSeqLenQ = FmhaPipeline::kPadSeqLenQ;
43  static constexpr bool kPadSeqLenK = FmhaPipeline::kPadSeqLenK;
44  static constexpr bool kPadHeadDimQ = FmhaPipeline::kPadHeadDimQ;
45  static constexpr bool kPadHeadDimV = FmhaPipeline::kPadHeadDimV;
46  static constexpr auto BiasEnum = FmhaPipeline::BiasEnum;
47  static constexpr bool kStoreLSE = FmhaPipeline::kStoreLSE;
48  static constexpr bool kDoFp8StaticQuant = FmhaPipeline::Problem::kDoFp8StaticQuant;
49  static constexpr bool kIsPagedKV = FmhaPipeline::Problem::kIsPagedKV;
50  static constexpr bool kMergeNumHeadGroupsSeqLenQ =
51  FmhaPipeline::Problem::kMergeNumHeadGroupsSeqLenQ;
52 
54  static constexpr bool kHasMask = FmhaMask::IsMasking;
55 
56  static_assert(!kMergeNumHeadGroupsSeqLenQ ||
58  !kHasMask));
59 
60  // clang-format off
61  template <typename T> struct t2s;
62  template <> struct t2s<float> { static constexpr const char * name = "fp32"; };
63  template <> struct t2s<ck_tile::fp16_t> { static constexpr const char * name = "fp16"; };
64  template <> struct t2s<ck_tile::bf16_t> { static constexpr const char * name = "bf16"; };
65  template <> struct t2s<ck_tile::fp8_t> { static constexpr const char * name = "fp8"; };
66  template <> struct t2s<ck_tile::bf8_t> { static constexpr const char * name = "bf8"; };
67  // clang-format on
68 
69  __host__ static std::string GetName()
70  {
71  // sync with generate.py
72  // clang-format off
73  using bfs = typename FmhaPipeline::BlockFmhaShape;
74  using g0br = typename bfs::Gemm0BlockWarps;
75  using g1br = typename bfs::Gemm1BlockWarps;
76  using g0wt = typename bfs::Gemm0WarpTile;
77  using g1wt = typename bfs::Gemm1WarpTile;
78  #define _SS_ std::string
79  #define _TS_ std::to_string
80  auto pn = [&] () {
81  std::string n;
82  if (kPadSeqLenQ) n += "s";
83  if (kPadSeqLenK) n += "sk";
84  if (kPadHeadDimQ) n += "d";
85  if (kPadHeadDimV) n += "dv";
86  return n.empty() ? n : std::string("p") + n; }();
87  return
88  _SS_("fmha_fwd_splitkv_d") + _TS_(bfs::kQKHeaddim) + "_" + _SS_(t2s<QDataType>::name) +
89  "_" + (kIsGroupMode ? "group" : "batch") + "_"
90  "b" + _TS_(bfs::kM0) + "x" + _TS_(bfs::kN0) + "x" + _TS_(bfs::kK0) + "x" +
91  _TS_(bfs::kN1) + "x" + _TS_(bfs::kK1) + "x" + _TS_(bfs::kQKHeaddim) + "_" +
92  "r" + _TS_(g0br::at(ck_tile::number<0>{})) + "x" + _TS_(g0br::at(ck_tile::number<1>{})) + "x" + _TS_(g0br::at(ck_tile::number<2>{})) + "_" +
93  "r" + _TS_(g1br::at(ck_tile::number<0>{})) + "x" + _TS_(g1br::at(ck_tile::number<1>{})) + "x" + _TS_(g1br::at(ck_tile::number<2>{})) + "_" +
94  "w" + _TS_(g0wt::at(ck_tile::number<0>{})) + "x" + _TS_(g0wt::at(ck_tile::number<1>{})) + "x" + _TS_(g0wt::at(ck_tile::number<2>{})) + "_" +
95  "w" + _TS_(g1wt::at(ck_tile::number<0>{})) + "x" + _TS_(g1wt::at(ck_tile::number<1>{})) + "x" + _TS_(g1wt::at(ck_tile::number<2>{})) + "_" +
96  (kBlockPerCuInput == -1 ? "" : ("o" + _TS_(kBlockPerCu) + "_")) + _SS_(FmhaPipeline::name) + "_" +
97  "v" + (std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor> ? "r" : "c") + (pn.empty() ? "" : "_" + pn) +
99  (kHasMask ? "_" + _SS_(FmhaMask::name) : "") + (kStoreLSE ? "_lse" : "" ) + (kDoFp8StaticQuant ? "_squant" : "") + (kIsPagedKV ? "_pagedkv" : "" );
100  #undef _SS_
101  #undef _TS_
102  // clang-format on
103  }
104 
105  template <ck_tile::index_t I> // to avoid duplicated base class prblem, introduce an template
106  // arg
107  struct EmptyKargs
108  {
109  };
110 
111  // kargs use aggregate initializer, so no constructor will provided
112  // use inheritance to minimize karg size
113  // user need to use MakeKargs() function to create kargs.
114  struct CommonKargs
115  {
116  const void* q_ptr;
117  const void* k_ptr;
118  const void* v_ptr;
119  void* lse_acc_ptr;
120  void* o_acc_ptr;
121 
123 
128 
130  // for MQA/GQA, nhead could be different. This parameter is nhead_q / nhead_k
131  // if this param is larger than 1, indicate MQA/GQA case
134 
135  float scale_s;
136 
141 
147 
150  };
151 
153  {
154  const void* bias_ptr = nullptr;
157  };
158 
160  {
162  };
163 
164  struct AlibiKargs
165  {
166  // alibi is batch*nhead*1, no matter in batch/group mode, they are the same
167  const void* alibi_slope_ptr;
168  ck_tile::index_t alibi_slope_stride; // stride in batch, or 0 for all batch share same slope
169  };
170 
171  struct MaskKargs
172  {
173  // ck_tile::index_t window_size_left, window_size_right;
176  };
177 
179  {
180  float scale_p;
181  };
182 
184  {
185  const int32_t* block_table_ptr;
188  };
189 
191  {
192  bool is_gappy = false;
193  };
194 
196  {
197  const int32_t* cache_batch_idx;
198  };
199 
201  : CommonKargs,
202  std::conditional_t<BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS,
203  BatchModeBiasKargs,
204  std::conditional_t<BiasEnum == BlockAttentionBiasEnum::ALIBI,
205  AlibiKargs,
206  EmptyKargs<0>>>,
207  std::conditional_t<kHasMask, MaskKargs, EmptyKargs<1>>,
208  std::conditional_t<kDoFp8StaticQuant, Fp8StaticQuantKargs, EmptyKargs<2>>,
209  std::conditional_t<kIsPagedKV, CommonPageBlockTableKargs, CacheBatchIdxKargs>
210  {
211  const int32_t* seqlen_k_ptr;
212 
214  ck_tile::index_t batch_stride_k; // when using paged-kvcache, this will be stride/size for
215  // single kcache page-block
216  ck_tile::index_t batch_stride_v; // when using paged-kvcache, this will be stride/size for
217  // single vcache page-block
220  };
221 
223  : CommonKargs,
224  std::conditional_t<BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS,
225  CommonBiasKargs,
226  std::conditional_t<BiasEnum == BlockAttentionBiasEnum::ALIBI,
227  AlibiKargs,
228  EmptyKargs<0>>>,
229  std::conditional_t<kHasMask, MaskKargs, EmptyKargs<1>>,
230  std::conditional_t<kDoFp8StaticQuant, Fp8StaticQuantKargs, EmptyKargs<2>>,
231  std::conditional_t<kIsPagedKV, GroupModePageBlockTableKargs, EmptyKargs<3>>
232  {
233  const int32_t* seqstart_q_ptr;
234  const int32_t* seqstart_k_ptr;
235  const int32_t* seqlen_k_ptr;
236 
237  ck_tile::index_t batch_stride_k; // only used for paged-kvcache, this will be stride/size
238  // for single kcache page-block
239  ck_tile::index_t batch_stride_v; // only used for paged-kvcache, this will be stride/size
240  // for single vcache page-block
241  };
242 
243  using Kargs = std::conditional_t<kIsGroupMode, GroupModeKargs, BatchModeKargs>;
244 
245  template <bool Cond = !kIsGroupMode>
246  __host__ static constexpr std::enable_if_t<Cond, Kargs>
247  MakeKargs(const void* q_ptr,
248  const void* k_ptr,
249  const void* v_ptr,
250  const void* bias_ptr,
251  void* lse_acc_ptr, /* workspace for lse accumulation when num_splits > 1, otherwise
252  final lse */
253  void* o_acc_ptr, /* workspace for o accumulation when num_splits > 1, otherwise final
254  o */
255  ck_tile::index_t batch,
256  ck_tile::index_t seqlen_q,
257  ck_tile::index_t seqlen_k, // only used if 'seqlen_k_ptr' is not specified
258  const void* seqlen_k_ptr, // only used for (paged-) kvcache
259  ck_tile::index_t hdim_q,
260  ck_tile::index_t hdim_v,
261  ck_tile::index_t num_head_q,
262  ck_tile::index_t nhead_ratio_qk,
263  ck_tile::index_t num_splits,
264  const void* block_table_ptr,
265  ck_tile::index_t batch_stride_block_table,
266  ck_tile::index_t page_block_size,
267  const void* cache_batch_idx,
268  float scale_s,
269  float scale_p,
270  ck_tile::index_t stride_q,
271  ck_tile::index_t stride_k,
272  ck_tile::index_t stride_v,
273  ck_tile::index_t stride_bias,
274  ck_tile::index_t stride_o_acc,
275  ck_tile::index_t nhead_stride_q,
276  ck_tile::index_t nhead_stride_k,
277  ck_tile::index_t nhead_stride_v,
278  ck_tile::index_t nhead_stride_bias,
279  ck_tile::index_t nhead_stride_lse_acc,
280  ck_tile::index_t nhead_stride_o_acc,
281  ck_tile::index_t batch_stride_q,
282  ck_tile::index_t batch_stride_k,
283  ck_tile::index_t batch_stride_v,
284  ck_tile::index_t batch_stride_bias,
285  ck_tile::index_t batch_stride_lse_acc,
286  ck_tile::index_t batch_stride_o_acc,
287  ck_tile::index_t split_stride_lse_acc,
288  ck_tile::index_t split_stride_o_acc,
289  ck_tile::index_t window_size_left,
290  ck_tile::index_t window_size_right,
291  ck_tile::index_t mask_type)
292  {
293  Kargs kargs{{q_ptr,
294  k_ptr,
295  v_ptr,
296  lse_acc_ptr,
297  o_acc_ptr,
298  batch,
299  seqlen_q,
300  seqlen_k,
301  hdim_q,
302  hdim_v,
303  num_head_q,
304  nhead_ratio_qk,
305  num_splits,
306 #if CK_TILE_FMHA_FWD_FAST_EXP2
307  static_cast<float>(scale_s * ck_tile::log2e_v<>),
308 #else
309  scale_s,
310 #endif
311  stride_q,
312  stride_k,
313  stride_v,
314  stride_o_acc,
315  nhead_stride_q,
316  nhead_stride_k,
317  nhead_stride_v,
318  nhead_stride_lse_acc,
319  nhead_stride_o_acc,
320  split_stride_lse_acc,
321  split_stride_o_acc}, // args for common karg
322  {}, // placeholder for bias
323  {}, // placeholder for mask
324  {}, // placeholder for fp8_static_quant args
325  {}, // placeholder for paged-block table or cache_batch_idx
326  reinterpret_cast<const int32_t*>(seqlen_k_ptr),
327  batch_stride_q,
328  batch_stride_k,
329  batch_stride_v,
330  batch_stride_lse_acc,
331  batch_stride_o_acc};
332 
334  {
335  kargs.bias_ptr = bias_ptr;
336  kargs.stride_bias = stride_bias;
337  kargs.nhead_stride_bias = nhead_stride_bias;
338  kargs.batch_stride_bias = batch_stride_bias;
339  }
340  else if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI)
341  {
342  kargs.alibi_slope_ptr = bias_ptr;
343  kargs.alibi_slope_stride = stride_bias;
344  }
345  if constexpr(kHasMask)
346  {
347  kargs.window_size_left = window_size_left;
348  kargs.window_size_right = window_size_right;
349  kargs.mask_type = static_cast<ck_tile::GenericAttentionMaskEnum>(mask_type);
350  }
351  if constexpr(kDoFp8StaticQuant)
352  {
353  kargs.scale_p = scale_p;
354  }
355  if constexpr(kIsPagedKV)
356  {
357  kargs.block_table_ptr = reinterpret_cast<const int32_t*>(block_table_ptr);
358  kargs.batch_stride_block_table = batch_stride_block_table;
359  kargs.page_block_size = page_block_size;
360  }
361  else
362  {
363  kargs.cache_batch_idx = reinterpret_cast<const int32_t*>(cache_batch_idx);
364  }
365 
366  return kargs;
367  }
368 
369  template <bool Cond = kIsGroupMode>
370  __host__ static constexpr std::enable_if_t<Cond, Kargs>
371  MakeKargs(const void* q_ptr,
372  const void* k_ptr,
373  const void* v_ptr,
374  const void* bias_ptr,
375  void* lse_acc_ptr, /* workspace for lse accumulation when num_splits > 1, otherwise
376  final lse */
377  void* o_acc_ptr, /* workspace for o accumulation when num_splits > 1, otherwise final
378  o */
379  ck_tile::index_t batch,
380  const void* seqstart_q_ptr,
381  const void* seqstart_k_ptr,
382  const void* seqlen_k_ptr,
383  ck_tile::index_t hdim_q,
384  ck_tile::index_t hdim_v,
385  ck_tile::index_t num_head_q,
386  ck_tile::index_t nhead_ratio_qk,
387  ck_tile::index_t num_splits,
388  const void* block_table_ptr,
389  ck_tile::index_t batch_stride_block_table,
390  ck_tile::index_t page_block_size,
391  bool is_gappy,
392  float scale_s,
393  float scale_p,
394  ck_tile::index_t stride_q,
395  ck_tile::index_t stride_k,
396  ck_tile::index_t stride_v,
397  ck_tile::index_t stride_bias,
398  ck_tile::index_t stride_o_acc,
399  ck_tile::index_t nhead_stride_q,
400  ck_tile::index_t nhead_stride_k,
401  ck_tile::index_t nhead_stride_v,
402  ck_tile::index_t nhead_stride_bias,
403  ck_tile::index_t nhead_stride_lse_acc,
404  ck_tile::index_t nhead_stride_o_acc,
405  ck_tile::index_t batch_stride_k, // only used for paged-kvcache
406  ck_tile::index_t batch_stride_v, // only used for paged-kvcache
407  ck_tile::index_t split_stride_lse_acc,
408  ck_tile::index_t split_stride_o_acc,
409  ck_tile::index_t window_size_left,
410  ck_tile::index_t window_size_right,
411  ck_tile::index_t mask_type)
412  {
413  Kargs kargs{{q_ptr,
414  k_ptr,
415  v_ptr,
416  lse_acc_ptr,
417  o_acc_ptr,
418  batch,
419  -1, // seqlen_q will be updated by another pointer
420  -1, // seqlen_k will be updated by another pointer
421  hdim_q,
422  hdim_v,
423  num_head_q,
424  nhead_ratio_qk,
425  num_splits,
426 #if CK_TILE_FMHA_FWD_FAST_EXP2
427  static_cast<float>(scale_s * ck_tile::log2e_v<>),
428 #else
429  scale_s,
430 #endif
431  stride_q,
432  stride_k,
433  stride_v,
434  stride_o_acc,
435  nhead_stride_q,
436  nhead_stride_k,
437  nhead_stride_v,
438  nhead_stride_lse_acc,
439  nhead_stride_o_acc,
440  split_stride_lse_acc,
441  split_stride_o_acc}, // args for common karg
442  {}, // placeholder for bias
443  {}, // placeholder for mask
444  {}, // placeholder for fp8_static_quant args
445  {}, // placeholder for paged-block table
446  reinterpret_cast<const int32_t*>(seqstart_q_ptr),
447  reinterpret_cast<const int32_t*>(seqstart_k_ptr),
448  reinterpret_cast<const int32_t*>(seqlen_k_ptr),
449  batch_stride_k,
450  batch_stride_v};
451 
453  {
454  kargs.bias_ptr = bias_ptr;
455  kargs.stride_bias = stride_bias;
456  kargs.nhead_stride_bias = nhead_stride_bias;
457  }
458  else if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI)
459  {
460  kargs.alibi_slope_ptr = bias_ptr;
461  kargs.alibi_slope_stride = stride_bias;
462  }
463  if constexpr(kHasMask)
464  {
465  kargs.window_size_left = window_size_left;
466  kargs.window_size_right = window_size_right;
467  kargs.mask_type = static_cast<ck_tile::GenericAttentionMaskEnum>(mask_type);
468  }
469  if constexpr(kDoFp8StaticQuant)
470  {
471  kargs.scale_p = scale_p;
472  }
473  if constexpr(kIsPagedKV)
474  {
475  kargs.block_table_ptr = reinterpret_cast<const int32_t*>(block_table_ptr);
476  kargs.batch_stride_block_table = batch_stride_block_table;
477  kargs.page_block_size = page_block_size;
478  kargs.is_gappy = is_gappy;
479  }
480 
481  return kargs;
482  }
483 
484  CK_TILE_HOST static constexpr auto GridSize(ck_tile::index_t batch_size,
485  ck_tile::index_t nhead_q,
486  ck_tile::index_t nhead_kv,
487  ck_tile::index_t max_seqlen_q,
488  ck_tile::index_t hdim_v,
489  ck_tile::index_t num_splits)
490  {
491  ck_tile::index_t nhead_ = kMergeNumHeadGroupsSeqLenQ ? nhead_kv : nhead_q;
492  ck_tile::index_t max_seqlen_q_ =
493  max_seqlen_q * (kMergeNumHeadGroupsSeqLenQ ? nhead_q / nhead_kv : 1);
494 
495  // TODO: this may need tuning
496  return dim3(ck_tile::integer_divide_ceil(max_seqlen_q_, FmhaPipeline::kM0) *
497  ck_tile::integer_divide_ceil(hdim_v, FmhaPipeline::kN1) * num_splits,
498  nhead_,
499  batch_size);
500  }
501 
502  CK_TILE_DEVICE static constexpr auto GetTileIndex(const Kargs& kargs)
503  {
504  const index_t num_tile_n1 = ck_tile::integer_divide_ceil(kargs.hdim_v, FmhaPipeline::kN1);
505 
506  const auto f = [](index_t dividend, index_t divisor) {
507  index_t quotient = dividend / divisor;
508  index_t modulus = dividend - quotient * divisor;
509  return ck_tile::make_tuple(quotient, modulus);
510  };
511 
512  const auto [mn, i_split] = f(blockIdx.x, kargs.num_splits);
513  const auto [i_tile_m, i_tile_n] = f(mn, num_tile_n1);
514  const index_t i_nhead = blockIdx.y;
515  const index_t i_batch = blockIdx.z;
516 
517  return ck_tile::make_tuple(i_tile_m, i_tile_n, i_split, i_nhead, i_batch);
518  }
519 
520  __host__ static constexpr auto BlockSize() { return dim3(kBlockSize); }
521 
523  {
524  return ck_tile::max(FmhaPipeline::GetSmemSize(), EpiloguePipeline::GetSmemSize());
525  }
526 
527  CK_TILE_DEVICE void operator()(Kargs kargs) const
528  {
529  // allocate LDS
530  __shared__ char smem_ptr[GetSmemSize()];
531 
532  // divide problem
533  const auto [i_tile_m, i_tile_n, i_split, i_nhead, i_batch] = GetTileIndex(kargs);
534 
535  const index_t i_m0 = __builtin_amdgcn_readfirstlane(i_tile_m * FmhaPipeline::kM0);
536  const index_t i_n1 = __builtin_amdgcn_readfirstlane(i_tile_n * FmhaPipeline::kN1);
537 
538  long_index_t batch_offset_q = 0;
539  long_index_t batch_offset_k = 0; // unused for paged-kvcache
540  long_index_t batch_offset_v = 0; // unused for paged-kvcache
541  long_index_t batch_offset_bias = 0;
542  long_index_t batch_offset_lse_acc = 0;
543  long_index_t batch_offset_o_acc = 0;
544  index_t kv_l2p_offset =
545  0; // logical-to-physical offset of seqlen_k coordinate. only used for paged-kvcache
546 
547  if constexpr(kIsGroupMode)
548  {
549  // get starting offset for each batch
550  const long_index_t query_start = kargs.seqstart_q_ptr[i_batch];
551  const long_index_t key_start = kargs.seqstart_k_ptr[i_batch];
552 
553  batch_offset_q = query_start * kargs.stride_q;
554  batch_offset_k = key_start * kargs.stride_k;
555  if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
556  {
557  batch_offset_v = key_start * kargs.stride_v;
558  }
559  else
560  {
561  batch_offset_v = key_start;
562  }
564  {
565  batch_offset_bias = query_start * kargs.stride_bias + key_start;
566  }
567 
568  batch_offset_lse_acc = query_start;
569  batch_offset_o_acc = query_start * kargs.stride_o_acc;
570 
571  // get real # queries & # keys under group mode
572  kargs.seqlen_q = kargs.seqstart_q_ptr[i_batch + 1] - kargs.seqstart_q_ptr[i_batch];
573 
574  // # of required blocks is different in each groups, terminate unnecessary blocks
575  // earlier
576  if(kargs.seqlen_q * (kMergeNumHeadGroupsSeqLenQ ? kargs.nhead_ratio_qk : 1) <= i_m0)
577  {
578  return;
579  }
580 
581  if(kargs.seqlen_k_ptr != nullptr)
582  {
583  kargs.seqlen_k = kargs.seqlen_k_ptr[i_batch];
584  }
585  else
586  {
587  kargs.seqlen_k = kargs.seqstart_k_ptr[i_batch + 1] - kargs.seqstart_k_ptr[i_batch];
588  }
589 
590  if constexpr(kIsPagedKV)
591  {
592  if(kargs.is_gappy)
593  {
594  // seqstart_k_ptr has different meaning in this case
595  kv_l2p_offset = kargs.seqstart_k_ptr[i_batch];
596  }
597  }
598  }
599  else
600  {
601  const index_t i_cache_batch = [&, i_batch_ = i_batch] {
602  if constexpr(kIsPagedKV)
603  {
604  return i_batch_;
605  }
606  else
607  {
608  return (kargs.cache_batch_idx != nullptr ? kargs.cache_batch_idx[i_batch_]
609  : i_batch_);
610  }
611  }();
612 
613  batch_offset_q = static_cast<long_index_t>(i_batch) * kargs.batch_stride_q;
614  batch_offset_k = static_cast<long_index_t>(i_cache_batch) * kargs.batch_stride_k;
615  batch_offset_v = static_cast<long_index_t>(i_cache_batch) * kargs.batch_stride_v;
616  batch_offset_lse_acc = static_cast<long_index_t>(i_batch) * kargs.batch_stride_lse_acc;
617  batch_offset_o_acc = static_cast<long_index_t>(i_batch) * kargs.batch_stride_o_acc;
618 
620  {
621  batch_offset_bias = static_cast<long_index_t>(i_batch) * kargs.batch_stride_bias;
622  }
623 
624  if(kargs.seqlen_k_ptr != nullptr)
625  {
626  kargs.seqlen_k = kargs.seqlen_k_ptr[i_batch];
627  }
628  }
629 
630  // for simplicity, batch stride we just modify the pointer
631  const index_t i_nhead_k =
632  (kMergeNumHeadGroupsSeqLenQ ? i_nhead : i_nhead / kargs.nhead_ratio_qk);
633 
634  const QDataType* q_ptr = reinterpret_cast<const QDataType*>(kargs.q_ptr) +
635  static_cast<long_index_t>(i_nhead) *
636  (kMergeNumHeadGroupsSeqLenQ ? kargs.nhead_ratio_qk : 1) *
637  kargs.nhead_stride_q +
638  batch_offset_q;
639  const KDataType* k_ptr = reinterpret_cast<const KDataType*>(kargs.k_ptr) +
640  static_cast<long_index_t>(i_nhead_k) * kargs.nhead_stride_k +
641  batch_offset_k;
642  const VDataType* v_ptr = reinterpret_cast<const VDataType*>(kargs.v_ptr) +
643  static_cast<long_index_t>(i_nhead_k) * kargs.nhead_stride_v +
644  batch_offset_v;
645 
646  ODataType* o_acc_ptr = reinterpret_cast<ODataType*>(kargs.o_acc_ptr) +
647  static_cast<long_index_t>(i_nhead) *
648  (kMergeNumHeadGroupsSeqLenQ ? kargs.nhead_ratio_qk : 1) *
649  kargs.nhead_stride_o_acc +
650  batch_offset_o_acc + i_split * kargs.split_stride_o_acc;
651 
652  // Q/K/V DRAM and DRAM window
653  const auto q_dram = [&] {
654  const auto q_dram_naive = [&] {
655  if constexpr(kMergeNumHeadGroupsSeqLenQ)
656  {
657  // reshape: (nhead_ratio_qk, seqlen_q, hdim_q) -> (nhead_ratio_qk * seqlen_q,
658  // hdim_q)
659  const auto view = make_naive_tensor_view<address_space_enum::global>(
660  q_ptr,
661  make_tuple(kargs.nhead_ratio_qk, kargs.seqlen_q, kargs.hdim_q),
662  make_tuple(kargs.nhead_stride_q, kargs.stride_q, 1),
664  number<1>{});
665 
666  return transform_tensor_view(
667  view,
668  make_tuple(
669  make_merge_transform(make_tuple(kargs.nhead_ratio_qk, kargs.seqlen_q)),
670  make_pass_through_transform(kargs.hdim_q)),
673  }
674  else
675  {
676  return make_naive_tensor_view<address_space_enum::global>(
677  q_ptr,
678  make_tuple(kargs.seqlen_q, kargs.hdim_q),
679  make_tuple(kargs.stride_q, 1),
681  number<1>{});
682  }
683  }();
684 
685  if constexpr(FmhaPipeline::kQLoadOnce)
686  {
687  return pad_tensor_view(
688  q_dram_naive,
691  }
692  else
693  {
694  return pad_tensor_view(
695  q_dram_naive,
698  }
699  }();
700 
701  const auto make_k_dram = [&](const KDataType* data, index_t height) {
702  const auto k_dram_naive = make_naive_tensor_view<address_space_enum::global>(
703  data, // will update this pointer if using paged-kvcache
704  make_tuple(height, kargs.hdim_q),
705  make_tuple(kargs.stride_k, 1),
707  number<1>{});
708 
709  return pad_tensor_view(
710  k_dram_naive,
713  };
714  const auto k_dram = [&]() {
715  if constexpr(kIsPagedKV)
716  {
717  return make_k_dram(nullptr, kargs.page_block_size);
718  }
719  else
720  {
721  return make_k_dram(k_ptr, kargs.seqlen_k);
722  }
723  }();
724 
725  const auto make_v_dram = [&](const VDataType* data, index_t length) {
726  if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
727  {
728  const auto v_dram_naive = make_naive_tensor_view<address_space_enum::global>(
729  data, // will update this pointer if using paged-kvcache
730  make_tuple(length, kargs.hdim_v),
731  make_tuple(kargs.stride_v, 1),
733  number<1>{});
734 
735  const auto v_dram_transposed =
736  transform_tensor_view(v_dram_naive,
741 
742  return pad_tensor_view(
743  v_dram_transposed,
746  }
747  else
748  {
749  const auto v_dram_naive = make_naive_tensor_view<address_space_enum::global>(
750  data, // will update this pointer if using paged-kvcache
751  make_tuple(kargs.hdim_v, length),
752  make_tuple(kargs.stride_v, 1),
754  number<1>{});
755 
756  return pad_tensor_view(
757  v_dram_naive,
760  }
761  };
762  const auto v_dram = [&]() {
763  if constexpr(kIsPagedKV)
764  {
765  return make_v_dram(nullptr, kargs.page_block_size);
766  }
767  else
768  {
769  return make_v_dram(v_ptr, kargs.seqlen_k);
770  }
771  }();
772 
773  auto k_page_block_navigator = [&, i_batch_ = i_batch]() {
774  if constexpr(kIsPagedKV)
775  {
776  const auto* block_indices =
777  reinterpret_cast<const int32_t*>(kargs.block_table_ptr) +
778  i_batch_ * kargs.batch_stride_block_table;
779  const index_t num_blocks =
780  integer_divide_ceil(kv_l2p_offset + kargs.seqlen_k, kargs.page_block_size);
781 
782  const long_index_t fixed_offset =
783  static_cast<long_index_t>(i_nhead_k) * kargs.nhead_stride_k;
784 
785  return make_page_block_navigator<const KDataType, 0>(
786  kargs.k_ptr,
787  kargs.batch_stride_k, // kcache page-block stride/size
788  fixed_offset,
789  block_indices,
790  num_blocks,
791  kargs.page_block_size,
792  k_dram,
793  make_k_dram(nullptr,
794  (kv_l2p_offset + kargs.seqlen_k) -
795  (num_blocks - 1) * kargs.page_block_size));
796  }
797  else
798  {
799  return make_page_block_navigator(k_dram);
800  }
801  }();
802 
803  auto v_page_block_navigator = [&, i_batch_ = i_batch]() {
804  if constexpr(kIsPagedKV)
805  {
806  const auto* block_indices =
807  reinterpret_cast<const int32_t*>(kargs.block_table_ptr) +
808  i_batch_ * kargs.batch_stride_block_table;
809  const index_t num_blocks =
810  integer_divide_ceil(kv_l2p_offset + kargs.seqlen_k, kargs.page_block_size);
811 
812  const long_index_t fixed_offset =
813  static_cast<long_index_t>(i_nhead_k) * kargs.nhead_stride_v;
814 
815  return make_page_block_navigator<const VDataType, 1>(
816  kargs.v_ptr,
817  kargs.batch_stride_v, // vcache page-block stride/size
818  fixed_offset,
819  block_indices,
820  num_blocks,
821  kargs.page_block_size,
822  v_dram,
823  make_v_dram(nullptr,
824  (kv_l2p_offset + kargs.seqlen_k) -
825  (num_blocks - 1) * kargs.page_block_size));
826  }
827  else
828  {
829  return make_page_block_navigator(v_dram);
830  }
831  }();
832 
833  auto q_dram_window = make_tile_window(
834  q_dram,
835  [&]() {
836  if constexpr(FmhaPipeline::kQLoadOnce)
839  else
841  }(),
842  {i_m0, 0});
843 
844  auto k_dram_window_lengths =
846  auto v_dram_window_lengths =
848 
851  const auto bias_dram_window = [&, i_nhead_ = i_nhead]() {
852  constexpr auto bias_dram_window_lengths =
855  {
856  const BiasDataType* bias_ptr =
857  reinterpret_cast<const BiasDataType*>(kargs.bias_ptr) +
858  static_cast<long_index_t>(i_nhead_) * kargs.nhead_stride_bias +
859  batch_offset_bias;
860 
861  const auto bias_dram = [&]() {
862  const auto bias_dram_naive = make_naive_tensor_view<address_space_enum::global>(
863  bias_ptr,
864  make_tuple(kargs.seqlen_q, kargs.seqlen_k),
865  make_tuple(kargs.stride_bias, 1),
867  number<1>{});
868 
869  return pad_tensor_view(
870  bias_dram_naive, bias_dram_window_lengths, sequence<false, kPadSeqLenK>{});
871  }();
872 
873  return make_tile_window(bias_dram, bias_dram_window_lengths, {i_m0, 0});
874  }
875  else
876  {
877  return make_null_tile_window(bias_dram_window_lengths);
878  }
879  }();
880 
881  // lse acc
882  auto lse_acc_dram_window = [&, i_nhead_ = i_nhead, i_split_ = i_split]() {
883  constexpr auto lse_acc_dram_window_lengths = make_tuple(number<FmhaPipeline::kM0>{});
884  LSEDataType* lse_acc_ptr = reinterpret_cast<LSEDataType*>(kargs.lse_acc_ptr) +
885  static_cast<long_index_t>(i_nhead_) *
886  (kMergeNumHeadGroupsSeqLenQ ? kargs.nhead_ratio_qk : 1) *
887  kargs.nhead_stride_lse_acc +
888  batch_offset_lse_acc + i_split_ * kargs.split_stride_lse_acc;
889 
890  const auto lse_acc_dram = [&] {
891  const auto lse_acc_dram_naive = [&] {
892  if constexpr(kMergeNumHeadGroupsSeqLenQ)
893  {
894  // reshape: (nhead_ratio_qk, seqlen_q) -> (nhead_ratio_qk * seqlen_q)
895  const auto view = make_naive_tensor_view<address_space_enum::global>(
896  lse_acc_ptr,
897  make_tuple(kargs.nhead_ratio_qk, kargs.seqlen_q),
898  make_tuple(kargs.nhead_stride_lse_acc, 1),
899  number<1>{},
900  number<1>{});
901 
902  return transform_tensor_view(view,
904  kargs.nhead_ratio_qk, kargs.seqlen_q))),
907  }
908  else
909  {
910  return make_naive_tensor_view<address_space_enum::global>(
911  lse_acc_ptr,
912  make_tuple(kargs.seqlen_q),
913  make_tuple(1),
914  number<1>{},
915  number<1>{});
916  }
917  }();
918  return pad_tensor_view(
919  lse_acc_dram_naive, lse_acc_dram_window_lengths, sequence<kPadSeqLenQ>{});
920  }();
921 
922  return make_tile_window(lse_acc_dram, lse_acc_dram_window_lengths, {i_m0});
923  }();
924 
925  FmhaMask mask = [&]() {
926  if constexpr(kHasMask)
927  return ck_tile::make_generic_attention_mask_from_lr_window<FmhaMask>(
928  kargs.window_size_left,
929  kargs.window_size_right,
930  kargs.seqlen_q,
931  kargs.seqlen_k,
933  else
934  return FmhaMask{kargs.seqlen_q, kargs.seqlen_k};
935  }();
936 
937  // WA i_batch capture structure binding before c++20
938  auto position_encoding = [&, i_batch_ = i_batch, i_nhead_ = i_nhead]() {
939  if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI)
940  {
941  // data loading, shared by entire wg
942  // TODO: how to use s_read?
943  SaccDataType slope =
944  *(reinterpret_cast<const SaccDataType*>(kargs.alibi_slope_ptr) +
945  i_batch_ * kargs.alibi_slope_stride + i_nhead_);
946 #if CK_TILE_FMHA_FWD_FAST_EXP2
947  slope *= ck_tile::log2e_v<>;
948 #endif
949  if constexpr(kHasMask)
950  {
951  return make_alibi_from_lr_mask<SaccDataType, true, 32>(slope,
952  kargs.window_size_left,
953  kargs.window_size_right,
954  kargs.seqlen_q,
955  kargs.seqlen_k,
956  kargs.mask_type);
957  }
958  else
959  {
960  return Alibi<SaccDataType, true, 32>{
961  slope, kargs.seqlen_q, kargs.seqlen_k, AlibiMode::FROM_BOTTOM_RIGHT};
962  }
963  }
964  else
965  {
966  return EmptyPositionEncoding<SaccDataType>{};
967  }
968  }();
969 
970  auto o_acc_tile = [&, i_split_ = i_split]() {
971  if constexpr(kDoFp8StaticQuant)
972  {
973  return FmhaPipeline{}(q_dram_window,
974  identity{}, // q_element_func
975  k_dram_window_lengths,
976  k_page_block_navigator,
977  identity{}, // k_element_func
978  v_dram_window_lengths,
979  v_page_block_navigator,
980  identity{}, // v_element_func
981  bias_dram_window,
982  identity{}, // bias_element_func
983  lse_acc_dram_window,
984  identity{}, // lse_element_func
985  identity{}, // s_acc_element_func
986  scales{kargs.scale_p}, // p_compute_element_func
987  identity{}, // o_acc_element_func
988  kargs.num_splits,
989  i_split_,
990  mask,
991  position_encoding,
992  kargs.scale_s,
993  kv_l2p_offset,
994  smem_ptr);
995  }
996  else
997  {
998  return FmhaPipeline{}(q_dram_window,
999  k_dram_window_lengths,
1000  k_page_block_navigator,
1001  v_dram_window_lengths,
1002  v_page_block_navigator,
1003  bias_dram_window,
1004  lse_acc_dram_window,
1005  kargs.num_splits,
1006  i_split_,
1007  mask,
1008  position_encoding,
1009  kargs.scale_s,
1010  kv_l2p_offset,
1011  smem_ptr);
1012  }
1013  }();
1014 
1015  // Oacc DRAM and Oacc DRAM window
1016  auto o_acc_dram = [&] {
1017  const auto o_acc_dram_naive = [&] {
1018  if constexpr(kMergeNumHeadGroupsSeqLenQ)
1019  {
1020  // reshape: (nhead_ratio_qk, seqlen_q, hdim_v) -> (nhead_ratio_qk * seqlen_q,
1021  // hdim_v)
1022  const auto view = make_naive_tensor_view<address_space_enum::global>(
1023  o_acc_ptr,
1024  make_tuple(kargs.nhead_ratio_qk, kargs.seqlen_q, kargs.hdim_v),
1025  make_tuple(kargs.nhead_stride_o_acc, kargs.stride_o_acc, 1),
1026  number<FmhaPipeline::kAlignmentOacc>{},
1027  number<1>{});
1028 
1029  return transform_tensor_view(
1030  view,
1031  make_tuple(
1032  make_merge_transform(make_tuple(kargs.nhead_ratio_qk, kargs.seqlen_q)),
1033  make_pass_through_transform(kargs.hdim_v)),
1034  make_tuple(sequence<0, 1>{}, sequence<2>{}),
1035  make_tuple(sequence<0>{}, sequence<1>{}));
1036  }
1037  else
1038  {
1039  return make_naive_tensor_view<address_space_enum::global>(
1040  o_acc_ptr,
1041  make_tuple(kargs.seqlen_q, kargs.hdim_v),
1042  make_tuple(kargs.stride_o_acc, 1),
1043  number<FmhaPipeline::kAlignmentOacc>{},
1044  number<1>{});
1045  }
1046  }();
1047 
1048  return pad_tensor_view(
1049  o_acc_dram_naive,
1050  make_tuple(number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kN1>{}),
1051  sequence<kPadSeqLenQ, kPadHeadDimV>{});
1052  }();
1053 
1054  auto o_acc_dram_window =
1055  make_tile_window(o_acc_dram,
1056  make_tuple(number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kN1>{}),
1057  {i_m0, i_n1});
1058 
1059  EpiloguePipeline{}(o_acc_dram_window, o_acc_tile);
1060  }
1061 };
1062 
1063 } // namespace ck_tile
#define CK_TILE_DEVICE
Definition: config.hpp:40
#define CK_TILE_HOST
Definition: config.hpp:39
#define CK_TILE_HOST_DEVICE
Definition: config.hpp:41
#define _TS_
#define _SS_
Definition: cluster_descriptor.hpp:13
constexpr CK_TILE_DEVICE auto make_null_tile_window(const WindowLengths &window_lengths)
Definition: null_tile_window.hpp:63
_BitInt(8) fp8_t
Definition: float8.hpp:204
constexpr CK_TILE_HOST_DEVICE auto integer_divide_ceil(X x, Y y)
Definition: math.hpp:149
_Float16 fp16_t
Definition: half.hpp:110
CK_TILE_HOST_DEVICE auto make_page_block_navigator(const TensorView &tensor_view)
Definition: page_block_navigator.hpp:262
constexpr CK_TILE_HOST_DEVICE auto transform_tensor_view(const OldTensorView &old_tensor_view, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition: tensor_view.hpp:461
bfloat16_t bf16_t
Definition: bfloat16.hpp:106
constexpr CK_TILE_HOST_DEVICE auto make_merge_transform(const LowLengths &low_lengths)
Definition: coordinate_transform.hpp:1672
int32_t index_t
Definition: integer.hpp:9
constexpr CK_TILE_HOST_DEVICE auto pad_tensor_view(const TensorView &tensor_view, const TileLengths &tile_lengths, DoPads)
Definition: tensor_view.hpp:480
constexpr CK_TILE_HOST_DEVICE auto make_pass_through_transform(const LowLength &low_length)
Definition: coordinate_transform.hpp:1615
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition: type_traits.hpp:20
__host__ __device__ scales(Scale) -> scales< Scale >
FIXME: create macro to replace 'host device' and nothing more.
int64_t long_index_t
Definition: integer.hpp:10
constexpr CK_TILE_DEVICE auto make_tile_window(null_tensor_view, const WindowLengths &window_lengths, const multi_index< WindowLengths::size()> &, Ts &&...)
Definition: null_tile_window.hpp:72
unsigned _BitInt(8) bf8_t
Definition: float8.hpp:206
constexpr CK_TILE_HOST_DEVICE auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:337
GenericAttentionMaskEnum
Definition: block_masking.hpp:11
constexpr CK_TILE_HOST_DEVICE T max(T x)
Definition: math.hpp:161
typename conditional< predicate, X, Y >::type conditional_t
Definition: functional.hpp:115
Definition: block_attention_bias_enum.hpp:19
Definition: fmha_fwd_splitkv_kernel.hpp:165
const void * alibi_slope_ptr
Definition: fmha_fwd_splitkv_kernel.hpp:167
ck_tile::index_t alibi_slope_stride
Definition: fmha_fwd_splitkv_kernel.hpp:168
Definition: fmha_fwd_splitkv_kernel.hpp:160
ck_tile::index_t batch_stride_bias
Definition: fmha_fwd_splitkv_kernel.hpp:161
Definition: fmha_fwd_splitkv_kernel.hpp:210
ck_tile::index_t batch_stride_lse_acc
Definition: fmha_fwd_splitkv_kernel.hpp:218
ck_tile::index_t batch_stride_v
Definition: fmha_fwd_splitkv_kernel.hpp:216
ck_tile::index_t batch_stride_q
Definition: fmha_fwd_splitkv_kernel.hpp:213
ck_tile::index_t batch_stride_k
Definition: fmha_fwd_splitkv_kernel.hpp:214
ck_tile::index_t batch_stride_o_acc
Definition: fmha_fwd_splitkv_kernel.hpp:219
const int32_t * seqlen_k_ptr
Definition: fmha_fwd_splitkv_kernel.hpp:211
Definition: fmha_fwd_splitkv_kernel.hpp:196
const int32_t * cache_batch_idx
Definition: fmha_fwd_splitkv_kernel.hpp:197
Definition: fmha_fwd_splitkv_kernel.hpp:153
const void * bias_ptr
Definition: fmha_fwd_splitkv_kernel.hpp:154
ck_tile::index_t stride_bias
Definition: fmha_fwd_splitkv_kernel.hpp:155
ck_tile::index_t nhead_stride_bias
Definition: fmha_fwd_splitkv_kernel.hpp:156
Definition: fmha_fwd_splitkv_kernel.hpp:115
ck_tile::index_t split_stride_o_acc
Definition: fmha_fwd_splitkv_kernel.hpp:149
ck_tile::index_t nhead_stride_o_acc
Definition: fmha_fwd_splitkv_kernel.hpp:146
const void * k_ptr
Definition: fmha_fwd_splitkv_kernel.hpp:117
ck_tile::index_t num_splits
Definition: fmha_fwd_splitkv_kernel.hpp:133
void * lse_acc_ptr
Definition: fmha_fwd_splitkv_kernel.hpp:119
ck_tile::index_t nhead_stride_q
Definition: fmha_fwd_splitkv_kernel.hpp:142
void * o_acc_ptr
Definition: fmha_fwd_splitkv_kernel.hpp:120
ck_tile::index_t nhead_stride_lse_acc
Definition: fmha_fwd_splitkv_kernel.hpp:145
ck_tile::index_t nhead_stride_v
Definition: fmha_fwd_splitkv_kernel.hpp:144
ck_tile::index_t hdim_q
Definition: fmha_fwd_splitkv_kernel.hpp:126
const void * v_ptr
Definition: fmha_fwd_splitkv_kernel.hpp:118
ck_tile::index_t nhead_stride_k
Definition: fmha_fwd_splitkv_kernel.hpp:143
ck_tile::index_t split_stride_lse_acc
Definition: fmha_fwd_splitkv_kernel.hpp:148
ck_tile::index_t nhead_ratio_qk
Definition: fmha_fwd_splitkv_kernel.hpp:132
ck_tile::index_t stride_q
Definition: fmha_fwd_splitkv_kernel.hpp:137
ck_tile::index_t seqlen_k
Definition: fmha_fwd_splitkv_kernel.hpp:125
ck_tile::index_t batch
Definition: fmha_fwd_splitkv_kernel.hpp:122
const void * q_ptr
Definition: fmha_fwd_splitkv_kernel.hpp:116
ck_tile::index_t stride_v
Definition: fmha_fwd_splitkv_kernel.hpp:139
ck_tile::index_t stride_k
Definition: fmha_fwd_splitkv_kernel.hpp:138
ck_tile::index_t num_head_q
Definition: fmha_fwd_splitkv_kernel.hpp:129
ck_tile::index_t seqlen_q
Definition: fmha_fwd_splitkv_kernel.hpp:124
ck_tile::index_t stride_o_acc
Definition: fmha_fwd_splitkv_kernel.hpp:140
float scale_s
Definition: fmha_fwd_splitkv_kernel.hpp:135
ck_tile::index_t hdim_v
Definition: fmha_fwd_splitkv_kernel.hpp:127
Definition: fmha_fwd_splitkv_kernel.hpp:184
ck_tile::index_t page_block_size
Definition: fmha_fwd_splitkv_kernel.hpp:187
ck_tile::index_t batch_stride_block_table
Definition: fmha_fwd_splitkv_kernel.hpp:186
const int32_t * block_table_ptr
Definition: fmha_fwd_splitkv_kernel.hpp:185
Definition: fmha_fwd_splitkv_kernel.hpp:108
Definition: fmha_fwd_splitkv_kernel.hpp:179
float scale_p
Definition: fmha_fwd_splitkv_kernel.hpp:180
Definition: fmha_fwd_splitkv_kernel.hpp:232
const int32_t * seqlen_k_ptr
Definition: fmha_fwd_splitkv_kernel.hpp:235
ck_tile::index_t batch_stride_k
Definition: fmha_fwd_splitkv_kernel.hpp:237
const int32_t * seqstart_q_ptr
Definition: fmha_fwd_splitkv_kernel.hpp:233
const int32_t * seqstart_k_ptr
Definition: fmha_fwd_splitkv_kernel.hpp:234
ck_tile::index_t batch_stride_v
Definition: fmha_fwd_splitkv_kernel.hpp:239
Definition: fmha_fwd_splitkv_kernel.hpp:191
bool is_gappy
Definition: fmha_fwd_splitkv_kernel.hpp:192
Definition: fmha_fwd_splitkv_kernel.hpp:172
ck_tile::index_t window_size_right
Definition: fmha_fwd_splitkv_kernel.hpp:174
ck_tile::GenericAttentionMaskEnum mask_type
Definition: fmha_fwd_splitkv_kernel.hpp:175
ck_tile::index_t window_size_left
Definition: fmha_fwd_splitkv_kernel.hpp:174
Definition: fmha_fwd_splitkv_kernel.hpp:61
Definition: fmha_fwd_splitkv_kernel.hpp:22
static constexpr auto BiasEnum
Definition: fmha_fwd_splitkv_kernel.hpp:46
static constexpr ck_tile::index_t kBlockSize
Definition: fmha_fwd_splitkv_kernel.hpp:25
ck_tile::remove_cvref_t< typename FmhaPipeline::BiasDataType > BiasDataType
Definition: fmha_fwd_splitkv_kernel.hpp:33
static constexpr CK_TILE_HOST_DEVICE ck_tile::index_t GetSmemSize()
Definition: fmha_fwd_splitkv_kernel.hpp:522
ck_tile::remove_cvref_t< EpiloguePipeline_ > EpiloguePipeline
Definition: fmha_fwd_splitkv_kernel.hpp:24
static constexpr __host__ std::enable_if_t< Cond, Kargs > MakeKargs(const void *q_ptr, const void *k_ptr, const void *v_ptr, const void *bias_ptr, void *lse_acc_ptr, void *o_acc_ptr, ck_tile::index_t batch, const void *seqstart_q_ptr, const void *seqstart_k_ptr, const void *seqlen_k_ptr, ck_tile::index_t hdim_q, ck_tile::index_t hdim_v, ck_tile::index_t num_head_q, ck_tile::index_t nhead_ratio_qk, ck_tile::index_t num_splits, const void *block_table_ptr, ck_tile::index_t batch_stride_block_table, ck_tile::index_t page_block_size, bool is_gappy, float scale_s, float scale_p, ck_tile::index_t stride_q, ck_tile::index_t stride_k, ck_tile::index_t stride_v, ck_tile::index_t stride_bias, ck_tile::index_t stride_o_acc, ck_tile::index_t nhead_stride_q, ck_tile::index_t nhead_stride_k, ck_tile::index_t nhead_stride_v, ck_tile::index_t nhead_stride_bias, ck_tile::index_t nhead_stride_lse_acc, ck_tile::index_t nhead_stride_o_acc, ck_tile::index_t batch_stride_k, ck_tile::index_t batch_stride_v, ck_tile::index_t split_stride_lse_acc, ck_tile::index_t split_stride_o_acc, ck_tile::index_t window_size_left, ck_tile::index_t window_size_right, ck_tile::index_t mask_type)
Definition: fmha_fwd_splitkv_kernel.hpp:371
ck_tile::remove_cvref_t< typename FmhaPipeline::VLayout > VLayout
Definition: fmha_fwd_splitkv_kernel.hpp:39
static constexpr bool kPadHeadDimQ
Definition: fmha_fwd_splitkv_kernel.hpp:44
std::conditional_t< kIsGroupMode, GroupModeKargs, BatchModeKargs > Kargs
Definition: fmha_fwd_splitkv_kernel.hpp:243
static constexpr CK_TILE_HOST auto GridSize(ck_tile::index_t batch_size, ck_tile::index_t nhead_q, ck_tile::index_t nhead_kv, ck_tile::index_t max_seqlen_q, ck_tile::index_t hdim_v, ck_tile::index_t num_splits)
Definition: fmha_fwd_splitkv_kernel.hpp:484
static constexpr bool kPadSeqLenK
Definition: fmha_fwd_splitkv_kernel.hpp:43
static constexpr bool kPadSeqLenQ
Definition: fmha_fwd_splitkv_kernel.hpp:42
ck_tile::remove_cvref_t< typename FmhaPipeline::KDataType > KDataType
Definition: fmha_fwd_splitkv_kernel.hpp:31
remove_cvref_t< typename FmhaPipeline::ODataType > ODataType
Definition: fmha_fwd_splitkv_kernel.hpp:37
static constexpr bool kMergeNumHeadGroupsSeqLenQ
Definition: fmha_fwd_splitkv_kernel.hpp:50
ck_tile::remove_cvref_t< typename FmhaPipeline::FmhaMask > FmhaMask
Definition: fmha_fwd_splitkv_kernel.hpp:53
static constexpr bool kDoFp8StaticQuant
Definition: fmha_fwd_splitkv_kernel.hpp:48
ck_tile::remove_cvref_t< typename FmhaPipeline::SaccDataType > SaccDataType
Definition: fmha_fwd_splitkv_kernel.hpp:35
static constexpr bool kHasMask
Definition: fmha_fwd_splitkv_kernel.hpp:54
static constexpr CK_TILE_DEVICE auto GetTileIndex(const Kargs &kargs)
Definition: fmha_fwd_splitkv_kernel.hpp:502
static constexpr bool kPadHeadDimV
Definition: fmha_fwd_splitkv_kernel.hpp:45
static __host__ std::string GetName()
Definition: fmha_fwd_splitkv_kernel.hpp:69
ck_tile::remove_cvref_t< typename FmhaPipeline::VDataType > VDataType
Definition: fmha_fwd_splitkv_kernel.hpp:32
static constexpr ck_tile::index_t kBlockPerCuInput
Definition: fmha_fwd_splitkv_kernel.hpp:28
static constexpr __host__ auto BlockSize()
Definition: fmha_fwd_splitkv_kernel.hpp:520
ck_tile::remove_cvref_t< typename FmhaPipeline::LSEDataType > LSEDataType
Definition: fmha_fwd_splitkv_kernel.hpp:34
remove_cvref_t< typename FmhaPipeline::OaccDataType > OaccDataType
Definition: fmha_fwd_splitkv_kernel.hpp:36
static constexpr ck_tile::index_t kBlockPerCu
Definition: fmha_fwd_splitkv_kernel.hpp:26
ck_tile::remove_cvref_t< typename FmhaPipeline::QDataType > QDataType
Definition: fmha_fwd_splitkv_kernel.hpp:30
static constexpr bool kStoreLSE
Definition: fmha_fwd_splitkv_kernel.hpp:47
static constexpr bool kIsPagedKV
Definition: fmha_fwd_splitkv_kernel.hpp:49
static constexpr bool kIsGroupMode
Definition: fmha_fwd_splitkv_kernel.hpp:41
CK_TILE_DEVICE void operator()(Kargs kargs) const
Definition: fmha_fwd_splitkv_kernel.hpp:527
static constexpr __host__ std::enable_if_t< Cond, Kargs > MakeKargs(const void *q_ptr, const void *k_ptr, const void *v_ptr, const void *bias_ptr, void *lse_acc_ptr, void *o_acc_ptr, ck_tile::index_t batch, ck_tile::index_t seqlen_q, ck_tile::index_t seqlen_k, const void *seqlen_k_ptr, ck_tile::index_t hdim_q, ck_tile::index_t hdim_v, ck_tile::index_t num_head_q, ck_tile::index_t nhead_ratio_qk, ck_tile::index_t num_splits, const void *block_table_ptr, ck_tile::index_t batch_stride_block_table, ck_tile::index_t page_block_size, const void *cache_batch_idx, float scale_s, float scale_p, ck_tile::index_t stride_q, ck_tile::index_t stride_k, ck_tile::index_t stride_v, ck_tile::index_t stride_bias, ck_tile::index_t stride_o_acc, ck_tile::index_t nhead_stride_q, ck_tile::index_t nhead_stride_k, ck_tile::index_t nhead_stride_v, ck_tile::index_t nhead_stride_bias, ck_tile::index_t nhead_stride_lse_acc, ck_tile::index_t nhead_stride_o_acc, ck_tile::index_t batch_stride_q, ck_tile::index_t batch_stride_k, ck_tile::index_t batch_stride_v, ck_tile::index_t batch_stride_bias, ck_tile::index_t batch_stride_lse_acc, ck_tile::index_t batch_stride_o_acc, ck_tile::index_t split_stride_lse_acc, ck_tile::index_t split_stride_o_acc, ck_tile::index_t window_size_left, ck_tile::index_t window_size_right, ck_tile::index_t mask_type)
Definition: fmha_fwd_splitkv_kernel.hpp:247
ck_tile::remove_cvref_t< FmhaPipeline_ > FmhaPipeline
Definition: fmha_fwd_splitkv_kernel.hpp:23
Definition: integral_constant.hpp:13
Definition: sequence.hpp:52