/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/fmha/kernel/fmha_batch_prefill_kernel.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/fmha/kernel/fmha_batch_prefill_kernel.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/fmha/kernel/fmha_batch_prefill_kernel.hpp Source File
fmha_batch_prefill_kernel.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include "ck_tile/core.hpp"
7 #include "ck_tile/ops/common.hpp"
10 
11 #include <string>
12 #include <type_traits>
13 #include <utility>
14 #include <variant>
15 
16 // S[seqlen_q, seqlen_k] = Q[seqlen_q, hdim_q] @ K[seqlen_k, hdim_q]
17 // S'[seqlen_q, seqlen_k] = S[seqlen_q, seqlen_k] * Scale[1]
18 // S''[seqlen_q, seqlen_k] = S'[seqlen_q, seqlen_k] + Bias[seqlen_q, seqlen_k]
19 // P[seqlen_q, seqlen_k] = Softmax(S''[seqlen_q, seqlen_k])
20 // O[seqlen_q, hdim_v] = P[seqlen_q, seqlen_k] @ V^T[hdim_v, seqlen_k]
21 
22 namespace ck_tile {
23 
24 template <typename FmhaPipeline_, typename EpiloguePipeline_>
26 {
29  static constexpr ck_tile::index_t kBlockSize = FmhaPipeline::kBlockSize;
30 
31  static constexpr ck_tile::index_t kBlockPerCu = FmhaPipeline::kBlockPerCu;
32  static_assert(kBlockPerCu > 0);
33  static constexpr ck_tile::index_t kBlockPerCuInput = FmhaPipeline::Problem::kBlockPerCu;
34 
44 
46 
47  static constexpr bool kIsGroupMode = FmhaPipeline::kIsGroupMode;
48  static constexpr bool kPadSeqLenQ = FmhaPipeline::kPadSeqLenQ;
49  static constexpr bool kPadSeqLenK = FmhaPipeline::kPadSeqLenK;
50  static constexpr bool kPadHeadDimQ = FmhaPipeline::kPadHeadDimQ;
51  static constexpr bool kPadHeadDimV = FmhaPipeline::kPadHeadDimV;
52  static constexpr bool kHasLogitsSoftCap = FmhaPipeline::kHasLogitsSoftCap;
53  static constexpr auto BiasEnum = FmhaPipeline::BiasEnum;
54  static constexpr bool kStoreLSE = FmhaPipeline::kStoreLSE;
55  static constexpr bool kHasDropout = FmhaPipeline::kHasDropout;
56  static constexpr bool kDoFp8StaticQuant = FmhaPipeline::Problem::kDoFp8StaticQuant;
59  static constexpr bool kHasMask = FmhaMask::IsMasking;
60 
61  static constexpr bool kUseAsyncCopy = FmhaPipeline::Policy::AsyncCopy;
62 
63  // clang-format off
64  template <typename T> struct t2s;
65  template <> struct t2s<float> { static constexpr const char * name = "fp32"; };
66  template <> struct t2s<ck_tile::fp16_t> { static constexpr const char * name = "fp16"; };
67  template <> struct t2s<ck_tile::bf16_t> { static constexpr const char * name = "bf16"; };
68  template <> struct t2s<ck_tile::fp8_t> { static constexpr const char * name = "fp8"; };
69  template <> struct t2s<ck_tile::bf8_t> { static constexpr const char * name = "bf8"; };
70  // clang-format on
71 
72  CK_TILE_HOST static std::string GetName()
73  {
74  // sync with generate.py
75  // clang-format off
76  using bfs = typename FmhaPipeline::BlockFmhaShape;
77  using g0br = typename bfs::Gemm0BlockWarps;
78  using g1br = typename bfs::Gemm1BlockWarps;
79  using g0wt = typename bfs::Gemm0WarpTile;
80  using g1wt = typename bfs::Gemm1WarpTile;
81  #define _SS_ std::string
82  #define _TS_ std::to_string
83  auto pn = [&] () {
84  std::string n;
85  if (kPadSeqLenQ) n += "s";
86  if (kPadSeqLenK) n += "sk";
87  if (kPadHeadDimQ) n += "d";
88  if (kPadHeadDimV) n += "dv";
89  return n.empty() ? n : std::string("p") + n; }();
90  return
91  _SS_("fmha_batch_prefill_d") + _TS_(bfs::kQKHeaddim) + "_" + _SS_(t2s<QDataType>::name) +
92  "_" + (kIsGroupMode ? "group" : "batch") + "_"
93  "b" + _TS_(bfs::kM0) + "x" + _TS_(bfs::kN0) + "x" + _TS_(bfs::kK0) + "x" +
94  _TS_(bfs::kN1) + "x" + _TS_(bfs::kK1) + "x" + _TS_(bfs::kQKHeaddim) + "_" +
95  "r" + _TS_(g0br::at(ck_tile::number<0>{})) + "x" + _TS_(g0br::at(ck_tile::number<1>{})) + "x" + _TS_(g0br::at(ck_tile::number<2>{})) + "_" +
96  "r" + _TS_(g1br::at(ck_tile::number<0>{})) + "x" + _TS_(g1br::at(ck_tile::number<1>{})) + "x" + _TS_(g1br::at(ck_tile::number<2>{})) + "_" +
97  "w" + _TS_(g0wt::at(ck_tile::number<0>{})) + "x" + _TS_(g0wt::at(ck_tile::number<1>{})) + "x" + _TS_(g0wt::at(ck_tile::number<2>{})) + "_" +
98  "w" + _TS_(g1wt::at(ck_tile::number<0>{})) + "x" + _TS_(g1wt::at(ck_tile::number<1>{})) + "x" + _TS_(g1wt::at(ck_tile::number<2>{})) + "_" +
99  (kBlockPerCuInput == -1 ? "" : ("o" + _TS_(kBlockPerCu) + "_")) + _SS_(FmhaPipeline::name) + "_" +
100  "v" + (std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor> ? "r" : "c") + (pn.empty() ? "_npad" : "_" + pn) +
101  (kHasLogitsSoftCap ? "_logits" : "_nlogits" ) + (BiasEnum == BlockAttentionBiasEnum::NO_BIAS ? _SS_("_nbias") : (_SS_("_") + BlockAttentionBiasEnumToStr<BiasEnum>::name)) +
102  (kHasMask ? "_" + _SS_(FmhaMask::name) : "_nmask") + (kStoreLSE ? "_lse" : "_nlse" ) + (kHasDropout ? "_dropout" : "_ndropout" ) + (kDoFp8StaticQuant ? "_squant" : "_nsquant" );
103  #undef _SS_
104  #undef _TS_
105  // clang-format on
106  }
107 
108  template <ck_tile::index_t I> // to avoid duplicated base class prblem, introduce an template
109  // arg
111  {
112  };
113 
114  // kargs use aggregate initializer, so no constructor will provided
115  // use inheritance to minimize karg size
116  // user need to use MakeKargs() function to create kargs.
118  {
119  const void* q_ptr;
120  const void* k_ptr;
121  const void* v_ptr;
122  void* o_ptr;
123 
128 
130  // for MQA/GQA, nhead could be different. This parameter is nhead_q / nhead_k
131  // if this param is larger than 1, indicate MQA/GQA case
133 
137 #if 0 // we assume page_block_size=1 for now
138  const int32_t* kv_last_page_lens;
140 #else
141  static constexpr ck_tile::index_t page_block_size = 1;
142 #endif
143 
144  float scale_s;
145 
150 
155  };
156 
158  {
160 
161  void init_logits_soft_cap(float logits_soft_cap_)
162  {
163  if(0 < logits_soft_cap_)
164  {
165  logits_soft_cap = logits_soft_cap_;
167  }
168  else
169  {
170  logits_soft_cap = 0.f;
171  logits_soft_cap_rcp = 0.f;
172  }
173  }
174 
177  };
178 
180  {
181  const void* bias_ptr = nullptr;
184  };
185 
187  {
189  };
190 
192  {
193  // alibi is batch*nhead*1, no matter in batch/group mode, they are the same
194  const void* alibi_slope_ptr;
195  ck_tile::index_t alibi_slope_stride; // stride in batch, or 0 for all batch share same slope
196  };
197 
199  {
200  // ck_tile::index_t window_size_left, window_size_right;
203  };
204 
206  {
207  float scale_p;
208  float scale_o;
209  };
210 
212  {
213  void* lse_ptr = nullptr;
216  };
217 
219  {
220  template <typename T>
222  {
223  T val;
224  const T* ptr;
225  };
226 
230  };
231 
233  {
234  void init_dropout(float p_drop, uint64_t seed, uint64_t offset)
235  {
236  float p_undrop = 1.0 - p_drop;
239  rp_undrop = 1.0 / p_undrop;
240 
241  this->drop_seed.val = seed;
242  this->drop_offset.val = offset;
243  this->is_drop_seed_offset_from_host = true;
244  }
245 
246  void init_dropout(float p_drop, const uint64_t* seed_ptr, const uint64_t* offset_ptr)
247  {
248  float p_undrop = 1.0 - p_drop;
251  rp_undrop = 1.0 / p_undrop;
252 
253  this->drop_seed.ptr = seed_ptr;
254  this->drop_offset.ptr = offset_ptr;
255  this->is_drop_seed_offset_from_host = false;
256  }
257 
258  float rp_undrop = 1;
260  bool is_store_randval = false;
261  void* rand_val_ptr = nullptr;
262 
265  };
266 
268  {
270  };
271 
274  std::conditional_t<BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS,
275  FmhaFwdBatchModeBiasKargs,
276  std::conditional_t<BiasEnum == BlockAttentionBiasEnum::ALIBI,
277  FmhaFwdAlibiKargs,
278  FmhaFwdEmptyKargs<0>>>,
279  std::conditional_t<kHasMask, FmhaFwdMaskKargs, FmhaFwdEmptyKargs<1>>,
280  std::conditional_t<kStoreLSE, FmhaFwdCommonLSEKargs, FmhaFwdEmptyKargs<2>>,
281  std::conditional_t<kDoFp8StaticQuant, FmhaFwdFp8StaticQuantKargs, FmhaFwdEmptyKargs<3>>,
282  std::conditional_t<kHasDropout, FmhaFwdBatchModeDropoutKargs, FmhaFwdEmptyKargs<4>>,
283  std::conditional_t<kHasLogitsSoftCap, FmhaFwdLogitsSoftCapKargs, FmhaFwdEmptyKargs<5>>
284  {
289  };
290 
293  std::conditional_t<BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS,
294  FmhaFwdCommonBiasKargs,
295  std::conditional_t<BiasEnum == BlockAttentionBiasEnum::ALIBI,
296  FmhaFwdAlibiKargs,
297  FmhaFwdEmptyKargs<0>>>,
298  std::conditional_t<kHasMask, FmhaFwdMaskKargs, FmhaFwdEmptyKargs<1>>,
299  std::conditional_t<kStoreLSE, FmhaFwdCommonLSEKargs, FmhaFwdEmptyKargs<2>>,
300  std::conditional_t<kDoFp8StaticQuant, FmhaFwdFp8StaticQuantKargs, FmhaFwdEmptyKargs<3>>,
301  std::conditional_t<kHasDropout, FmhaFwdCommonDropoutKargs, FmhaFwdEmptyKargs<4>>,
302  std::conditional_t<kHasLogitsSoftCap, FmhaFwdLogitsSoftCapKargs, FmhaFwdEmptyKargs<5>>
303  {
307  };
308 
309  using Kargs = std::conditional_t<kIsGroupMode, FmhaFwdGroupModeKargs, FmhaFwdBatchModeKargs>;
310 
312  {
316  };
317 
318  template <bool Cond = !kIsGroupMode>
319  CK_TILE_HOST static constexpr std::enable_if_t<Cond, Kargs>
320  MakeKargs(const void* q_ptr,
321  const void* k_ptr,
322  const void* v_ptr,
323  const void* bias_ptr,
324  void* rand_val_ptr,
325  void* lse_ptr,
326  void* o_ptr,
327  ck_tile::index_t seqlen_q,
328  ck_tile::index_t hdim_q,
329  ck_tile::index_t hdim_v,
330  ck_tile::index_t num_head_q,
331  ck_tile::index_t nhead_ratio_qk,
332  int32_t num_total_pages,
333  const void* kv_indptr,
334  const void* kv_page_indices,
335 #if 0 // we assume page_block_size=1 for now
336  const void* kv_last_page_lens,
337  ck_tile::index_t page_block_size,
338 #endif
339  float scale_s,
340  float scale_p,
341  float scale_o,
342  float logits_soft_cap,
343  ck_tile::index_t stride_q,
344  ck_tile::index_t stride_k,
345  ck_tile::index_t stride_v,
346  ck_tile::index_t stride_bias,
347  ck_tile::index_t stride_randval,
348  ck_tile::index_t stride_o,
349  ck_tile::index_t nhead_stride_q,
350  ck_tile::index_t nhead_stride_k,
351  ck_tile::index_t nhead_stride_v,
352  ck_tile::index_t nhead_stride_bias,
353  ck_tile::index_t nhead_stride_randval,
354  ck_tile::index_t nhead_stride_lse,
355  ck_tile::index_t nhead_stride_o,
356  ck_tile::index_t batch_stride_q,
357  ck_tile::index_t batch_stride_k,
358  ck_tile::index_t batch_stride_v,
359  ck_tile::index_t batch_stride_bias,
360  ck_tile::index_t batch_stride_randval,
361  ck_tile::index_t batch_stride_lse,
362  ck_tile::index_t batch_stride_o,
363  ck_tile::index_t window_size_left,
364  ck_tile::index_t window_size_right,
365  ck_tile::index_t mask_type,
366  float p_drop,
367  bool s_randval,
368  std::variant<std::pair<uint64_t, uint64_t>, std::pair<const void*, const void*>>
369  drop_seed_offset)
370  {
371  Kargs kargs{{q_ptr,
372  k_ptr,
373  v_ptr,
374  o_ptr,
375  seqlen_q,
376  -1,
377  hdim_q,
378  hdim_v,
379  num_head_q,
380  nhead_ratio_qk,
381  num_total_pages,
382  reinterpret_cast<const int32_t*>(kv_indptr),
383  reinterpret_cast<const int32_t*>(kv_page_indices),
384 #if 0 // we assume page_block_size=1 for now
385  reinterpret_cast<const int32_t*>(kv_last_page_lens),
386  page_block_size,
387 #endif
389  static_cast<float>(scale_s * ck_tile::log2e_v<>),
390 #else
391  scale_s,
392 #endif
393  stride_q,
394  stride_k,
395  stride_v,
396  stride_o,
397  nhead_stride_q,
398  nhead_stride_k,
399  nhead_stride_v,
400  nhead_stride_o}, // args for common karg
401  {}, // placeholder for bias
402  {}, // placeholder for mask
403  {}, // placeholder for lse
404  {}, // placeholder for fp8_static_quant args
405  {}, // placeholder for dropout
406  {}, // placeholder for logits_soft_cap
407  batch_stride_q,
408  batch_stride_k,
409  batch_stride_v,
410  batch_stride_o};
411 
413  {
414  kargs.bias_ptr = bias_ptr;
415  kargs.stride_bias = stride_bias;
416  kargs.nhead_stride_bias = nhead_stride_bias;
417  kargs.batch_stride_bias = batch_stride_bias;
418  }
419  else if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI)
420  {
421  kargs.alibi_slope_ptr = bias_ptr;
422  kargs.alibi_slope_stride = stride_bias;
423  }
424  if constexpr(kHasMask)
425  {
426  kargs.window_size_left = window_size_left;
427  kargs.window_size_right = window_size_right;
428  kargs.mask_type = static_cast<ck_tile::GenericAttentionMaskEnum>(mask_type);
429  }
430  if constexpr(kStoreLSE)
431  {
432  kargs.lse_ptr = lse_ptr;
433  kargs.nhead_stride_lse = nhead_stride_lse;
434  kargs.batch_stride_lse = batch_stride_lse;
435  }
436  if constexpr(kDoFp8StaticQuant)
437  {
438  kargs.scale_p = scale_p;
439  kargs.scale_o = scale_o;
440  }
441  if constexpr(kHasDropout)
442  {
443  if(drop_seed_offset.index() == 0) // seed & offset come from host
444  {
445  const auto& [seed, offset] = std::get<0>(drop_seed_offset);
446  kargs.init_dropout(p_drop, seed, offset);
447  }
448  else // seed & offset come from device
449  {
450  const auto& [seed_ptr, offset_ptr] = std::get<1>(drop_seed_offset);
451  kargs.init_dropout(p_drop,
452  reinterpret_cast<const uint64_t*>(seed_ptr),
453  reinterpret_cast<const uint64_t*>(offset_ptr));
454  }
455 
456  kargs.rand_val_ptr = rand_val_ptr;
457  kargs.stride_randval = stride_randval;
458  kargs.nhead_stride_randval = nhead_stride_randval;
459  kargs.batch_stride_randval = batch_stride_randval;
460  kargs.is_store_randval = s_randval;
461  }
462  if constexpr(kHasLogitsSoftCap)
463  {
464  kargs.init_logits_soft_cap(logits_soft_cap);
465  }
466 
467  return kargs;
468  }
469 
470  template <bool Cond = kIsGroupMode>
471  CK_TILE_HOST static constexpr std::enable_if_t<Cond, Kargs>
472  MakeKargs(const void* q_ptr,
473  const void* k_ptr,
474  const void* v_ptr,
475  const void* bias_ptr,
476  void* rand_val_ptr,
477  void* lse_ptr,
478  void* o_ptr,
479  const void* seqstart_q_ptr,
480  ck_tile::index_t hdim_q,
481  ck_tile::index_t hdim_v,
482  ck_tile::index_t num_head_q,
483  ck_tile::index_t nhead_ratio_qk,
484  int32_t num_total_pages,
485  const void* kv_indptr,
486  const void* kv_page_indices,
487 #if 0 // we assume page_block_size=1 for now
488  const void* kv_last_page_lens,
489  ck_tile::index_t page_block_size,
490 #endif
491  float scale_s,
492  float scale_p,
493  float scale_o,
494  float logits_soft_cap,
495  ck_tile::index_t stride_q,
496  ck_tile::index_t stride_k,
497  ck_tile::index_t stride_v,
498  ck_tile::index_t stride_bias,
499  ck_tile::index_t stride_randval,
500  ck_tile::index_t stride_o,
501  ck_tile::index_t nhead_stride_q,
502  ck_tile::index_t nhead_stride_k,
503  ck_tile::index_t nhead_stride_v,
504  ck_tile::index_t nhead_stride_bias,
505  ck_tile::index_t nhead_stride_randval,
506  ck_tile::index_t nhead_stride_lse,
507  ck_tile::index_t nhead_stride_o,
508  ck_tile::index_t batch_stride_k,
509  ck_tile::index_t batch_stride_v,
510  ck_tile::index_t window_size_left,
511  ck_tile::index_t window_size_right,
512  ck_tile::index_t mask_type,
513  float p_drop,
514  bool s_randval,
515  std::variant<std::pair<uint64_t, uint64_t>, std::pair<const void*, const void*>>
516  drop_seed_offset)
517  {
518  Kargs kargs{{q_ptr,
519  k_ptr,
520  v_ptr,
521  o_ptr,
522  -1, // seqlen will be updated by another pointer
523  -1, //
524  hdim_q,
525  hdim_v,
526  num_head_q,
527  nhead_ratio_qk,
528  num_total_pages,
529  reinterpret_cast<const int32_t*>(kv_indptr),
530  reinterpret_cast<const int32_t*>(kv_page_indices),
531 #if 0 // we assume page_block_size=1 for now
532  reinterpret_cast<const int32_t*>(kv_last_page_lens),
533  page_block_size,
534 #endif
536  static_cast<float>(scale_s * ck_tile::log2e_v<>),
537 #else
538  scale_s,
539 #endif
540  stride_q,
541  stride_k,
542  stride_v,
543  stride_o,
544  nhead_stride_q,
545  nhead_stride_k,
546  nhead_stride_v,
547  nhead_stride_o}, // args for common karg
548  {}, // placeholder for bias
549  {}, // placeholder for mask
550  {}, // placeholder for lse
551  {}, // placeholder for fp8_static_quant args
552  {}, // placeholder for dropout
553  {}, // placeholder for logits_soft_cap
554  reinterpret_cast<const int32_t*>(seqstart_q_ptr),
555  batch_stride_k,
556  batch_stride_v};
557 
559  {
560  kargs.bias_ptr = bias_ptr;
561  kargs.stride_bias = stride_bias;
562  kargs.nhead_stride_bias = nhead_stride_bias;
563  }
564  else if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI)
565  {
566  kargs.alibi_slope_ptr = bias_ptr;
567  kargs.alibi_slope_stride = stride_bias;
568  }
569  if constexpr(kHasMask)
570  {
571  kargs.window_size_left = window_size_left;
572  kargs.window_size_right = window_size_right;
573  kargs.mask_type = static_cast<ck_tile::GenericAttentionMaskEnum>(mask_type);
574  }
575  if constexpr(kStoreLSE)
576  {
577  kargs.lse_ptr = lse_ptr;
578  kargs.nhead_stride_lse = nhead_stride_lse;
579  }
580  if constexpr(kDoFp8StaticQuant)
581  {
582  kargs.scale_p = scale_p;
583  kargs.scale_o = scale_o;
584  }
585  if constexpr(kHasDropout)
586  {
587  if(drop_seed_offset.index() == 0) // seed & offset come from host
588  {
589  const auto& [seed, offset] = std::get<0>(drop_seed_offset);
590  kargs.init_dropout(p_drop, seed, offset);
591  }
592  else // seed & offset come from device
593  {
594  const auto& [seed_ptr, offset_ptr] = std::get<1>(drop_seed_offset);
595  kargs.init_dropout(p_drop,
596  reinterpret_cast<const uint64_t*>(seed_ptr),
597  reinterpret_cast<const uint64_t*>(offset_ptr));
598  }
599 
600  kargs.rand_val_ptr = rand_val_ptr;
601  kargs.stride_randval = stride_randval;
602  kargs.nhead_stride_randval = nhead_stride_randval;
603  kargs.is_store_randval = s_randval;
604  }
605  if constexpr(kHasLogitsSoftCap)
606  {
607  kargs.init_logits_soft_cap(logits_soft_cap);
608  }
609 
610  return kargs;
611  }
612 
613  CK_TILE_HOST static constexpr auto GridSize(ck_tile::index_t batch_size_,
614  ck_tile::index_t nhead_,
615  ck_tile::index_t seqlen_q_,
616  ck_tile::index_t hdim_v_)
617  {
618  if constexpr(kIsGroupMode)
619  {
620  // TODO: this may need tuning
621  return dim3(nhead_,
622  batch_size_,
623  ck_tile::integer_divide_ceil(seqlen_q_, FmhaPipeline::kM0) *
624  ck_tile::integer_divide_ceil(hdim_v_, FmhaPipeline::kN1));
625  }
626  else
627  {
628  // TODO: this may need tuning
629  return dim3(ck_tile::integer_divide_ceil(seqlen_q_, FmhaPipeline::kM0) *
630  ck_tile::integer_divide_ceil(hdim_v_, FmhaPipeline::kN1),
631  nhead_,
632  batch_size_);
633  }
634  }
635 
636  CK_TILE_DEVICE static constexpr auto GetTileIndex(const Kargs& kargs)
637  {
638  if constexpr(kIsGroupMode)
639  {
640  // const index_t num_tile_m0 = seqlen_q / kM0;
641  const index_t num_tile_n1 =
642  ck_tile::integer_divide_ceil(kargs.hdim_v, FmhaPipeline::kN1);
643 
644  const index_t i_block = blockIdx.z;
645  const index_t i_nhead = blockIdx.x;
646  const index_t i_batch = blockIdx.y;
647 
648  const auto f = [](index_t dividend, index_t divisor) {
649  index_t quotient = dividend / divisor;
650  index_t modulus = dividend - quotient * divisor;
651  return ck_tile::make_tuple(quotient, modulus);
652  };
653 
654  const auto [i_tile_m, i_tile_n] = f(i_block, num_tile_n1);
655  if constexpr(kHasMask)
656  {
657  // assume that num_tile_n1 is always 1
658  return ck_tile::make_tuple(gridDim.z - 1 - i_tile_m, i_tile_n, i_nhead, i_batch);
659  }
660  else
661  {
662  return ck_tile::make_tuple(i_tile_m, i_tile_n, i_nhead, i_batch);
663  }
664  }
665  else
666  {
667  // const index_t num_tile_m0 = seqlen_q / kM0;
668  const index_t num_tile_n1 =
669  ck_tile::integer_divide_ceil(kargs.hdim_v, FmhaPipeline::kN1);
670 
671  const index_t i_block = blockIdx.x;
672  const index_t i_nhead = blockIdx.y;
673  const index_t i_batch = blockIdx.z;
674 
675  const auto f = [](index_t dividend, index_t divisor) {
676  index_t quotient = dividend / divisor;
677  index_t modulus = dividend - quotient * divisor;
678  return ck_tile::make_tuple(quotient, modulus);
679  };
680 
681  const auto [i_tile_m, i_tile_n] = f(i_block, num_tile_n1);
682 
683  if constexpr(kHasMask)
684  {
685  // assume that num_tile_n1 is always 1
686  return ck_tile::make_tuple(gridDim.x - 1 - i_tile_m, i_tile_n, i_nhead, i_batch);
687  }
688  else
689  {
690  return ck_tile::make_tuple(i_tile_m, i_tile_n, i_nhead, i_batch);
691  }
692  }
693  }
694 
695  CK_TILE_HOST static constexpr auto BlockSize() { return dim3(kBlockSize); }
696 
698  {
699  return ck_tile::max(FmhaPipeline::GetSmemSize(), EpiloguePipeline::GetSmemSize());
700  }
701 
702  CK_TILE_DEVICE void operator()(Kargs kargs) const
703  {
704  // allocate LDS
705  __shared__ char smem_ptr[GetSmemSize()];
706 
707  // divide problem
708  const auto [i_tile_m, i_tile_n, i_nhead, i_batch] = GetTileIndex(kargs);
709 
710  const index_t i_m0 = __builtin_amdgcn_readfirstlane(i_tile_m * FmhaPipeline::kM0);
711  const index_t i_n1 = __builtin_amdgcn_readfirstlane(i_tile_n * FmhaPipeline::kN1);
712 
713  long_index_t batch_offset_q = 0;
714  long_index_t batch_offset_bias = 0;
715  long_index_t batch_offset_randval = 0;
716  long_index_t batch_offset_lse = 0;
717  long_index_t batch_offset_o = 0;
718 
719  const int32_t num_page_blocks = kargs.kv_indptr[i_batch + 1] - kargs.kv_indptr[i_batch];
720 #if 0 // we assume page_block_size=1 for now
721  const int32_t last_page_len = kargs.kv_last_page_lens[i_batch];
722 #endif
723  if constexpr(kIsGroupMode)
724  {
725  // get starting offset for each batch
726  const long_index_t query_start = kargs.seqstart_q_ptr[i_batch];
727 
728  batch_offset_q = query_start * kargs.stride_q;
729 
730  kargs.kv_page_indices += kargs.kv_indptr[i_batch];
731 
733  {
734  batch_offset_bias = query_start * kargs.stride_bias;
735  }
736  if constexpr(kStoreLSE)
737  {
738  batch_offset_lse = query_start;
739  }
740  if constexpr(kHasDropout)
741  {
742  batch_offset_randval = query_start * kargs.stride_randval;
743  }
744  batch_offset_o = query_start * kargs.stride_o;
745 
746  // get real # queries & # keys under group mode
747  kargs.seqlen_q = kargs.seqstart_q_ptr[i_batch + 1] - query_start;
748 
749  // # of required blocks is different in each groups, terminate unnecessary blocks
750  // earlier
751  if(kargs.seqlen_q <= i_m0)
752  {
753  return;
754  }
755 
756 #if 0 // we assume page_block_size=1 for now
757  kargs.seqlen_k = (num_page_blocks - 1) * kargs.page_block_size + last_page_len;
758 #else
759  kargs.seqlen_k = num_page_blocks;
760 #endif
761  }
762  else
763  {
764  batch_offset_q = static_cast<long_index_t>(i_batch) * kargs.batch_stride_q;
765 
766  kargs.kv_page_indices += kargs.kv_indptr[i_batch];
767 
769  {
770  batch_offset_bias = static_cast<long_index_t>(i_batch) * kargs.batch_stride_bias;
771  }
772  if constexpr(kStoreLSE)
773  {
774  batch_offset_lse = static_cast<long_index_t>(i_batch) * kargs.batch_stride_lse;
775  }
776  if constexpr(kHasDropout)
777  {
778  batch_offset_randval =
779  static_cast<long_index_t>(i_batch) * kargs.batch_stride_randval;
780  }
781  batch_offset_o = static_cast<long_index_t>(i_batch) * kargs.batch_stride_o;
782 
783 #if 0 // we assume page_block_size=1 for now
784  kargs.seqlen_k = (num_page_blocks - 1) * kargs.page_block_size + last_page_len;
785 #else
786  kargs.seqlen_k = num_page_blocks;
787 #endif
788  }
789 
790  // for simplicity, batch stride we just modify the pointer
791  const QDataType* q_ptr = reinterpret_cast<const QDataType*>(kargs.q_ptr) +
792  static_cast<long_index_t>(i_nhead) * kargs.nhead_stride_q +
793  batch_offset_q;
794  const KDataType* k_ptr =
795  reinterpret_cast<const KDataType*>(kargs.k_ptr) +
796  static_cast<long_index_t>(i_nhead / kargs.nhead_ratio_qk) * kargs.nhead_stride_k;
797  const VDataType* v_ptr =
798  reinterpret_cast<const VDataType*>(kargs.v_ptr) +
799  static_cast<long_index_t>(i_nhead / kargs.nhead_ratio_qk) * kargs.nhead_stride_v;
800  ODataType* o_ptr = reinterpret_cast<ODataType*>(kargs.o_ptr) +
801  static_cast<long_index_t>(i_nhead) * kargs.nhead_stride_o +
802  batch_offset_o;
803 
804  // Q/K/V DRAM and DRAM window
805  const auto q_dram = [&]() {
806  const auto q_dram_naive = make_naive_tensor_view<address_space_enum::global>(
807  q_ptr,
808  make_tuple(kargs.seqlen_q, kargs.hdim_q),
809  make_tuple(kargs.stride_q, 1),
811  number<1>{});
812  if constexpr(FmhaPipeline::kQLoadOnce)
813  {
814  return pad_tensor_view(
815  q_dram_naive,
818  }
819  else
820  {
821  return pad_tensor_view(
822  q_dram_naive,
825  }
826  }();
827  const auto k_dram = [&]() {
828  const auto k_dram_naive = make_naive_tensor_view<address_space_enum::global>(
829  k_ptr,
830  make_tuple(kargs.num_total_pages * kargs.page_block_size, kargs.hdim_q),
831  make_tuple(kargs.stride_k, 1),
833  number<1>{});
834 
835  constexpr bool kPadSeqLenK_ = kUseAsyncCopy ? kPadSeqLenK : true;
836  return pad_tensor_view(
837  k_dram_naive,
840  }();
841  const auto v_dram = [&]() {
842  if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
843  {
844  const auto v_dram_naive = make_naive_tensor_view<address_space_enum::global>(
845  v_ptr,
846  make_tuple(kargs.num_total_pages * kargs.page_block_size, kargs.hdim_v),
847  make_tuple(kargs.stride_v, 1),
849  number<1>{});
850 
851  const auto v_dram_transposed = transform_tensor_view(
852  v_dram_naive,
853  make_tuple(
854  make_pass_through_transform(kargs.hdim_v),
855  make_pass_through_transform(kargs.num_total_pages * kargs.page_block_size)),
858 
859  constexpr bool kPadSeqLenK_ = kUseAsyncCopy ? kPadSeqLenK : true;
860  return pad_tensor_view(
861  v_dram_transposed,
864  }
865  else
866  {
867  const auto v_dram_naive = make_naive_tensor_view<address_space_enum::global>(
868  v_ptr,
869  make_tuple(kargs.hdim_v, kargs.num_total_pages * kargs.page_block_size),
870  make_tuple(kargs.stride_v, 1),
872  number<1>{});
873 
874  constexpr bool kPadHeadDimV_ = kUseAsyncCopy ? kPadHeadDimV : false;
875  return pad_tensor_view(
876  v_dram_naive,
879  }
880  }();
881 
882  auto q_dram_window = make_tile_window(
883  q_dram,
884  [&]() {
885  if constexpr(FmhaPipeline::kQLoadOnce)
888  else
890  }(),
891  {i_m0, 0});
892 
893  auto k_dram_window = make_tile_window(
895 
896  auto v_dram_window =
897  make_tile_window(v_dram,
899  {i_n1, 0});
902  const auto bias_dram_window = [&, i_nhead_ = i_nhead]() {
903  constexpr auto bias_dram_window_lengths =
906  {
907  const BiasDataType* bias_ptr =
908  reinterpret_cast<const BiasDataType*>(kargs.bias_ptr) +
909  static_cast<long_index_t>(i_nhead_) * kargs.nhead_stride_bias +
910  batch_offset_bias;
911 
912  const auto bias_dram = [&]() {
913  const auto bias_dram_naive = make_naive_tensor_view<address_space_enum::global>(
914  bias_ptr,
915  make_tuple(kargs.seqlen_q, kargs.seqlen_k),
916  make_tuple(kargs.stride_bias, 1),
918  number<1>{});
919 
920  return pad_tensor_view(bias_dram_naive,
921  bias_dram_window_lengths,
923  }();
924 
925  return make_tile_window(bias_dram, bias_dram_window_lengths, {i_m0, 0});
926  }
927  else
928  {
929  return make_null_tile_window(bias_dram_window_lengths);
930  }
931  }();
932 
933  // lse
934  auto lse_dram_window = [&, i_nhead_ = i_nhead]() {
935  constexpr auto lse_dram_window_lengths = make_tuple(number<FmhaPipeline::kM0>{});
936  if constexpr(kStoreLSE)
937  {
938  LSEDataType* lse_ptr =
939  reinterpret_cast<LSEDataType*>(kargs.lse_ptr) +
940  static_cast<long_index_t>(i_nhead_) * kargs.nhead_stride_lse + batch_offset_lse;
941 
942  const auto lse_dram = [&]() {
943  const auto lse_dram_naive = make_naive_tensor_view<address_space_enum::global>(
944  lse_ptr,
945  make_tuple(kargs.seqlen_q),
946  make_tuple(1),
947  number<1>{},
948  number<1>{});
949 
950  return pad_tensor_view(
951  lse_dram_naive, lse_dram_window_lengths, sequence<kPadSeqLenQ>{});
952  }();
953 
954  return make_tile_window(lse_dram, lse_dram_window_lengths, {i_m0});
955  }
956  else
957  {
958  return make_null_tile_window(lse_dram_window_lengths);
959  }
960  }();
961 
962  auto dropout = [&, i_nhead_ = i_nhead, i_batch_ = i_batch]() {
963  if constexpr(kHasDropout)
964  {
965  return BlockDropout{i_batch_,
966  i_nhead_,
967  kargs.num_head_q,
968  kargs.is_drop_seed_offset_from_host ? kargs.drop_seed.val
969  : *kargs.drop_seed.ptr,
970  kargs.is_drop_seed_offset_from_host ? kargs.drop_offset.val
971  : *kargs.drop_offset.ptr,
972  kargs.rp_undrop,
973  kargs.p_undrop_in_uint8_t,
974  kargs.is_store_randval};
975  }
976  else
977  {
978  return NullBlockDropout{};
979  };
980  }();
981 
982  auto randval_dram_window = [&, i_nhead_ = i_nhead]() {
983  constexpr auto randval_dram_window_lengths =
985  if constexpr(kHasDropout)
986  {
987  RandValOutputDataType* rand_val_ptr =
988  reinterpret_cast<RandValOutputDataType*>(kargs.rand_val_ptr) +
989  static_cast<long_index_t>(i_nhead_) * kargs.nhead_stride_randval +
990  batch_offset_randval;
991 
992  const auto randval_dram = [&]() {
993  const auto randval_dram_naive =
994  make_naive_tensor_view<address_space_enum::global>(
995  rand_val_ptr,
996  make_tuple(kargs.seqlen_q, kargs.seqlen_k),
997  make_tuple(kargs.stride_randval, 1),
998  number<1>{},
999  number<1>{});
1000 
1001  return pad_tensor_view(randval_dram_naive,
1002  randval_dram_window_lengths,
1004  }();
1005 
1006  return make_tile_window(randval_dram, randval_dram_window_lengths, {i_m0, 0});
1007  }
1008  else
1009  {
1010  return make_null_tile_window(randval_dram_window_lengths);
1011  }
1012  }();
1013 
1014  FmhaMask mask = [&]() {
1015  if constexpr(kHasMask)
1016  return ck_tile::make_generic_attention_mask_from_lr_window<FmhaMask>(
1017  kargs.window_size_left,
1018  kargs.window_size_right,
1019  kargs.seqlen_q,
1020  kargs.seqlen_k,
1022  else
1023  return FmhaMask{kargs.seqlen_q, kargs.seqlen_k};
1024  }();
1025 
1026  // WA i_batch capture structure binding before c++20
1027  auto position_encoding = [&, i_batch_ = i_batch, i_nhead_ = i_nhead]() {
1028  if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI)
1029  {
1030  // data loading, shared by entire wg
1031  // TODO: how to use s_read?
1032  SaccDataType slope =
1033  *(reinterpret_cast<const SaccDataType*>(kargs.alibi_slope_ptr) +
1034  i_batch_ * kargs.alibi_slope_stride + i_nhead_);
1035 #if CK_TILE_FMHA_FWD_FAST_EXP2
1036  slope *= ck_tile::log2e_v<>;
1037 #endif
1038  if constexpr(kHasMask)
1039  {
1040  return make_alibi_from_lr_mask<SaccDataType, true>(slope,
1041  kargs.window_size_left,
1042  kargs.window_size_right,
1043  kargs.seqlen_q,
1044  kargs.seqlen_k,
1045  kargs.mask_type);
1046  }
1047  else
1048  {
1050  slope, kargs.seqlen_q, kargs.seqlen_k, AlibiMode::FROM_BOTTOM_RIGHT};
1051  }
1052  }
1053  else
1054  {
1056  }
1057  }();
1058 
1059  AttentionVariant variant;
1060  const auto variant_params = [&] {
1061  if constexpr(kHasLogitsSoftCap)
1062  {
1064  mask, kargs.scale_s, kargs.logits_soft_cap, kargs.logits_soft_cap_rcp};
1065  }
1066  else
1067  {
1068  return ck_tile::StandardAttentionParams<FmhaMask>{mask, kargs.scale_s};
1069  }
1070  }();
1071 
1072  BlockIndices block_indices{i_batch, i_nhead, i_nhead / kargs.nhead_ratio_qk};
1073 
1074  auto o_acc_tile = [&]() {
1075  if constexpr(kDoFp8StaticQuant)
1076  {
1077  return FmhaPipeline{}(
1078  q_dram_window,
1079  identity{}, // q_element_func
1080  k_dram_window,
1081  identity{}, // k_element_func
1082  v_dram_window,
1083  identity{}, // v_element_func
1084  bias_dram_window,
1085  identity{}, // bias_element_func
1086  randval_dram_window,
1087  lse_dram_window,
1088  identity{}, // lse_element_func
1089  identity{}, // s_acc_element_func
1090  scales{kargs.scale_p}, // p_compute_element_func
1091  composes(saturates<fp8_t>{}, scales{kargs.scale_o}), // o_acc_element_func
1092  mask,
1093  position_encoding,
1094  kargs.scale_s,
1095  variant,
1096  variant_params,
1097  block_indices,
1098  smem_ptr,
1099  kargs.kv_page_indices,
1100  kargs.stride_k,
1101  kargs.stride_v,
1102  dropout);
1103  }
1104  else
1105  {
1106  return FmhaPipeline{}(q_dram_window,
1107  k_dram_window,
1108  v_dram_window,
1109  bias_dram_window,
1110  randval_dram_window,
1111  lse_dram_window,
1112  mask,
1113  position_encoding,
1114  kargs.scale_s,
1115  variant,
1116  variant_params,
1117  block_indices,
1118  smem_ptr,
1119  kargs.kv_page_indices,
1120  kargs.stride_k,
1121  kargs.stride_v,
1122  dropout);
1123  }
1124  }();
1125 
1126  // O DRAM and O DRAM window
1127  auto o_dram = [&]() {
1128  const auto o_dram_naive = make_naive_tensor_view<address_space_enum::global>(
1129  o_ptr,
1130  make_tuple(kargs.seqlen_q, kargs.hdim_v),
1131  make_tuple(kargs.stride_o, 1),
1133  number<1>{});
1134 
1135  return pad_tensor_view(
1136  o_dram_naive,
1139  }();
1140 
1141  auto o_dram_window =
1142  make_tile_window(o_dram,
1144  {i_m0, i_n1});
1145 
1146  EpiloguePipeline{}(o_dram_window, o_acc_tile, nullptr);
1147  }
1148 };
1149 
1150 } // namespace ck_tile
#define CK_TILE_DEVICE
Definition: config.hpp:41
#define CK_TILE_HOST
Definition: config.hpp:40
#define CK_TILE_FMHA_FWD_FAST_EXP2
Definition: config.hpp:234
#define CK_TILE_HOST_DEVICE
Definition: config.hpp:42
#define _TS_
#define _SS_
__host__ constexpr __device__ T max(T x)
Definition: math.hpp:84
__host__ T floor(T x)
Definition: math_v2.hpp:367
Definition: cluster_descriptor.hpp:13
constexpr CK_TILE_DEVICE auto make_null_tile_window(const WindowLengths &window_lengths)
Definition: null_tile_window.hpp:66
_BitInt(8) fp8_t
Definition: float8.hpp:204
constexpr CK_TILE_HOST_DEVICE auto integer_divide_ceil(X x, Y y)
Definition: math.hpp:149
_Float16 fp16_t
Definition: half.hpp:110
constexpr CK_TILE_HOST_DEVICE auto transform_tensor_view(const OldTensorView &old_tensor_view, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition: tensor_view.hpp:511
bfloat16_t bf16_t
Definition: bfloat16.hpp:113
int32_t index_t
Definition: integer.hpp:9
constexpr CK_TILE_HOST_DEVICE auto pad_tensor_view(const TensorView &tensor_view, const TileLengths &tile_lengths, DoPads)
Definition: tensor_view.hpp:530
constexpr CK_TILE_HOST_DEVICE auto make_pass_through_transform(const LowLength &low_length)
Definition: coordinate_transform.hpp:1558
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition: type_traits.hpp:21
int64_t long_index_t
Definition: integer.hpp:11
int32_t int32_t
Definition: integer.hpp:10
constexpr CK_TILE_DEVICE auto make_tile_window(null_tensor_view, const WindowLengths &window_lengths, const multi_index< WindowLengths::size()> &, Ts &&...)
Definition: null_tile_window.hpp:75
unsigned _BitInt(8) bf8_t
Definition: float8.hpp:206
constexpr CK_TILE_HOST_DEVICE auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:360
GenericAttentionMaskEnum
Definition: block_masking.hpp:11
__host__ __device__ composes(Ts &&...) -> composes< remove_cvref_t< Ts >... >
FIXME: create macro to replace 'host device' and nothing more.
constexpr CK_TILE_HOST_DEVICE T max(T x)
Definition: math.hpp:161
typename conditional< predicate, X, Y >::type conditional_t
Definition: functional.hpp:115
unsigned char uint8_t
Definition: stdint.h:124
unsigned __int64 uint64_t
Definition: stdint.h:136
Definition: block_position_encoding.hpp:48
Definition: block_attention_bias_enum.hpp:19
Definition: block_dropout.hpp:26
const float rp_undrop
Definition: block_dropout.hpp:290
Definition: block_position_encoding.hpp:137
Definition: fmha_batch_prefill_kernel.hpp:312
ck_tile::index_t kv_head_idx
Definition: fmha_batch_prefill_kernel.hpp:315
ck_tile::index_t qo_head_idx
Definition: fmha_batch_prefill_kernel.hpp:314
ck_tile::index_t batch_idx
Definition: fmha_batch_prefill_kernel.hpp:313
Definition: fmha_batch_prefill_kernel.hpp:192
ck_tile::index_t alibi_slope_stride
Definition: fmha_batch_prefill_kernel.hpp:195
const void * alibi_slope_ptr
Definition: fmha_batch_prefill_kernel.hpp:194
ck_tile::index_t batch_stride_bias
Definition: fmha_batch_prefill_kernel.hpp:188
ck_tile::index_t batch_stride_randval
Definition: fmha_batch_prefill_kernel.hpp:269
ck_tile::index_t batch_stride_o
Definition: fmha_batch_prefill_kernel.hpp:288
ck_tile::index_t batch_stride_v
Definition: fmha_batch_prefill_kernel.hpp:287
ck_tile::index_t batch_stride_q
Definition: fmha_batch_prefill_kernel.hpp:285
ck_tile::index_t batch_stride_k
Definition: fmha_batch_prefill_kernel.hpp:286
ck_tile::index_t nhead_stride_bias
Definition: fmha_batch_prefill_kernel.hpp:183
ck_tile::index_t stride_bias
Definition: fmha_batch_prefill_kernel.hpp:182
const void * bias_ptr
Definition: fmha_batch_prefill_kernel.hpp:181
ck_tile::index_t stride_randval
Definition: fmha_batch_prefill_kernel.hpp:263
void init_dropout(float p_drop, const uint64_t *seed_ptr, const uint64_t *offset_ptr)
Definition: fmha_batch_prefill_kernel.hpp:246
ck_tile::index_t nhead_stride_randval
Definition: fmha_batch_prefill_kernel.hpp:264
void init_dropout(float p_drop, uint64_t seed, uint64_t offset)
Definition: fmha_batch_prefill_kernel.hpp:234
void * rand_val_ptr
Definition: fmha_batch_prefill_kernel.hpp:261
float rp_undrop
Definition: fmha_batch_prefill_kernel.hpp:258
bool is_store_randval
Definition: fmha_batch_prefill_kernel.hpp:260
uint8_t p_undrop_in_uint8_t
Definition: fmha_batch_prefill_kernel.hpp:259
Definition: fmha_batch_prefill_kernel.hpp:118
ck_tile::index_t stride_q
Definition: fmha_batch_prefill_kernel.hpp:146
ck_tile::index_t stride_v
Definition: fmha_batch_prefill_kernel.hpp:148
int32_t num_total_pages
Definition: fmha_batch_prefill_kernel.hpp:134
float scale_s
Definition: fmha_batch_prefill_kernel.hpp:144
ck_tile::index_t seqlen_q
Definition: fmha_batch_prefill_kernel.hpp:124
ck_tile::index_t stride_k
Definition: fmha_batch_prefill_kernel.hpp:147
ck_tile::index_t nhead_stride_o
Definition: fmha_batch_prefill_kernel.hpp:154
ck_tile::index_t nhead_stride_k
Definition: fmha_batch_prefill_kernel.hpp:152
ck_tile::index_t nhead_ratio_qk
Definition: fmha_batch_prefill_kernel.hpp:132
ck_tile::index_t nhead_stride_v
Definition: fmha_batch_prefill_kernel.hpp:153
ck_tile::index_t nhead_stride_q
Definition: fmha_batch_prefill_kernel.hpp:151
const int32_t * kv_page_indices
Definition: fmha_batch_prefill_kernel.hpp:136
const void * v_ptr
Definition: fmha_batch_prefill_kernel.hpp:121
const int32_t * kv_indptr
Definition: fmha_batch_prefill_kernel.hpp:135
void * o_ptr
Definition: fmha_batch_prefill_kernel.hpp:122
ck_tile::index_t seqlen_k
Definition: fmha_batch_prefill_kernel.hpp:125
ck_tile::index_t stride_o
Definition: fmha_batch_prefill_kernel.hpp:149
ck_tile::index_t hdim_v
Definition: fmha_batch_prefill_kernel.hpp:127
ck_tile::index_t num_head_q
Definition: fmha_batch_prefill_kernel.hpp:129
static constexpr ck_tile::index_t page_block_size
Definition: fmha_batch_prefill_kernel.hpp:141
const void * k_ptr
Definition: fmha_batch_prefill_kernel.hpp:120
ck_tile::index_t hdim_q
Definition: fmha_batch_prefill_kernel.hpp:126
const void * q_ptr
Definition: fmha_batch_prefill_kernel.hpp:119
ck_tile::index_t batch_stride_lse
Definition: fmha_batch_prefill_kernel.hpp:215
ck_tile::index_t nhead_stride_lse
Definition: fmha_batch_prefill_kernel.hpp:214
void * lse_ptr
Definition: fmha_batch_prefill_kernel.hpp:213
bool is_drop_seed_offset_from_host
Definition: fmha_batch_prefill_kernel.hpp:229
ValueOrPointer< uint64_t > drop_seed
Definition: fmha_batch_prefill_kernel.hpp:227
ValueOrPointer< uint64_t > drop_offset
Definition: fmha_batch_prefill_kernel.hpp:228
Definition: fmha_batch_prefill_kernel.hpp:111
float scale_p
Definition: fmha_batch_prefill_kernel.hpp:207
float scale_o
Definition: fmha_batch_prefill_kernel.hpp:208
ck_tile::index_t batch_stride_v
Definition: fmha_batch_prefill_kernel.hpp:306
ck_tile::index_t batch_stride_k
Definition: fmha_batch_prefill_kernel.hpp:305
const int32_t * seqstart_q_ptr
Definition: fmha_batch_prefill_kernel.hpp:304
float logits_soft_cap_rcp
Definition: fmha_batch_prefill_kernel.hpp:176
void init_logits_soft_cap(float logits_soft_cap_)
Definition: fmha_batch_prefill_kernel.hpp:161
float logits_soft_cap
Definition: fmha_batch_prefill_kernel.hpp:175
Definition: fmha_batch_prefill_kernel.hpp:199
ck_tile::index_t window_size_right
Definition: fmha_batch_prefill_kernel.hpp:201
ck_tile::index_t window_size_left
Definition: fmha_batch_prefill_kernel.hpp:201
ck_tile::GenericAttentionMaskEnum mask_type
Definition: fmha_batch_prefill_kernel.hpp:202
Definition: fmha_batch_prefill_kernel.hpp:64
Definition: fmha_batch_prefill_kernel.hpp:26
static constexpr CK_TILE_DEVICE auto GetTileIndex(const Kargs &kargs)
Definition: fmha_batch_prefill_kernel.hpp:636
static constexpr bool kIsGroupMode
Definition: fmha_batch_prefill_kernel.hpp:47
static constexpr ck_tile::index_t kBlockPerCu
Definition: fmha_batch_prefill_kernel.hpp:31
ck_tile::remove_cvref_t< typename FmhaPipeline::VDataType > VDataType
Definition: fmha_batch_prefill_kernel.hpp:37
ck_tile::remove_cvref_t< FmhaPipeline_ > FmhaPipeline
Definition: fmha_batch_prefill_kernel.hpp:27
ck_tile::remove_cvref_t< typename FmhaPipeline::KDataType > KDataType
Definition: fmha_batch_prefill_kernel.hpp:36
static constexpr bool kPadSeqLenQ
Definition: fmha_batch_prefill_kernel.hpp:48
ck_tile::remove_cvref_t< typename FmhaPipeline::RandValOutputDataType > RandValOutputDataType
Definition: fmha_batch_prefill_kernel.hpp:40
static constexpr bool kDoFp8StaticQuant
Definition: fmha_batch_prefill_kernel.hpp:56
static constexpr bool kPadHeadDimV
Definition: fmha_batch_prefill_kernel.hpp:51
ck_tile::remove_cvref_t< typename FmhaPipeline::LSEDataType > LSEDataType
Definition: fmha_batch_prefill_kernel.hpp:41
ck_tile::remove_cvref_t< typename FmhaPipeline::QDataType > QDataType
Definition: fmha_batch_prefill_kernel.hpp:35
static constexpr bool kHasMask
Definition: fmha_batch_prefill_kernel.hpp:59
static CK_TILE_HOST std::string GetName()
Definition: fmha_batch_prefill_kernel.hpp:72
ck_tile::remove_cvref_t< typename FmhaPipeline::BiasDataType > BiasDataType
Definition: fmha_batch_prefill_kernel.hpp:38
ck_tile::remove_cvref_t< typename FmhaPipeline::FmhaMask > FmhaMask
Definition: fmha_batch_prefill_kernel.hpp:58
static constexpr bool kPadSeqLenK
Definition: fmha_batch_prefill_kernel.hpp:49
static constexpr bool kHasLogitsSoftCap
Definition: fmha_batch_prefill_kernel.hpp:52
static constexpr bool kHasDropout
Definition: fmha_batch_prefill_kernel.hpp:55
static constexpr bool kStoreLSE
Definition: fmha_batch_prefill_kernel.hpp:54
ck_tile::remove_cvref_t< typename FmhaPipeline::SaccDataType > SaccDataType
Definition: fmha_batch_prefill_kernel.hpp:43
static constexpr CK_TILE_HOST_DEVICE ck_tile::index_t GetSmemSize()
Definition: fmha_batch_prefill_kernel.hpp:697
static constexpr ck_tile::index_t kBlockSize
Definition: fmha_batch_prefill_kernel.hpp:29
static constexpr CK_TILE_HOST std::enable_if_t< Cond, Kargs > MakeKargs(const void *q_ptr, const void *k_ptr, const void *v_ptr, const void *bias_ptr, void *rand_val_ptr, void *lse_ptr, void *o_ptr, const void *seqstart_q_ptr, ck_tile::index_t hdim_q, ck_tile::index_t hdim_v, ck_tile::index_t num_head_q, ck_tile::index_t nhead_ratio_qk, int32_t num_total_pages, const void *kv_indptr, const void *kv_page_indices, float scale_s, float scale_p, float scale_o, float logits_soft_cap, ck_tile::index_t stride_q, ck_tile::index_t stride_k, ck_tile::index_t stride_v, ck_tile::index_t stride_bias, ck_tile::index_t stride_randval, ck_tile::index_t stride_o, ck_tile::index_t nhead_stride_q, ck_tile::index_t nhead_stride_k, ck_tile::index_t nhead_stride_v, ck_tile::index_t nhead_stride_bias, ck_tile::index_t nhead_stride_randval, ck_tile::index_t nhead_stride_lse, ck_tile::index_t nhead_stride_o, ck_tile::index_t batch_stride_k, ck_tile::index_t batch_stride_v, ck_tile::index_t window_size_left, ck_tile::index_t window_size_right, ck_tile::index_t mask_type, float p_drop, bool s_randval, std::variant< std::pair< uint64_t, uint64_t >, std::pair< const void *, const void * >> drop_seed_offset)
Definition: fmha_batch_prefill_kernel.hpp:472
static constexpr auto BiasEnum
Definition: fmha_batch_prefill_kernel.hpp:53
ck_tile::remove_cvref_t< typename FmhaPipeline::VLayout > VLayout
Definition: fmha_batch_prefill_kernel.hpp:45
ck_tile::remove_cvref_t< typename FmhaPipeline::ODataType > ODataType
Definition: fmha_batch_prefill_kernel.hpp:42
ck_tile::remove_cvref_t< typename FmhaPipeline::AttentionVariant > AttentionVariant
Definition: fmha_batch_prefill_kernel.hpp:57
static constexpr bool kUseAsyncCopy
Definition: fmha_batch_prefill_kernel.hpp:61
static constexpr CK_TILE_HOST std::enable_if_t< Cond, Kargs > MakeKargs(const void *q_ptr, const void *k_ptr, const void *v_ptr, const void *bias_ptr, void *rand_val_ptr, void *lse_ptr, void *o_ptr, ck_tile::index_t seqlen_q, ck_tile::index_t hdim_q, ck_tile::index_t hdim_v, ck_tile::index_t num_head_q, ck_tile::index_t nhead_ratio_qk, int32_t num_total_pages, const void *kv_indptr, const void *kv_page_indices, float scale_s, float scale_p, float scale_o, float logits_soft_cap, ck_tile::index_t stride_q, ck_tile::index_t stride_k, ck_tile::index_t stride_v, ck_tile::index_t stride_bias, ck_tile::index_t stride_randval, ck_tile::index_t stride_o, ck_tile::index_t nhead_stride_q, ck_tile::index_t nhead_stride_k, ck_tile::index_t nhead_stride_v, ck_tile::index_t nhead_stride_bias, ck_tile::index_t nhead_stride_randval, ck_tile::index_t nhead_stride_lse, ck_tile::index_t nhead_stride_o, ck_tile::index_t batch_stride_q, ck_tile::index_t batch_stride_k, ck_tile::index_t batch_stride_v, ck_tile::index_t batch_stride_bias, ck_tile::index_t batch_stride_randval, ck_tile::index_t batch_stride_lse, ck_tile::index_t batch_stride_o, ck_tile::index_t window_size_left, ck_tile::index_t window_size_right, ck_tile::index_t mask_type, float p_drop, bool s_randval, std::variant< std::pair< uint64_t, uint64_t >, std::pair< const void *, const void * >> drop_seed_offset)
Definition: fmha_batch_prefill_kernel.hpp:320
static constexpr ck_tile::index_t kBlockPerCuInput
Definition: fmha_batch_prefill_kernel.hpp:33
static constexpr CK_TILE_HOST auto GridSize(ck_tile::index_t batch_size_, ck_tile::index_t nhead_, ck_tile::index_t seqlen_q_, ck_tile::index_t hdim_v_)
Definition: fmha_batch_prefill_kernel.hpp:613
static constexpr bool kPadHeadDimQ
Definition: fmha_batch_prefill_kernel.hpp:50
static constexpr CK_TILE_HOST auto BlockSize()
Definition: fmha_batch_prefill_kernel.hpp:695
CK_TILE_DEVICE void operator()(Kargs kargs) const
Definition: fmha_batch_prefill_kernel.hpp:702
std::conditional_t< kIsGroupMode, FmhaFwdGroupModeKargs, FmhaFwdBatchModeKargs > Kargs
Definition: fmha_batch_prefill_kernel.hpp:309
ck_tile::remove_cvref_t< EpiloguePipeline_ > EpiloguePipeline
Definition: fmha_batch_prefill_kernel.hpp:28
Definition: variants.hpp:63
float logits_soft_cap
Definition: variants.hpp:128
Definition: block_dropout.hpp:12
Definition: variants.hpp:51
Definition: integral_constant.hpp:13
Definition: functional.hpp:86
Definition: coordinate_transform.hpp:1392
Definition: unary_element_function.hpp:56
Definition: math.hpp:28
Definition: sequence.hpp:49