/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/fmha/kernel/fmha_batch_prefill_kernel.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/fmha/kernel/fmha_batch_prefill_kernel.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/fmha/kernel/fmha_batch_prefill_kernel.hpp Source File
fmha_batch_prefill_kernel.hpp
Go to the documentation of this file.
1 // Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
2 // SPDX-License-Identifier: MIT
3 
4 #pragma once
5 
6 #include "ck_tile/core.hpp"
7 #include "ck_tile/ops/common.hpp"
12 
13 #include <string>
14 #include <type_traits>
15 #include <utility>
16 #include <variant>
17 
18 // S[seqlen_q, seqlen_k] = Q[seqlen_q, hdim_q] @ K[seqlen_k, hdim_q]
19 // S'[seqlen_q, seqlen_k] = S[seqlen_q, seqlen_k] * Scale[1]
20 // S''[seqlen_q, seqlen_k] = S'[seqlen_q, seqlen_k] + Bias[seqlen_q, seqlen_k]
21 // P[seqlen_q, seqlen_k] = Softmax(S''[seqlen_q, seqlen_k])
22 // O[seqlen_q, hdim_v] = P[seqlen_q, seqlen_k] @ V^T[hdim_v, seqlen_k]
23 
24 namespace ck_tile {
25 
26 template <typename FmhaPipeline_, typename EpiloguePipeline_>
28 {
31  static constexpr ck_tile::index_t kBlockSize = FmhaPipeline::kBlockSize;
32 
33  static constexpr ck_tile::index_t kBlockPerCu = FmhaPipeline::kBlockPerCu;
34  static_assert(kBlockPerCu > 0);
35  static constexpr ck_tile::index_t kBlockPerCuInput = FmhaPipeline::Problem::kBlockPerCu;
36 
47 
49 
50  static constexpr bool kIsGroupMode = FmhaPipeline::kIsGroupMode;
51  static constexpr bool kPadSeqLenQ = FmhaPipeline::kPadSeqLenQ;
52  static constexpr bool kPadSeqLenK = FmhaPipeline::kPadSeqLenK;
53  static constexpr bool kPadHeadDimQ = FmhaPipeline::kPadHeadDimQ;
54  static constexpr bool kPadHeadDimV = FmhaPipeline::kPadHeadDimV;
55  static constexpr bool kHasLogitsSoftCap = FmhaPipeline::kHasLogitsSoftCap;
56  static constexpr auto BiasEnum = FmhaPipeline::BiasEnum;
57  static constexpr bool kStoreLSE = FmhaPipeline::kStoreLSE;
58  static constexpr bool kHasDropout = FmhaPipeline::kHasDropout;
59  static constexpr auto QScaleEnum = FmhaPipeline::Problem::QScaleEnum;
60  static constexpr auto kKVMemoryLayout = FmhaPipeline::Problem::kKVMemoryLayout;
61  static constexpr auto kKVLookupTable = FmhaPipeline::Problem::kKVLookupTable;
62  static constexpr index_t kPageBlockSize = FmhaPipeline::kPageBlockSize;
63  static constexpr index_t kVectorSize = FmhaPipeline::kVectorSize;
66  static constexpr bool kHasMask = FmhaMask::IsMasking;
67 
68  static constexpr bool kUseAsyncCopy = FmhaPipeline::Policy::AsyncCopy;
69  template <ck_tile::index_t I> // to avoid duplicated base class prblem, introduce an template
70  // arg
72  {
73  };
74 
75  // kargs use aggregate initializer, so no constructor will provided
76  // use inheritance to minimize karg size
77  // user need to use MakeKargs() function to create kargs.
79  {
83  };
84 
86  {
90  };
91 
97 
99  {
100  const void* q_ptr;
101  const void* k_ptr;
102  const void* v_ptr;
103  void* o_ptr;
104 
109 
111  // for MQA/GQA, nhead could be different. This parameter is nhead_q / nhead_k
112  // if this param is larger than 1, indicate MQA/GQA case
114 
118 
119  float scale_s;
120 
125 
130  };
131 
133  {
135 
136  void init_logits_soft_cap(float logits_soft_cap_)
137  {
138  if(0 < logits_soft_cap_)
139  {
140  logits_soft_cap = logits_soft_cap_;
142  }
143  else
144  {
145  logits_soft_cap = 0.f;
146  logits_soft_cap_rcp = 0.f;
147  }
148  }
149 
152  };
153 
155  {
156  const void* bias_ptr = nullptr;
159  };
160 
162  {
164  };
165 
167  {
168  // alibi is batch*nhead*1, no matter in batch/group mode, they are the same
169  const void* alibi_slope_ptr;
170  ck_tile::index_t alibi_slope_stride; // stride in batch, or 0 for all batch share same slope
171  };
172 
174  {
175  // ck_tile::index_t window_size_left, window_size_right;
178  };
179 
181  {
182  void* lse_ptr = nullptr;
185  };
186 
188  {
189  const void* q_descale_ptr = nullptr;
190  const void* k_descale_ptr = nullptr;
191  const void* v_descale_ptr = nullptr;
192  };
193 
195  {
196  template <typename T>
198  {
199  T val;
200  const T* ptr;
201  };
202 
206  };
207 
209  {
210  void init_dropout(float p_drop, uint64_t seed, uint64_t offset)
211  {
212  float p_undrop = 1.0 - p_drop;
215  rp_undrop = 1.0 / p_undrop;
216 
217  this->drop_seed.val = seed;
218  this->drop_offset.val = offset;
219  this->is_drop_seed_offset_from_host = true;
220  }
221 
222  void init_dropout(float p_drop, const uint64_t* seed_ptr, const uint64_t* offset_ptr)
223  {
224  float p_undrop = 1.0 - p_drop;
227  rp_undrop = 1.0 / p_undrop;
228 
229  this->drop_seed.ptr = seed_ptr;
230  this->drop_offset.ptr = offset_ptr;
231  this->is_drop_seed_offset_from_host = false;
232  }
233 
234  float rp_undrop = 1;
236  bool is_store_randval = false;
237  void* rand_val_ptr = nullptr;
238 
241  };
242 
244  {
246  };
247 
250  std::conditional_t<BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS,
251  FmhaFwdBatchModeBiasKargs,
252  std::conditional_t<BiasEnum == BlockAttentionBiasEnum::ALIBI,
253  FmhaFwdAlibiKargs,
254  FmhaFwdEmptyKargs<0>>>,
255  std::conditional_t<kHasMask, FmhaFwdMaskKargs, FmhaFwdEmptyKargs<1>>,
256  std::conditional_t<kStoreLSE, FmhaFwdCommonLSEKargs, FmhaFwdEmptyKargs<2>>,
257  std::conditional_t<QScaleEnum == BlockAttentionQuantScaleEnum::PERTENSOR,
258  FmhaFwdCommonQScaleKargs,
259  FmhaFwdEmptyKargs<3>>,
260  std::conditional_t<kHasDropout, FmhaFwdBatchModeDropoutKargs, FmhaFwdEmptyKargs<4>>,
261  std::conditional_t<kHasLogitsSoftCap, FmhaFwdLogitsSoftCapKargs, FmhaFwdEmptyKargs<5>>
262  {
267  };
268 
271  std::conditional_t<BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS,
272  FmhaFwdCommonBiasKargs,
273  std::conditional_t<BiasEnum == BlockAttentionBiasEnum::ALIBI,
274  FmhaFwdAlibiKargs,
275  FmhaFwdEmptyKargs<0>>>,
276  std::conditional_t<kHasMask, FmhaFwdMaskKargs, FmhaFwdEmptyKargs<1>>,
277  std::conditional_t<kStoreLSE, FmhaFwdCommonLSEKargs, FmhaFwdEmptyKargs<2>>,
278  std::conditional_t<QScaleEnum == BlockAttentionQuantScaleEnum::PERTENSOR,
279  FmhaFwdCommonQScaleKargs,
280  FmhaFwdEmptyKargs<3>>,
281  std::conditional_t<kHasDropout, FmhaFwdCommonDropoutKargs, FmhaFwdEmptyKargs<4>>,
282  std::conditional_t<kHasLogitsSoftCap, FmhaFwdLogitsSoftCapKargs, FmhaFwdEmptyKargs<5>>
283  {
287  };
288 
289  using Kargs = std::conditional_t<kIsGroupMode, FmhaFwdGroupModeKargs, FmhaFwdBatchModeKargs>;
290 
292  {
296  };
297 
298  template <bool Cond = !kIsGroupMode>
299  CK_TILE_HOST static constexpr std::enable_if_t<Cond, Kargs>
300  MakeKargs(const void* q_ptr,
301  const void* k_ptr,
302  const void* v_ptr,
303  const void* bias_ptr,
304  const void* q_descale_ptr,
305  const void* k_descale_ptr,
306  const void* v_descale_ptr,
307  void* rand_val_ptr,
308  void* lse_ptr,
309  void* o_ptr,
310  ck_tile::index_t seqlen_q,
311  ck_tile::index_t hdim_q,
312  ck_tile::index_t hdim_v,
313  ck_tile::index_t num_head_q,
314  ck_tile::index_t nhead_ratio_qk,
315  int32_t num_total_pages,
316  ck_tile::index_t page_block_size,
317  const PageBlockTableKargs& page_table,
318  float scale_s,
319  [[maybe_unused]] float scale_p,
320  [[maybe_unused]] float scale_o,
321  float logits_soft_cap,
322  ck_tile::index_t stride_q,
323  ck_tile::index_t stride_k,
324  ck_tile::index_t stride_v,
325  ck_tile::index_t stride_bias,
326  ck_tile::index_t stride_randval,
327  ck_tile::index_t stride_o,
328  ck_tile::index_t nhead_stride_q,
329  ck_tile::index_t nhead_stride_k,
330  ck_tile::index_t nhead_stride_v,
331  ck_tile::index_t nhead_stride_bias,
332  ck_tile::index_t nhead_stride_randval,
333  ck_tile::index_t nhead_stride_lse,
334  ck_tile::index_t nhead_stride_o,
335  ck_tile::index_t batch_stride_q,
336  ck_tile::index_t batch_stride_k,
337  ck_tile::index_t batch_stride_v,
338  ck_tile::index_t batch_stride_bias,
339  ck_tile::index_t batch_stride_randval,
340  ck_tile::index_t batch_stride_lse,
341  ck_tile::index_t batch_stride_o,
342  ck_tile::index_t window_size_left,
343  ck_tile::index_t window_size_right,
344  ck_tile::index_t sink_size,
345  ck_tile::index_t mask_type,
346  float p_drop,
347  bool s_randval,
348  std::variant<std::pair<uint64_t, uint64_t>, std::pair<const void*, const void*>>
349  drop_seed_offset)
350  {
351  Kargs kargs{{q_ptr,
352  k_ptr,
353  v_ptr,
354  o_ptr,
355  seqlen_q,
356  -1,
357  hdim_q,
358  hdim_v,
359  num_head_q,
360  nhead_ratio_qk,
361  num_total_pages,
362  page_block_size,
363  page_table,
364 #if CK_TILE_FMHA_FWD_FAST_EXP2
365  static_cast<float>(scale_s * ck_tile::log2e_v<>),
366 #else
367  scale_s,
368 #endif
369  stride_q,
370  stride_k,
371  stride_v,
372  stride_o,
373  nhead_stride_q,
374  nhead_stride_k,
375  nhead_stride_v,
376  nhead_stride_o}, // args for common karg
377  {}, // placeholder for bias
378  {}, // placeholder for mask
379  {}, // placeholder for lse
380  {}, // placeholder for qscale
381  {}, // placeholder for dropout
382  {}, // placeholder for logits_soft_cap
383  batch_stride_q,
384  batch_stride_k,
385  batch_stride_v,
386  batch_stride_o};
387 
389  {
390  kargs.bias_ptr = bias_ptr;
391  kargs.stride_bias = stride_bias;
392  kargs.nhead_stride_bias = nhead_stride_bias;
393  kargs.batch_stride_bias = batch_stride_bias;
394  }
395  else if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI)
396  {
397  kargs.alibi_slope_ptr = bias_ptr;
398  kargs.alibi_slope_stride = stride_bias;
399  }
400  if constexpr(kHasMask)
401  {
402  kargs.window_size_left = window_size_left;
403  kargs.window_size_right = window_size_right;
404  kargs.sink_size = sink_size;
405  kargs.mask_type = static_cast<ck_tile::GenericAttentionMaskEnum>(mask_type);
406  }
407  if constexpr(kStoreLSE)
408  {
409  kargs.lse_ptr = lse_ptr;
410  kargs.nhead_stride_lse = nhead_stride_lse;
411  kargs.batch_stride_lse = batch_stride_lse;
412  }
414  {
415  kargs.q_descale_ptr = q_descale_ptr;
416  kargs.k_descale_ptr = k_descale_ptr;
417  kargs.v_descale_ptr = v_descale_ptr;
418  }
419  if constexpr(kHasDropout)
420  {
421  if(drop_seed_offset.index() == 0) // seed & offset come from host
422  {
423  const auto& [seed, offset] = std::get<0>(drop_seed_offset);
424  kargs.init_dropout(p_drop, seed, offset);
425  }
426  else // seed & offset come from device
427  {
428  const auto& [seed_ptr, offset_ptr] = std::get<1>(drop_seed_offset);
429  kargs.init_dropout(p_drop,
430  reinterpret_cast<const uint64_t*>(seed_ptr),
431  reinterpret_cast<const uint64_t*>(offset_ptr));
432  }
433 
434  kargs.rand_val_ptr = rand_val_ptr;
435  kargs.stride_randval = stride_randval;
436  kargs.nhead_stride_randval = nhead_stride_randval;
437  kargs.batch_stride_randval = batch_stride_randval;
438  kargs.is_store_randval = s_randval;
439  }
440  if constexpr(kHasLogitsSoftCap)
441  {
442  kargs.init_logits_soft_cap(logits_soft_cap);
443  }
444 
445  return kargs;
446  }
447 
448  template <bool Cond = kIsGroupMode>
449  CK_TILE_HOST static constexpr std::enable_if_t<Cond, Kargs>
450  MakeKargs(const void* q_ptr,
451  const void* k_ptr,
452  const void* v_ptr,
453  const void* bias_ptr,
454  const void* q_descale_ptr,
455  const void* k_descale_ptr,
456  const void* v_descale_ptr,
457  void* rand_val_ptr,
458  void* lse_ptr,
459  void* o_ptr,
460  const void* seqstart_q_ptr,
461  ck_tile::index_t hdim_q,
462  ck_tile::index_t hdim_v,
463  ck_tile::index_t num_head_q,
464  ck_tile::index_t nhead_ratio_qk,
465  int32_t num_total_pages,
466  ck_tile::index_t page_block_size,
467  const PageBlockTableKargs& page_table,
468  float scale_s,
469  [[maybe_unused]] float scale_p,
470  [[maybe_unused]] float scale_o,
471  float logits_soft_cap,
472  ck_tile::index_t stride_q,
473  ck_tile::index_t stride_k,
474  ck_tile::index_t stride_v,
475  ck_tile::index_t stride_bias,
476  ck_tile::index_t stride_randval,
477  ck_tile::index_t stride_o,
478  ck_tile::index_t nhead_stride_q,
479  ck_tile::index_t nhead_stride_k,
480  ck_tile::index_t nhead_stride_v,
481  ck_tile::index_t nhead_stride_bias,
482  ck_tile::index_t nhead_stride_randval,
483  ck_tile::index_t nhead_stride_lse,
484  ck_tile::index_t nhead_stride_o,
485  ck_tile::index_t batch_stride_k,
486  ck_tile::index_t batch_stride_v,
487  ck_tile::index_t window_size_left,
488  ck_tile::index_t window_size_right,
489  ck_tile::index_t sink_size,
490  ck_tile::index_t mask_type,
491  float p_drop,
492  bool s_randval,
493  std::variant<std::pair<uint64_t, uint64_t>, std::pair<const void*, const void*>>
494  drop_seed_offset)
495  {
496  Kargs kargs{{q_ptr,
497  k_ptr,
498  v_ptr,
499  o_ptr,
500  -1, // seqlen will be updated by another pointer
501  -1, //
502  hdim_q,
503  hdim_v,
504  num_head_q,
505  nhead_ratio_qk,
506  num_total_pages,
507  page_block_size,
508  page_table,
509 #if CK_TILE_FMHA_FWD_FAST_EXP2
510  static_cast<float>(scale_s * ck_tile::log2e_v<>),
511 #else
512  scale_s,
513 #endif
514  stride_q,
515  stride_k,
516  stride_v,
517  stride_o,
518  nhead_stride_q,
519  nhead_stride_k,
520  nhead_stride_v,
521  nhead_stride_o}, // args for common karg
522  {}, // placeholder for bias
523  {}, // placeholder for mask
524  {}, // placeholder for lse
525  {}, // placeholder for qscale
526  {}, // placeholder for dropout
527  {}, // placeholder for logits_soft_cap
528  reinterpret_cast<const int32_t*>(seqstart_q_ptr),
529  batch_stride_k,
530  batch_stride_v};
531 
533  {
534  kargs.bias_ptr = bias_ptr;
535  kargs.stride_bias = stride_bias;
536  kargs.nhead_stride_bias = nhead_stride_bias;
537  }
538  else if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI)
539  {
540  kargs.alibi_slope_ptr = bias_ptr;
541  kargs.alibi_slope_stride = stride_bias;
542  }
543  if constexpr(kHasMask)
544  {
545  kargs.window_size_left = window_size_left;
546  kargs.window_size_right = window_size_right;
547  kargs.sink_size = sink_size;
548  kargs.mask_type = static_cast<ck_tile::GenericAttentionMaskEnum>(mask_type);
549  }
550  if constexpr(kStoreLSE)
551  {
552  kargs.lse_ptr = lse_ptr;
553  kargs.nhead_stride_lse = nhead_stride_lse;
554  }
556  {
557  kargs.q_descale_ptr = q_descale_ptr;
558  kargs.k_descale_ptr = k_descale_ptr;
559  kargs.v_descale_ptr = v_descale_ptr;
560  }
561  if constexpr(kHasDropout)
562  {
563  if(drop_seed_offset.index() == 0) // seed & offset come from host
564  {
565  const auto& [seed, offset] = std::get<0>(drop_seed_offset);
566  kargs.init_dropout(p_drop, seed, offset);
567  }
568  else // seed & offset come from device
569  {
570  const auto& [seed_ptr, offset_ptr] = std::get<1>(drop_seed_offset);
571  kargs.init_dropout(p_drop,
572  reinterpret_cast<const uint64_t*>(seed_ptr),
573  reinterpret_cast<const uint64_t*>(offset_ptr));
574  }
575 
576  kargs.rand_val_ptr = rand_val_ptr;
577  kargs.stride_randval = stride_randval;
578  kargs.nhead_stride_randval = nhead_stride_randval;
579  kargs.is_store_randval = s_randval;
580  }
581  if constexpr(kHasLogitsSoftCap)
582  {
583  kargs.init_logits_soft_cap(logits_soft_cap);
584  }
585 
586  return kargs;
587  }
588 
589  CK_TILE_HOST static constexpr auto GridSize(ck_tile::index_t batch_size_,
590  ck_tile::index_t nhead_,
591  ck_tile::index_t seqlen_q_,
592  ck_tile::index_t hdim_v_)
593  {
594  if constexpr(kIsGroupMode)
595  {
596  // TODO: this may need tuning
597  return dim3(nhead_,
598  batch_size_,
599  ck_tile::integer_divide_ceil(seqlen_q_, FmhaPipeline::kM0) *
600  ck_tile::integer_divide_ceil(hdim_v_, FmhaPipeline::kN1));
601  }
602  else
603  {
604  // TODO: this may need tuning
605  return dim3(ck_tile::integer_divide_ceil(seqlen_q_, FmhaPipeline::kM0) *
606  ck_tile::integer_divide_ceil(hdim_v_, FmhaPipeline::kN1),
607  nhead_,
608  batch_size_);
609  }
610  }
611 
612  CK_TILE_DEVICE static constexpr auto GetTileIndex(const Kargs& kargs)
613  {
614  if constexpr(kIsGroupMode)
615  {
616  // const index_t num_tile_m0 = seqlen_q / kM0;
617  const index_t num_tile_n1 =
618  ck_tile::integer_divide_ceil(kargs.hdim_v, FmhaPipeline::kN1);
619 
620  const index_t i_block = blockIdx.z;
621  const index_t i_nhead = blockIdx.x;
622  const index_t i_batch = blockIdx.y;
623 
624  const auto f = [](index_t dividend, index_t divisor) {
625  index_t quotient = dividend / divisor;
626  index_t modulus = dividend - quotient * divisor;
627  return ck_tile::make_tuple(quotient, modulus);
628  };
629 
630  const auto [i_tile_m, i_tile_n] = f(i_block, num_tile_n1);
631  if constexpr(kHasMask)
632  {
633  // assume that num_tile_n1 is always 1
634  return ck_tile::make_tuple(gridDim.z - 1 - i_tile_m, i_tile_n, i_nhead, i_batch);
635  }
636  else
637  {
638  return ck_tile::make_tuple(i_tile_m, i_tile_n, i_nhead, i_batch);
639  }
640  }
641  else
642  {
643  // const index_t num_tile_m0 = seqlen_q / kM0;
644  const index_t num_tile_n1 =
645  ck_tile::integer_divide_ceil(kargs.hdim_v, FmhaPipeline::kN1);
646 
647  const index_t i_block = blockIdx.x;
648  const index_t i_nhead = blockIdx.y;
649  const index_t i_batch = blockIdx.z;
650 
651  const auto f = [](index_t dividend, index_t divisor) {
652  index_t quotient = dividend / divisor;
653  index_t modulus = dividend - quotient * divisor;
654  return ck_tile::make_tuple(quotient, modulus);
655  };
656 
657  const auto [i_tile_m, i_tile_n] = f(i_block, num_tile_n1);
658 
659  if constexpr(kHasMask)
660  {
661  // assume that num_tile_n1 is always 1
662  return ck_tile::make_tuple(gridDim.x - 1 - i_tile_m, i_tile_n, i_nhead, i_batch);
663  }
664  else
665  {
666  return ck_tile::make_tuple(i_tile_m, i_tile_n, i_nhead, i_batch);
667  }
668  }
669  }
670 
671  CK_TILE_HOST static dim3 BlockSize()
672  {
673  if(is_wave32())
674  {
675  return dim3(kBlockSize / 2);
676  }
677  else
678  {
679  return dim3(kBlockSize);
680  }
681  }
682 
684  {
685  return ck_tile::max(FmhaPipeline::GetSmemSize(), EpiloguePipeline::GetSmemSize());
686  }
687 
688  CK_TILE_DEVICE void operator()(Kargs kargs) const
689  {
690  // allocate LDS
691  __shared__ char smem_ptr[GetSmemSize()];
692 
693  // divide problem
694  const auto [i_tile_m, i_tile_n, i_nhead, i_batch] = GetTileIndex(kargs);
695 
696  const index_t i_m0 = amd_wave_read_first_lane(i_tile_m * FmhaPipeline::kM0);
697  const index_t i_n1 = amd_wave_read_first_lane(i_tile_n * FmhaPipeline::kN1);
698 
699  long_index_t batch_offset_q = 0;
700  long_index_t batch_offset_bias = 0;
701  long_index_t batch_offset_randval = 0;
702  long_index_t batch_offset_lse = 0;
703  long_index_t batch_offset_o = 0;
704 
705  const index_t seqlen_k = [&]() {
706  if constexpr(kKVLookupTable ==
708  {
709  const int32_t page_start = kargs.page_table.kv_indptr[i_batch];
710  const int32_t page_end = kargs.page_table.kv_indptr[i_batch + 1];
711  const int32_t num_page_blocks = page_end - page_start;
712  const int32_t last_page_len = [&]() {
713  if constexpr(kPageBlockSize == 1)
714  return static_cast<int32_t>(kPageBlockSize);
715  else
716  return kargs.page_table.kv_last_page_lens[i_batch];
717  }();
718  return num_page_blocks > 0
719  ? static_cast<index_t>((num_page_blocks - 1) * kargs.page_block_size +
720  last_page_len)
721  : 0;
722  }
723  else // BlockAttentionKVCacheLookupTableEnum::VLLM_BLOCK_TABLE_2D
724  {
725  if(kargs.page_table.seqlen_k_ptr != nullptr)
726  return static_cast<index_t>(kargs.page_table.seqlen_k_ptr[i_batch]);
727  else
728  return kargs.seqlen_k;
729  }
730  }();
731  const int32_t* page_idx = [&]() {
732  if constexpr(kKVLookupTable ==
734  {
735  return kargs.page_table.kv_page_indices + kargs.page_table.kv_indptr[i_batch];
736  }
737  else // BlockAttentionKVCacheLookupTableEnum::VLLM_BLOCK_TABLE_2D
738  {
739  return kargs.page_table.block_table_ptr +
740  static_cast<long_index_t>(i_batch) *
741  kargs.page_table.batch_stride_block_table;
742  }
743  }();
744 
745  if constexpr(kIsGroupMode)
746  {
747  // get starting offset for each batch
748  const long_index_t query_start = kargs.seqstart_q_ptr[i_batch];
749 
750  batch_offset_q = query_start * kargs.stride_q;
751 
753  {
754  batch_offset_bias = query_start * kargs.stride_bias;
755  }
756  if constexpr(kStoreLSE)
757  {
758  batch_offset_lse = query_start;
759  }
760  if constexpr(kHasDropout)
761  {
762  batch_offset_randval = query_start * kargs.stride_randval;
763  }
764  batch_offset_o = query_start * kargs.stride_o;
765 
766  // get real # queries & # keys under group mode
767  kargs.seqlen_q = kargs.seqstart_q_ptr[i_batch + 1] - query_start;
768 
769  // # of required blocks is different in each groups, terminate unnecessary blocks
770  // earlier
771  if(kargs.seqlen_q <= i_m0)
772  {
773  return;
774  }
775 
776  kargs.seqlen_k = seqlen_k;
777  }
778  else
779  {
780  batch_offset_q = static_cast<long_index_t>(i_batch) * kargs.batch_stride_q;
781 
783  {
784  batch_offset_bias = static_cast<long_index_t>(i_batch) * kargs.batch_stride_bias;
785  }
786  if constexpr(kStoreLSE)
787  {
788  batch_offset_lse = static_cast<long_index_t>(i_batch) * kargs.batch_stride_lse;
789  }
790  if constexpr(kHasDropout)
791  {
792  batch_offset_randval =
793  static_cast<long_index_t>(i_batch) * kargs.batch_stride_randval;
794  }
795  batch_offset_o = static_cast<long_index_t>(i_batch) * kargs.batch_stride_o;
796 
797  kargs.seqlen_k = seqlen_k;
798  }
799 
800  // for simplicity, batch stride we just modify the pointer
801  const QDataType* q_ptr = reinterpret_cast<const QDataType*>(kargs.q_ptr) +
802  static_cast<long_index_t>(i_nhead) * kargs.nhead_stride_q +
803  batch_offset_q;
804  const KDataType* k_ptr =
805  reinterpret_cast<const KDataType*>(kargs.k_ptr) +
806  static_cast<long_index_t>(i_nhead / kargs.nhead_ratio_qk) * kargs.nhead_stride_k;
807  const VDataType* v_ptr =
808  reinterpret_cast<const VDataType*>(kargs.v_ptr) +
809  static_cast<long_index_t>(i_nhead / kargs.nhead_ratio_qk) * kargs.nhead_stride_v;
810  ODataType* o_ptr = reinterpret_cast<ODataType*>(kargs.o_ptr) +
811  static_cast<long_index_t>(i_nhead) * kargs.nhead_stride_o +
812  batch_offset_o;
813 
814  // Q/K/V DRAM and DRAM window
815  const auto q_dram = [&]() {
816  const auto q_dram_naive = make_naive_tensor_view<address_space_enum::global>(
817  q_ptr,
818  make_tuple(kargs.seqlen_q, kargs.hdim_q),
819  make_tuple(kargs.stride_q, 1),
821  number<1>{});
822  if constexpr(FmhaPipeline::kQLoadOnce)
823  {
824  return pad_tensor_view(
825  q_dram_naive,
828  }
829  else
830  {
831  return pad_tensor_view(
832  q_dram_naive,
835  }
836  }();
837  const auto k_dram = [&]() {
838  if constexpr(kKVMemoryLayout ==
840  {
841  // Vectorized K Layout: [NumPages, D/kVectorSize, S, kVectorSize]
842  // Logical View for Pipeline: (TotalSeqK, D)
843 
844  // Define the naive physical view with 4D shape: (NumPages, HeadDim/kVectorSize,
845  // PageBlockSize, kVectorSize)
846  // Strides: (BatchStride, PageBlockSize*kVectorSize, kVectorSize, 1)
847  const auto k_dram_naive = make_naive_tensor_view<address_space_enum::global>(
848  k_ptr,
849  make_tuple(kargs.num_total_pages,
850  kargs.hdim_q / kVectorSize,
851  kargs.page_block_size,
852  kVectorSize),
853  make_tuple(
854  kargs.batch_stride_k, kargs.page_block_size * kVectorSize, kVectorSize, 1),
856  number<1>{});
857 
858  // Merge to (TotalSeqK, D) in a single transform:
859  // physical (Page, D/vec, S, vec) -> logical (TotalSeqK, D)
860  auto k_dram_2d = transform_tensor_view(
861  k_dram_naive,
862  make_tuple(make_merge_transform(make_tuple(kargs.num_total_pages,
863  kargs.page_block_size)), // TotalSeqK
865  make_tuple(static_cast<int32_t>(kargs.hdim_q / kVectorSize),
866  static_cast<int32_t>(kVectorSize)))), // D
869 
870  constexpr bool kPadSeqLenK_ = kUseAsyncCopy ? kPadSeqLenK : true;
871  return pad_tensor_view(
872  k_dram_2d,
875  }
876  else
877  {
878  // Linear K Layout: [NumPages, PageSize, NumHeads, HeadDim]
879  // Logical View for Pipeline: (TotalSeqK, D)
880  const auto k_dram_naive = make_naive_tensor_view<address_space_enum::global>(
881  k_ptr,
882  make_tuple(kargs.num_total_pages, kargs.page_block_size, kargs.hdim_q),
883  make_tuple(kargs.batch_stride_k, kargs.stride_k, 1),
885  number<1>{});
886 
887  // Merge to (TotalSeqK, D) in a single transform:
888  // physical (Page, S, D) -> logical (TotalSeqK, D)
889  auto k_dram_2d = transform_tensor_view(
890  k_dram_naive,
892  make_tuple(kargs.num_total_pages, kargs.page_block_size)),
893  make_pass_through_transform(kargs.hdim_q)),
896 
897  constexpr bool kPadSeqLenK_ = kUseAsyncCopy ? kPadSeqLenK : true;
898  return pad_tensor_view(
899  k_dram_2d,
902  }
903  }();
904  const auto v_dram = [&]() {
905  if constexpr(kKVMemoryLayout ==
907  {
908  // Vectorized V Layout: [NumPages, S/kVectorSize, D, kVectorSize]
909  // Logical View for Pipeline: (D, TotalSeqK) - Transposed for GEMM
910 
911  // Define the naive physical view with 4D shape: (NumPages,
912  // PageBlockSize/kVectorSize, HeadDim, kVectorSize)
913  // Strides: (BatchStride, HeadDim*kVectorSize, kVectorSize, 1)
914  const auto v_dram_naive = make_naive_tensor_view<address_space_enum::global>(
915  v_ptr,
916  make_tuple(kargs.num_total_pages,
917  kargs.page_block_size / kVectorSize,
918  kargs.hdim_v,
919  kVectorSize),
920  make_tuple(kargs.batch_stride_v, kargs.hdim_v * kVectorSize, kVectorSize, 1),
922  number<1>{});
923 
924  // Merge to (D, TotalSeqK) in a single transform:
925  // physical (Page, S/vec, D, vec) -> logical (D, TotalSeqK)
926  auto v_dram_final = transform_tensor_view(
927  v_dram_naive,
928  make_tuple(make_pass_through_transform(kargs.hdim_v), // D
929  make_merge_transform(make_tuple(kargs.num_total_pages,
930  kargs.page_block_size / kVectorSize,
931  kVectorSize))), // TotalSeqK
934 
935  constexpr bool kPadSeqLenK_ = kUseAsyncCopy ? kPadSeqLenK : true;
936  return pad_tensor_view(
937  v_dram_final,
940  }
941  else
942  {
943  // Linear V Layout: [NumPages, PageSize, NumHeads, HeadDim]
944  // Logical View for Pipeline: (D, TotalSeqK)
945  const auto v_dram_naive = make_naive_tensor_view<address_space_enum::global>(
946  v_ptr,
947  make_tuple(kargs.num_total_pages, kargs.page_block_size, kargs.hdim_v),
948  make_tuple(kargs.batch_stride_v, kargs.stride_v, 1),
950  number<1>{});
951 
952  // Merge to (D, TotalSeqK) in a single transform:
953  // physical (Page, S, D) -> logical (D, TotalSeqK)
954  auto v_dram_final = transform_tensor_view(
955  v_dram_naive,
958  make_tuple(kargs.num_total_pages, kargs.page_block_size))),
961 
962  constexpr bool kPadSeqLenK_ = kUseAsyncCopy ? kPadSeqLenK : true;
963  return pad_tensor_view(
964  v_dram_final,
967  }
968  }();
969  auto q_dram_window = make_tile_window(
970  q_dram,
971  [&]() {
972  if constexpr(FmhaPipeline::kQLoadOnce)
975  else
977  }(),
978  {i_m0, 0});
979 
980  auto k_dram_window = make_tile_window(
982 
983  auto v_dram_window =
984  make_tile_window(v_dram,
986  {i_n1, 0});
989  const auto bias_dram_window = [&, i_nhead_ = i_nhead]() {
990  constexpr auto bias_dram_window_lengths =
993  {
994  const BiasDataType* bias_ptr =
995  reinterpret_cast<const BiasDataType*>(kargs.bias_ptr) +
996  static_cast<long_index_t>(i_nhead_) * kargs.nhead_stride_bias +
997  batch_offset_bias;
998 
999  const auto bias_dram = [&]() {
1000  const auto bias_dram_naive = make_naive_tensor_view<address_space_enum::global>(
1001  bias_ptr,
1002  make_tuple(kargs.seqlen_q, kargs.seqlen_k),
1003  make_tuple(kargs.stride_bias, 1),
1005  number<1>{});
1006 
1007  return pad_tensor_view(bias_dram_naive,
1008  bias_dram_window_lengths,
1010  }();
1011 
1012  return make_tile_window(bias_dram, bias_dram_window_lengths, {i_m0, 0});
1013  }
1014  else
1015  {
1016  return make_null_tile_window(bias_dram_window_lengths);
1017  }
1018  }();
1019 
1020  // lse
1021  auto lse_dram_window = [&, i_nhead_ = i_nhead]() {
1022  constexpr auto lse_dram_window_lengths = make_tuple(number<FmhaPipeline::kM0>{});
1023  if constexpr(kStoreLSE)
1024  {
1025  LSEDataType* lse_ptr =
1026  reinterpret_cast<LSEDataType*>(kargs.lse_ptr) +
1027  static_cast<long_index_t>(i_nhead_) * kargs.nhead_stride_lse + batch_offset_lse;
1028 
1029  const auto lse_dram = [&]() {
1030  const auto lse_dram_naive = make_naive_tensor_view<address_space_enum::global>(
1031  lse_ptr,
1032  make_tuple(kargs.seqlen_q),
1033  make_tuple(1),
1034  number<1>{},
1035  number<1>{});
1036 
1037  return pad_tensor_view(
1038  lse_dram_naive, lse_dram_window_lengths, sequence<kPadSeqLenQ>{});
1039  }();
1040 
1041  return make_tile_window(lse_dram, lse_dram_window_lengths, {i_m0});
1042  }
1043  else
1044  {
1045  return make_null_tile_window(lse_dram_window_lengths);
1046  }
1047  }();
1048 
1049  auto dropout = [&, i_nhead_ = i_nhead, i_batch_ = i_batch]() {
1050  if constexpr(kHasDropout)
1051  {
1052  return BlockDropout{i_batch_,
1053  i_nhead_,
1054  kargs.num_head_q,
1055  kargs.is_drop_seed_offset_from_host ? kargs.drop_seed.val
1056  : *kargs.drop_seed.ptr,
1057  kargs.is_drop_seed_offset_from_host ? kargs.drop_offset.val
1058  : *kargs.drop_offset.ptr,
1059  kargs.rp_undrop,
1060  kargs.p_undrop_in_uint8_t,
1061  kargs.is_store_randval};
1062  }
1063  else
1064  {
1065  return NullBlockDropout{};
1066  };
1067  }();
1068 
1069  auto randval_dram_window = [&, i_nhead_ = i_nhead]() {
1070  constexpr auto randval_dram_window_lengths =
1072  if constexpr(kHasDropout)
1073  {
1074  RandValOutputDataType* rand_val_ptr =
1075  reinterpret_cast<RandValOutputDataType*>(kargs.rand_val_ptr) +
1076  static_cast<long_index_t>(i_nhead_) * kargs.nhead_stride_randval +
1077  batch_offset_randval;
1078 
1079  const auto randval_dram = [&]() {
1080  const auto randval_dram_naive =
1081  make_naive_tensor_view<address_space_enum::global>(
1082  rand_val_ptr,
1083  make_tuple(kargs.seqlen_q, kargs.seqlen_k),
1084  make_tuple(kargs.stride_randval, 1),
1086  number<1>{});
1087 
1088  return pad_tensor_view(randval_dram_naive,
1089  randval_dram_window_lengths,
1091  }();
1092 
1093  return make_tile_window(randval_dram, randval_dram_window_lengths, {i_m0, 0});
1094  }
1095  else
1096  {
1097  return make_null_tile_window(randval_dram_window_lengths);
1098  }
1099  }();
1100 
1101  FmhaMask mask = [&]() {
1102  if constexpr(kHasMask)
1103  return ck_tile::make_generic_attention_mask_from_lr_window<FmhaMask>(
1104  kargs.window_size_left,
1105  kargs.window_size_right,
1106  kargs.sink_size,
1107  kargs.seqlen_q,
1108  kargs.seqlen_k,
1110  else
1111  return FmhaMask{kargs.seqlen_q, kargs.seqlen_k};
1112  }();
1113 
1114  // WA i_batch capture structure binding before c++20
1115  auto position_encoding = [&, i_batch_ = i_batch, i_nhead_ = i_nhead]() {
1116  if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI)
1117  {
1118  // data loading, shared by entire wg
1119  // TODO: how to use s_read?
1120  SaccDataType slope =
1121  *(reinterpret_cast<const SaccDataType*>(kargs.alibi_slope_ptr) +
1122  i_batch_ * kargs.alibi_slope_stride + i_nhead_);
1123 #if CK_TILE_FMHA_FWD_FAST_EXP2
1124  slope *= ck_tile::log2e_v<>;
1125 #endif
1126  if constexpr(kHasMask)
1127  {
1128  return make_alibi_from_lr_mask<SaccDataType, true>(slope,
1129  kargs.window_size_left,
1130  kargs.window_size_right,
1131  kargs.seqlen_q,
1132  kargs.seqlen_k,
1133  kargs.mask_type);
1134  }
1135  else
1136  {
1138  slope, kargs.seqlen_q, kargs.seqlen_k, AlibiMode::FROM_BOTTOM_RIGHT};
1139  }
1140  }
1141  else
1142  {
1144  }
1145  }();
1146 
1147  AttentionVariant variant;
1148  const auto variant_params = [&] {
1149  const float scale_s = [&] {
1151  {
1152  float q_descale = *(reinterpret_cast<const float*>(kargs.q_descale_ptr));
1153  float k_descale = *(reinterpret_cast<const float*>(kargs.k_descale_ptr));
1154 
1155  return kargs.scale_s * q_descale * k_descale;
1156  }
1157  else
1158  {
1159  return kargs.scale_s;
1160  }
1161  }();
1162 
1163  if constexpr(kHasLogitsSoftCap)
1164  {
1166  mask, scale_s, kargs.logits_soft_cap, kargs.logits_soft_cap_rcp};
1167  }
1168  else
1169  {
1170  return ck_tile::StandardAttentionParams<FmhaMask>{mask, scale_s};
1171  }
1172  }();
1173 
1174  BlockIndices block_indices{i_batch, i_nhead, i_nhead / kargs.nhead_ratio_qk};
1175 
1176  const index_t stride_k_for_pipeline =
1178  ? kVectorSize
1179  : kargs.stride_k;
1180  const index_t stride_v_for_pipeline =
1182  ? kargs.hdim_v
1183  : kargs.stride_v;
1184 
1185  auto o_acc_tile = [&] {
1187  {
1188  // TODO - move global load of descale to pipeline
1189  float v_descale = *(reinterpret_cast<const float*>(kargs.v_descale_ptr));
1190 
1191  float scale_p = ck_tile::type_convert<float>(ck_tile::numeric<PDataType>::max());
1192  float scale_o = v_descale / scale_p;
1193 
1194  auto o_acc_element_func = [&]() {
1195  if constexpr(std::is_same_v<ODataType, ck_tile::fp8_t>)
1197  ck_tile::scales{scale_o});
1198  else
1199  return ck_tile::scales{scale_o};
1200  }();
1201 
1202  return FmhaPipeline{}(q_dram_window,
1203  identity{}, // q_element_func
1204  k_dram_window,
1205  identity{}, // k_element_func
1206  v_dram_window,
1207  identity{}, // v_element_func
1208  bias_dram_window,
1209  identity{}, // bias_element_func
1210  randval_dram_window,
1211  lse_dram_window,
1212  identity{}, // lse_element_func
1213  identity{}, // s_acc_element_func
1214  scales{scale_p}, // p_compute_element_func
1215  o_acc_element_func, // o_acc_element_func
1216  mask,
1217  position_encoding,
1218  variant_params.sm_scale,
1219  variant,
1220  variant_params,
1221  block_indices,
1222  smem_ptr,
1223  page_idx,
1224  stride_k_for_pipeline,
1225  stride_v_for_pipeline,
1226  kargs.batch_stride_k,
1227  kargs.batch_stride_v,
1228  dropout);
1229  }
1230  else
1231  {
1232  return FmhaPipeline{}(q_dram_window,
1233  k_dram_window,
1234  v_dram_window,
1235  bias_dram_window,
1236  randval_dram_window,
1237  lse_dram_window,
1238  mask,
1239  position_encoding,
1240  variant_params.sm_scale,
1241  variant,
1242  variant_params,
1243  block_indices,
1244  smem_ptr,
1245  page_idx,
1246  stride_k_for_pipeline,
1247  stride_v_for_pipeline,
1248  kargs.batch_stride_k,
1249  kargs.batch_stride_v,
1250  dropout);
1251  }
1252  }();
1253 
1254  // O DRAM and O DRAM window
1255  auto o_dram = [&]() {
1256  const auto o_dram_naive = make_naive_tensor_view<address_space_enum::global>(
1257  o_ptr,
1258  make_tuple(kargs.seqlen_q, kargs.hdim_v),
1259  make_tuple(kargs.stride_o, 1),
1261  number<1>{});
1262 
1263  return pad_tensor_view(
1264  o_dram_naive,
1267  }();
1268 
1269  auto o_dram_window =
1270  make_tile_window(o_dram,
1272  {i_m0, i_n1});
1273 
1274  EpiloguePipeline{}(o_dram_window, o_acc_tile, nullptr);
1275  }
1276 };
1277 
1278 } // namespace ck_tile
#define CK_TILE_DEVICE
Definition: config.hpp:45
#define CK_TILE_HOST
Definition: config.hpp:44
#define CK_TILE_HOST_DEVICE
Definition: config.hpp:46
__host__ constexpr __device__ T max(T x)
Definition: math.hpp:84
__host__ T floor(T x)
Definition: math_v2.hpp:367
Definition: cluster_descriptor.hpp:13
constexpr CK_TILE_DEVICE auto make_null_tile_window(const WindowLengths &window_lengths)
Definition: null_tile_window.hpp:66
constexpr CK_TILE_HOST_DEVICE auto integer_divide_ceil(X x, Y y)
Definition: math.hpp:145
__device__ uint32_t amd_wave_read_first_lane(uint16_t v)
Definition: amd_buffer_addressing.hpp:36
CK_TILE_HOST_DEVICE_EXTERN composes(Ts &&...) -> composes< remove_cvref_t< Ts >... >
constexpr CK_TILE_HOST_DEVICE auto transform_tensor_view(const OldTensorView &old_tensor_view, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition: tensor_view.hpp:526
constexpr CK_TILE_HOST_DEVICE auto make_merge_transform(const LowLengths &low_lengths)
Definition: coordinate_transform.hpp:1690
int32_t index_t
Definition: integer.hpp:9
constexpr CK_TILE_HOST_DEVICE auto pad_tensor_view(const TensorView &tensor_view, const TileLengths &tile_lengths, DoPads)
Definition: tensor_view.hpp:545
constexpr CK_TILE_HOST_DEVICE auto make_pass_through_transform(const LowLength &low_length)
Definition: coordinate_transform.hpp:1633
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition: type_traits.hpp:21
int64_t long_index_t
Definition: integer.hpp:11
int32_t int32_t
Definition: integer.hpp:10
constexpr CK_TILE_DEVICE auto make_tile_window(null_tensor_view, const WindowLengths &window_lengths, const multi_index< WindowLengths::size()> &, Ts &&...)
Definition: null_tile_window.hpp:75
constexpr CK_TILE_HOST_DEVICE auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:360
GenericAttentionMaskEnum
Definition: block_masking.hpp:11
constexpr CK_TILE_HOST_DEVICE T max(T x)
Definition: math.hpp:157
typename conditional< predicate, X, Y >::type conditional_t
Definition: functional.hpp:115
unsigned char uint8_t
Definition: stdint.h:124
unsigned __int64 uint64_t
Definition: stdint.h:136
Definition: block_position_encoding.hpp:48
Definition: block_dropout.hpp:53
const float rp_undrop
Definition: block_dropout.hpp:371
Definition: block_position_encoding.hpp:137
Definition: fmha_batch_prefill_kernel.hpp:292
ck_tile::index_t kv_head_idx
Definition: fmha_batch_prefill_kernel.hpp:295
ck_tile::index_t qo_head_idx
Definition: fmha_batch_prefill_kernel.hpp:294
ck_tile::index_t batch_idx
Definition: fmha_batch_prefill_kernel.hpp:293
Definition: fmha_batch_prefill_kernel.hpp:167
ck_tile::index_t alibi_slope_stride
Definition: fmha_batch_prefill_kernel.hpp:170
const void * alibi_slope_ptr
Definition: fmha_batch_prefill_kernel.hpp:169
ck_tile::index_t batch_stride_bias
Definition: fmha_batch_prefill_kernel.hpp:163
ck_tile::index_t batch_stride_randval
Definition: fmha_batch_prefill_kernel.hpp:245
ck_tile::index_t batch_stride_o
Definition: fmha_batch_prefill_kernel.hpp:266
ck_tile::index_t batch_stride_v
Definition: fmha_batch_prefill_kernel.hpp:265
ck_tile::index_t batch_stride_q
Definition: fmha_batch_prefill_kernel.hpp:263
ck_tile::index_t batch_stride_k
Definition: fmha_batch_prefill_kernel.hpp:264
ck_tile::index_t nhead_stride_bias
Definition: fmha_batch_prefill_kernel.hpp:158
ck_tile::index_t stride_bias
Definition: fmha_batch_prefill_kernel.hpp:157
const void * bias_ptr
Definition: fmha_batch_prefill_kernel.hpp:156
ck_tile::index_t stride_randval
Definition: fmha_batch_prefill_kernel.hpp:239
void init_dropout(float p_drop, const uint64_t *seed_ptr, const uint64_t *offset_ptr)
Definition: fmha_batch_prefill_kernel.hpp:222
ck_tile::index_t nhead_stride_randval
Definition: fmha_batch_prefill_kernel.hpp:240
void init_dropout(float p_drop, uint64_t seed, uint64_t offset)
Definition: fmha_batch_prefill_kernel.hpp:210
void * rand_val_ptr
Definition: fmha_batch_prefill_kernel.hpp:237
float rp_undrop
Definition: fmha_batch_prefill_kernel.hpp:234
bool is_store_randval
Definition: fmha_batch_prefill_kernel.hpp:236
uint8_t p_undrop_in_uint8_t
Definition: fmha_batch_prefill_kernel.hpp:235
ck_tile::index_t page_block_size
Definition: fmha_batch_prefill_kernel.hpp:116
ck_tile::index_t stride_q
Definition: fmha_batch_prefill_kernel.hpp:121
ck_tile::index_t stride_v
Definition: fmha_batch_prefill_kernel.hpp:123
int32_t num_total_pages
Definition: fmha_batch_prefill_kernel.hpp:115
float scale_s
Definition: fmha_batch_prefill_kernel.hpp:119
PageBlockTableKargs page_table
Definition: fmha_batch_prefill_kernel.hpp:117
ck_tile::index_t seqlen_q
Definition: fmha_batch_prefill_kernel.hpp:105
ck_tile::index_t stride_k
Definition: fmha_batch_prefill_kernel.hpp:122
ck_tile::index_t nhead_stride_o
Definition: fmha_batch_prefill_kernel.hpp:129
ck_tile::index_t nhead_stride_k
Definition: fmha_batch_prefill_kernel.hpp:127
ck_tile::index_t nhead_ratio_qk
Definition: fmha_batch_prefill_kernel.hpp:113
ck_tile::index_t nhead_stride_v
Definition: fmha_batch_prefill_kernel.hpp:128
ck_tile::index_t nhead_stride_q
Definition: fmha_batch_prefill_kernel.hpp:126
const void * v_ptr
Definition: fmha_batch_prefill_kernel.hpp:102
void * o_ptr
Definition: fmha_batch_prefill_kernel.hpp:103
ck_tile::index_t seqlen_k
Definition: fmha_batch_prefill_kernel.hpp:106
ck_tile::index_t stride_o
Definition: fmha_batch_prefill_kernel.hpp:124
ck_tile::index_t hdim_v
Definition: fmha_batch_prefill_kernel.hpp:108
ck_tile::index_t num_head_q
Definition: fmha_batch_prefill_kernel.hpp:110
const void * k_ptr
Definition: fmha_batch_prefill_kernel.hpp:101
ck_tile::index_t hdim_q
Definition: fmha_batch_prefill_kernel.hpp:107
const void * q_ptr
Definition: fmha_batch_prefill_kernel.hpp:100
ck_tile::index_t batch_stride_lse
Definition: fmha_batch_prefill_kernel.hpp:184
ck_tile::index_t nhead_stride_lse
Definition: fmha_batch_prefill_kernel.hpp:183
void * lse_ptr
Definition: fmha_batch_prefill_kernel.hpp:182
const void * v_descale_ptr
Definition: fmha_batch_prefill_kernel.hpp:191
const void * k_descale_ptr
Definition: fmha_batch_prefill_kernel.hpp:190
const void * q_descale_ptr
Definition: fmha_batch_prefill_kernel.hpp:189
bool is_drop_seed_offset_from_host
Definition: fmha_batch_prefill_kernel.hpp:205
ValueOrPointer< uint64_t > drop_seed
Definition: fmha_batch_prefill_kernel.hpp:203
ValueOrPointer< uint64_t > drop_offset
Definition: fmha_batch_prefill_kernel.hpp:204
Definition: fmha_batch_prefill_kernel.hpp:72
ck_tile::index_t batch_stride_v
Definition: fmha_batch_prefill_kernel.hpp:286
ck_tile::index_t batch_stride_k
Definition: fmha_batch_prefill_kernel.hpp:285
const int32_t * seqstart_q_ptr
Definition: fmha_batch_prefill_kernel.hpp:284
float logits_soft_cap_rcp
Definition: fmha_batch_prefill_kernel.hpp:151
void init_logits_soft_cap(float logits_soft_cap_)
Definition: fmha_batch_prefill_kernel.hpp:136
float logits_soft_cap
Definition: fmha_batch_prefill_kernel.hpp:150
Definition: fmha_batch_prefill_kernel.hpp:174
ck_tile::index_t sink_size
Definition: fmha_batch_prefill_kernel.hpp:176
ck_tile::index_t window_size_right
Definition: fmha_batch_prefill_kernel.hpp:176
ck_tile::index_t window_size_left
Definition: fmha_batch_prefill_kernel.hpp:176
ck_tile::GenericAttentionMaskEnum mask_type
Definition: fmha_batch_prefill_kernel.hpp:177
const int32_t * kv_page_indices
Definition: fmha_batch_prefill_kernel.hpp:81
const int32_t * kv_indptr
Definition: fmha_batch_prefill_kernel.hpp:80
const int32_t * kv_last_page_lens
Definition: fmha_batch_prefill_kernel.hpp:82
const int32_t * block_table_ptr
Definition: fmha_batch_prefill_kernel.hpp:87
const int32_t * seqlen_k_ptr
Definition: fmha_batch_prefill_kernel.hpp:89
ck_tile::index_t batch_stride_block_table
Definition: fmha_batch_prefill_kernel.hpp:88
Definition: fmha_batch_prefill_kernel.hpp:28
static constexpr CK_TILE_DEVICE auto GetTileIndex(const Kargs &kargs)
Definition: fmha_batch_prefill_kernel.hpp:612
static constexpr bool kIsGroupMode
Definition: fmha_batch_prefill_kernel.hpp:50
static constexpr ck_tile::index_t kBlockPerCu
Definition: fmha_batch_prefill_kernel.hpp:33
ck_tile::remove_cvref_t< typename FmhaPipeline::VDataType > VDataType
Definition: fmha_batch_prefill_kernel.hpp:39
static constexpr CK_TILE_HOST std::enable_if_t< Cond, Kargs > MakeKargs(const void *q_ptr, const void *k_ptr, const void *v_ptr, const void *bias_ptr, const void *q_descale_ptr, const void *k_descale_ptr, const void *v_descale_ptr, void *rand_val_ptr, void *lse_ptr, void *o_ptr, const void *seqstart_q_ptr, ck_tile::index_t hdim_q, ck_tile::index_t hdim_v, ck_tile::index_t num_head_q, ck_tile::index_t nhead_ratio_qk, int32_t num_total_pages, ck_tile::index_t page_block_size, const PageBlockTableKargs &page_table, float scale_s, [[maybe_unused]] float scale_p, [[maybe_unused]] float scale_o, float logits_soft_cap, ck_tile::index_t stride_q, ck_tile::index_t stride_k, ck_tile::index_t stride_v, ck_tile::index_t stride_bias, ck_tile::index_t stride_randval, ck_tile::index_t stride_o, ck_tile::index_t nhead_stride_q, ck_tile::index_t nhead_stride_k, ck_tile::index_t nhead_stride_v, ck_tile::index_t nhead_stride_bias, ck_tile::index_t nhead_stride_randval, ck_tile::index_t nhead_stride_lse, ck_tile::index_t nhead_stride_o, ck_tile::index_t batch_stride_k, ck_tile::index_t batch_stride_v, ck_tile::index_t window_size_left, ck_tile::index_t window_size_right, ck_tile::index_t sink_size, ck_tile::index_t mask_type, float p_drop, bool s_randval, std::variant< std::pair< uint64_t, uint64_t >, std::pair< const void *, const void * >> drop_seed_offset)
Definition: fmha_batch_prefill_kernel.hpp:450
ck_tile::remove_cvref_t< FmhaPipeline_ > FmhaPipeline
Definition: fmha_batch_prefill_kernel.hpp:29
ck_tile::remove_cvref_t< typename FmhaPipeline::KDataType > KDataType
Definition: fmha_batch_prefill_kernel.hpp:38
static constexpr bool kPadSeqLenQ
Definition: fmha_batch_prefill_kernel.hpp:51
ck_tile::remove_cvref_t< typename FmhaPipeline::RandValOutputDataType > RandValOutputDataType
Definition: fmha_batch_prefill_kernel.hpp:43
static constexpr bool kPadHeadDimV
Definition: fmha_batch_prefill_kernel.hpp:54
ck_tile::remove_cvref_t< typename FmhaPipeline::LSEDataType > LSEDataType
Definition: fmha_batch_prefill_kernel.hpp:44
ck_tile::remove_cvref_t< typename FmhaPipeline::QDataType > QDataType
Definition: fmha_batch_prefill_kernel.hpp:37
static constexpr auto kKVMemoryLayout
Definition: fmha_batch_prefill_kernel.hpp:60
static constexpr bool kHasMask
Definition: fmha_batch_prefill_kernel.hpp:66
ck_tile::remove_cvref_t< typename FmhaPipeline::BiasDataType > BiasDataType
Definition: fmha_batch_prefill_kernel.hpp:41
ck_tile::remove_cvref_t< typename FmhaPipeline::FmhaMask > FmhaMask
Definition: fmha_batch_prefill_kernel.hpp:65
static CK_TILE_HOST dim3 BlockSize()
Definition: fmha_batch_prefill_kernel.hpp:671
static constexpr bool kPadSeqLenK
Definition: fmha_batch_prefill_kernel.hpp:52
static constexpr auto QScaleEnum
Definition: fmha_batch_prefill_kernel.hpp:59
static constexpr bool kHasLogitsSoftCap
Definition: fmha_batch_prefill_kernel.hpp:55
static constexpr bool kHasDropout
Definition: fmha_batch_prefill_kernel.hpp:58
static constexpr bool kStoreLSE
Definition: fmha_batch_prefill_kernel.hpp:57
static constexpr auto kKVLookupTable
Definition: fmha_batch_prefill_kernel.hpp:61
ck_tile::remove_cvref_t< typename FmhaPipeline::SaccDataType > SaccDataType
Definition: fmha_batch_prefill_kernel.hpp:46
static constexpr index_t kPageBlockSize
Definition: fmha_batch_prefill_kernel.hpp:62
ck_tile::remove_cvref_t< typename FmhaPipeline::PDataType > PDataType
Definition: fmha_batch_prefill_kernel.hpp:40
static constexpr CK_TILE_HOST_DEVICE ck_tile::index_t GetSmemSize()
Definition: fmha_batch_prefill_kernel.hpp:683
std::conditional_t< kKVLookupTable==BlockAttentionKVCacheLookupTableEnum::SGLANG_PAGE_TABLE_1D, SglangPageTableKargs, VllmPageTableKargs > PageBlockTableKargs
Definition: fmha_batch_prefill_kernel.hpp:96
static constexpr ck_tile::index_t kBlockSize
Definition: fmha_batch_prefill_kernel.hpp:31
static constexpr auto BiasEnum
Definition: fmha_batch_prefill_kernel.hpp:56
ck_tile::remove_cvref_t< typename FmhaPipeline::VLayout > VLayout
Definition: fmha_batch_prefill_kernel.hpp:48
ck_tile::remove_cvref_t< typename FmhaPipeline::ODataType > ODataType
Definition: fmha_batch_prefill_kernel.hpp:45
ck_tile::remove_cvref_t< typename FmhaPipeline::AttentionVariant > AttentionVariant
Definition: fmha_batch_prefill_kernel.hpp:64
static constexpr index_t kVectorSize
Definition: fmha_batch_prefill_kernel.hpp:63
static constexpr bool kUseAsyncCopy
Definition: fmha_batch_prefill_kernel.hpp:68
static constexpr ck_tile::index_t kBlockPerCuInput
Definition: fmha_batch_prefill_kernel.hpp:35
static constexpr CK_TILE_HOST auto GridSize(ck_tile::index_t batch_size_, ck_tile::index_t nhead_, ck_tile::index_t seqlen_q_, ck_tile::index_t hdim_v_)
Definition: fmha_batch_prefill_kernel.hpp:589
static constexpr bool kPadHeadDimQ
Definition: fmha_batch_prefill_kernel.hpp:53
CK_TILE_DEVICE void operator()(Kargs kargs) const
Definition: fmha_batch_prefill_kernel.hpp:688
std::conditional_t< kIsGroupMode, FmhaFwdGroupModeKargs, FmhaFwdBatchModeKargs > Kargs
Definition: fmha_batch_prefill_kernel.hpp:289
ck_tile::remove_cvref_t< EpiloguePipeline_ > EpiloguePipeline
Definition: fmha_batch_prefill_kernel.hpp:30
static constexpr CK_TILE_HOST std::enable_if_t< Cond, Kargs > MakeKargs(const void *q_ptr, const void *k_ptr, const void *v_ptr, const void *bias_ptr, const void *q_descale_ptr, const void *k_descale_ptr, const void *v_descale_ptr, void *rand_val_ptr, void *lse_ptr, void *o_ptr, ck_tile::index_t seqlen_q, ck_tile::index_t hdim_q, ck_tile::index_t hdim_v, ck_tile::index_t num_head_q, ck_tile::index_t nhead_ratio_qk, int32_t num_total_pages, ck_tile::index_t page_block_size, const PageBlockTableKargs &page_table, float scale_s, [[maybe_unused]] float scale_p, [[maybe_unused]] float scale_o, float logits_soft_cap, ck_tile::index_t stride_q, ck_tile::index_t stride_k, ck_tile::index_t stride_v, ck_tile::index_t stride_bias, ck_tile::index_t stride_randval, ck_tile::index_t stride_o, ck_tile::index_t nhead_stride_q, ck_tile::index_t nhead_stride_k, ck_tile::index_t nhead_stride_v, ck_tile::index_t nhead_stride_bias, ck_tile::index_t nhead_stride_randval, ck_tile::index_t nhead_stride_lse, ck_tile::index_t nhead_stride_o, ck_tile::index_t batch_stride_q, ck_tile::index_t batch_stride_k, ck_tile::index_t batch_stride_v, ck_tile::index_t batch_stride_bias, ck_tile::index_t batch_stride_randval, ck_tile::index_t batch_stride_lse, ck_tile::index_t batch_stride_o, ck_tile::index_t window_size_left, ck_tile::index_t window_size_right, ck_tile::index_t sink_size, ck_tile::index_t mask_type, float p_drop, bool s_randval, std::variant< std::pair< uint64_t, uint64_t >, std::pair< const void *, const void * >> drop_seed_offset)
Definition: fmha_batch_prefill_kernel.hpp:300
Definition: variants.hpp:63
float logits_soft_cap
Definition: variants.hpp:128
Definition: block_dropout.hpp:39
Definition: variants.hpp:51
Definition: integral_constant.hpp:13
Definition: functional.hpp:114
Definition: numeric.hpp:18
Definition: coordinate_transform.hpp:1392
Definition: unary_element_function.hpp:55
Definition: math.hpp:28
Definition: sequence.hpp:49