/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/fmha/kernel/fmha_batch_prefill_kernel.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/fmha/kernel/fmha_batch_prefill_kernel.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/fmha/kernel/fmha_batch_prefill_kernel.hpp Source File
fmha_batch_prefill_kernel.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include "ck_tile/core.hpp"
7 #include "ck_tile/ops/common.hpp"
10 
11 #include <string>
12 #include <type_traits>
13 #include <utility>
14 #include <variant>
15 
16 // S[seqlen_q, seqlen_k] = Q[seqlen_q, hdim_q] @ K[seqlen_k, hdim_q]
17 // S'[seqlen_q, seqlen_k] = S[seqlen_q, seqlen_k] * Scale[1]
18 // S''[seqlen_q, seqlen_k] = S'[seqlen_q, seqlen_k] + Bias[seqlen_q, seqlen_k]
19 // P[seqlen_q, seqlen_k] = Softmax(S''[seqlen_q, seqlen_k])
20 // O[seqlen_q, hdim_v] = P[seqlen_q, seqlen_k] @ V^T[hdim_v, seqlen_k]
21 
22 namespace ck_tile {
23 
24 template <typename FmhaPipeline_, typename EpiloguePipeline_>
26 {
29  static constexpr ck_tile::index_t kBlockSize = FmhaPipeline::kBlockSize;
30 
31  static constexpr ck_tile::index_t kBlockPerCu = FmhaPipeline::kBlockPerCu;
32  static_assert(kBlockPerCu > 0);
33  static constexpr ck_tile::index_t kBlockPerCuInput = FmhaPipeline::Problem::kBlockPerCu;
34 
44 
46 
47  static constexpr bool kIsGroupMode = FmhaPipeline::kIsGroupMode;
48  static constexpr bool kPadSeqLenQ = FmhaPipeline::kPadSeqLenQ;
49  static constexpr bool kPadSeqLenK = FmhaPipeline::kPadSeqLenK;
50  static constexpr bool kPadHeadDimQ = FmhaPipeline::kPadHeadDimQ;
51  static constexpr bool kPadHeadDimV = FmhaPipeline::kPadHeadDimV;
52  static constexpr bool kHasLogitsSoftCap = FmhaPipeline::kHasLogitsSoftCap;
53  static constexpr auto BiasEnum = FmhaPipeline::BiasEnum;
54  static constexpr bool kStoreLSE = FmhaPipeline::kStoreLSE;
55  static constexpr bool kHasDropout = FmhaPipeline::kHasDropout;
56  static constexpr bool kDoFp8StaticQuant = FmhaPipeline::Problem::kDoFp8StaticQuant;
59  static constexpr bool kHasMask = FmhaMask::IsMasking;
60 
61  static constexpr bool kUseAsyncCopy = FmhaPipeline::Policy::AsyncCopy;
62 
63  // clang-format off
64  template <typename T> struct t2s;
65  template <> struct t2s<float> { static constexpr const char * name = "fp32"; };
66  template <> struct t2s<ck_tile::fp16_t> { static constexpr const char * name = "fp16"; };
67  template <> struct t2s<ck_tile::bf16_t> { static constexpr const char * name = "bf16"; };
68  template <> struct t2s<ck_tile::fp8_t> { static constexpr const char * name = "fp8"; };
69  template <> struct t2s<ck_tile::bf8_t> { static constexpr const char * name = "bf8"; };
70  // clang-format on
71 
72  CK_TILE_HOST static std::string GetName()
73  {
74  // sync with generate.py
75  // clang-format off
76  using bfs = typename FmhaPipeline::BlockFmhaShape;
77  using g0br = typename bfs::Gemm0BlockWarps;
78  using g1br = typename bfs::Gemm1BlockWarps;
79  using g0wt = typename bfs::Gemm0WarpTile;
80  using g1wt = typename bfs::Gemm1WarpTile;
81  #define _SS_ std::string
82  #define _TS_ std::to_string
83  auto pn = [&] () {
84  std::string n;
85  if (kPadSeqLenQ) n += "s";
86  if (kPadSeqLenK) n += "sk";
87  if (kPadHeadDimQ) n += "d";
88  if (kPadHeadDimV) n += "dv";
89  return n.empty() ? n : std::string("p") + n; }();
90  return
91  _SS_("fmha_batch_prefill_d") + _TS_(bfs::kQKHeaddim) + "_" + _SS_(t2s<QDataType>::name) +
92  "_" + (kIsGroupMode ? "group" : "batch") + "_"
93  "b" + _TS_(bfs::kM0) + "x" + _TS_(bfs::kN0) + "x" + _TS_(bfs::kK0) + "x" +
94  _TS_(bfs::kN1) + "x" + _TS_(bfs::kK1) + "x" + _TS_(bfs::kQKHeaddim) + "_" +
95  "r" + _TS_(g0br::at(ck_tile::number<0>{})) + "x" + _TS_(g0br::at(ck_tile::number<1>{})) + "x" + _TS_(g0br::at(ck_tile::number<2>{})) + "_" +
96  "r" + _TS_(g1br::at(ck_tile::number<0>{})) + "x" + _TS_(g1br::at(ck_tile::number<1>{})) + "x" + _TS_(g1br::at(ck_tile::number<2>{})) + "_" +
97  "w" + _TS_(g0wt::at(ck_tile::number<0>{})) + "x" + _TS_(g0wt::at(ck_tile::number<1>{})) + "x" + _TS_(g0wt::at(ck_tile::number<2>{})) + "_" +
98  "w" + _TS_(g1wt::at(ck_tile::number<0>{})) + "x" + _TS_(g1wt::at(ck_tile::number<1>{})) + "x" + _TS_(g1wt::at(ck_tile::number<2>{})) + "_" +
99  (kBlockPerCuInput == -1 ? "" : ("o" + _TS_(kBlockPerCu) + "_")) + _SS_(FmhaPipeline::name) + "_" +
100  "v" + (std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor> ? "r" : "c") + (pn.empty() ? "_npad" : "_" + pn) +
101  (kHasLogitsSoftCap ? "_logits" : "_nlogits" ) + (BiasEnum == BlockAttentionBiasEnum::NO_BIAS ? _SS_("_nbias") : (_SS_("_") + BlockAttentionBiasEnumToStr<BiasEnum>::name)) +
102  (kHasMask ? "_" + _SS_(FmhaMask::name) : "_nmask") + (kStoreLSE ? "_lse" : "_nlse" ) + (kHasDropout ? "_dropout" : "_ndropout" ) + (kDoFp8StaticQuant ? "_squant" : "_nsquant" );
103  #undef _SS_
104  #undef _TS_
105  // clang-format on
106  }
107 
108  template <ck_tile::index_t I> // to avoid duplicated base class prblem, introduce an template
109  // arg
111  {
112  };
113 
114  // kargs use aggregate initializer, so no constructor will provided
115  // use inheritance to minimize karg size
116  // user need to use MakeKargs() function to create kargs.
118  {
119  const void* q_ptr;
120  const void* k_ptr;
121  const void* v_ptr;
122  void* o_ptr;
123 
128 
130  // for MQA/GQA, nhead could be different. This parameter is nhead_q / nhead_k
131  // if this param is larger than 1, indicate MQA/GQA case
133 
137 #if 0 // we assume page_block_size=1 for now
138  const int32_t* kv_last_page_lens;
140 #else
141  static constexpr ck_tile::index_t page_block_size = 1;
142 #endif
143 
144  float scale_s;
145 
150 
155  };
156 
158  {
160 
161  void init_logits_soft_cap(float logits_soft_cap_)
162  {
163  if(0 < logits_soft_cap_)
164  {
165  logits_soft_cap = logits_soft_cap_;
167  }
168  else
169  {
170  logits_soft_cap = 0.f;
171  logits_soft_cap_rcp = 0.f;
172  }
173  }
174 
177  };
178 
180  {
181  const void* bias_ptr = nullptr;
184  };
185 
187  {
189  };
190 
192  {
193  // alibi is batch*nhead*1, no matter in batch/group mode, they are the same
194  const void* alibi_slope_ptr;
195  ck_tile::index_t alibi_slope_stride; // stride in batch, or 0 for all batch share same slope
196  };
197 
199  {
200  // ck_tile::index_t window_size_left, window_size_right;
203  };
204 
206  {
207  float scale_p;
208  float scale_o;
209  };
210 
212  {
213  void* lse_ptr = nullptr;
216  };
217 
219  {
220  template <typename T>
222  {
223  T val;
224  const T* ptr;
225  };
226 
230  };
231 
233  {
234  void init_dropout(float p_drop, uint64_t seed, uint64_t offset)
235  {
236  float p_undrop = 1.0 - p_drop;
239  rp_undrop = 1.0 / p_undrop;
240 
241  this->drop_seed.val = seed;
242  this->drop_offset.val = offset;
243  this->is_drop_seed_offset_from_host = true;
244  }
245 
246  void init_dropout(float p_drop, const uint64_t* seed_ptr, const uint64_t* offset_ptr)
247  {
248  float p_undrop = 1.0 - p_drop;
251  rp_undrop = 1.0 / p_undrop;
252 
253  this->drop_seed.ptr = seed_ptr;
254  this->drop_offset.ptr = offset_ptr;
255  this->is_drop_seed_offset_from_host = false;
256  }
257 
258  float rp_undrop = 1;
260  bool is_store_randval = false;
261  void* rand_val_ptr = nullptr;
262 
265  };
266 
268  {
270  };
271 
274  std::conditional_t<BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS,
275  FmhaFwdBatchModeBiasKargs,
276  std::conditional_t<BiasEnum == BlockAttentionBiasEnum::ALIBI,
277  FmhaFwdAlibiKargs,
278  FmhaFwdEmptyKargs<0>>>,
279  std::conditional_t<kHasMask, FmhaFwdMaskKargs, FmhaFwdEmptyKargs<1>>,
280  std::conditional_t<kStoreLSE, FmhaFwdCommonLSEKargs, FmhaFwdEmptyKargs<2>>,
281  std::conditional_t<kDoFp8StaticQuant, FmhaFwdFp8StaticQuantKargs, FmhaFwdEmptyKargs<3>>,
282  std::conditional_t<kHasDropout, FmhaFwdBatchModeDropoutKargs, FmhaFwdEmptyKargs<4>>,
283  std::conditional_t<kHasLogitsSoftCap, FmhaFwdLogitsSoftCapKargs, FmhaFwdEmptyKargs<5>>
284  {
289  };
290 
293  std::conditional_t<BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS,
294  FmhaFwdCommonBiasKargs,
295  std::conditional_t<BiasEnum == BlockAttentionBiasEnum::ALIBI,
296  FmhaFwdAlibiKargs,
297  FmhaFwdEmptyKargs<0>>>,
298  std::conditional_t<kHasMask, FmhaFwdMaskKargs, FmhaFwdEmptyKargs<1>>,
299  std::conditional_t<kStoreLSE, FmhaFwdCommonLSEKargs, FmhaFwdEmptyKargs<2>>,
300  std::conditional_t<kDoFp8StaticQuant, FmhaFwdFp8StaticQuantKargs, FmhaFwdEmptyKargs<3>>,
301  std::conditional_t<kHasDropout, FmhaFwdCommonDropoutKargs, FmhaFwdEmptyKargs<4>>,
302  std::conditional_t<kHasLogitsSoftCap, FmhaFwdLogitsSoftCapKargs, FmhaFwdEmptyKargs<5>>
303  {
307  };
308 
309  using Kargs = std::conditional_t<kIsGroupMode, FmhaFwdGroupModeKargs, FmhaFwdBatchModeKargs>;
310 
312  {
316  };
317 
318  template <bool Cond = !kIsGroupMode>
319  CK_TILE_HOST static constexpr std::enable_if_t<Cond, Kargs>
320  MakeKargs(const void* q_ptr,
321  const void* k_ptr,
322  const void* v_ptr,
323  const void* bias_ptr,
324  void* rand_val_ptr,
325  void* lse_ptr,
326  void* o_ptr,
327  ck_tile::index_t seqlen_q,
328  ck_tile::index_t hdim_q,
329  ck_tile::index_t hdim_v,
330  ck_tile::index_t num_head_q,
331  ck_tile::index_t nhead_ratio_qk,
332  int32_t num_total_pages,
333  const void* kv_indptr,
334  const void* kv_page_indices,
335 #if 0 // we assume page_block_size=1 for now
336  const void* kv_last_page_lens,
337  ck_tile::index_t page_block_size,
338 #endif
339  float scale_s,
340  float scale_p,
341  float scale_o,
342  float logits_soft_cap,
343  ck_tile::index_t stride_q,
344  ck_tile::index_t stride_k,
345  ck_tile::index_t stride_v,
346  ck_tile::index_t stride_bias,
347  ck_tile::index_t stride_randval,
348  ck_tile::index_t stride_o,
349  ck_tile::index_t nhead_stride_q,
350  ck_tile::index_t nhead_stride_k,
351  ck_tile::index_t nhead_stride_v,
352  ck_tile::index_t nhead_stride_bias,
353  ck_tile::index_t nhead_stride_randval,
354  ck_tile::index_t nhead_stride_lse,
355  ck_tile::index_t nhead_stride_o,
356  ck_tile::index_t batch_stride_q,
357  ck_tile::index_t batch_stride_k,
358  ck_tile::index_t batch_stride_v,
359  ck_tile::index_t batch_stride_bias,
360  ck_tile::index_t batch_stride_randval,
361  ck_tile::index_t batch_stride_lse,
362  ck_tile::index_t batch_stride_o,
363  ck_tile::index_t window_size_left,
364  ck_tile::index_t window_size_right,
365  ck_tile::index_t mask_type,
366  float p_drop,
367  bool s_randval,
368  std::variant<std::pair<uint64_t, uint64_t>, std::pair<const void*, const void*>>
369  drop_seed_offset)
370  {
371  Kargs kargs{{q_ptr,
372  k_ptr,
373  v_ptr,
374  o_ptr,
375  seqlen_q,
376  -1,
377  hdim_q,
378  hdim_v,
379  num_head_q,
380  nhead_ratio_qk,
381  num_total_pages,
382  reinterpret_cast<const int32_t*>(kv_indptr),
383  reinterpret_cast<const int32_t*>(kv_page_indices),
384 #if 0 // we assume page_block_size=1 for now
385  reinterpret_cast<const int32_t*>(kv_last_page_lens),
386  page_block_size,
387 #endif
389  static_cast<float>(scale_s * ck_tile::log2e_v<>),
390 #else
391  scale_s,
392 #endif
393  stride_q,
394  stride_k,
395  stride_v,
396  stride_o,
397  nhead_stride_q,
398  nhead_stride_k,
399  nhead_stride_v,
400  nhead_stride_o}, // args for common karg
401  {}, // placeholder for bias
402  {}, // placeholder for mask
403  {}, // placeholder for lse
404  {}, // placeholder for fp8_static_quant args
405  {}, // placeholder for dropout
406  {}, // placeholder for logits_soft_cap
407  batch_stride_q,
408  batch_stride_k,
409  batch_stride_v,
410  batch_stride_o};
411 
413  {
414  kargs.bias_ptr = bias_ptr;
415  kargs.stride_bias = stride_bias;
416  kargs.nhead_stride_bias = nhead_stride_bias;
417  kargs.batch_stride_bias = batch_stride_bias;
418  }
419  else if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI)
420  {
421  kargs.alibi_slope_ptr = bias_ptr;
422  kargs.alibi_slope_stride = stride_bias;
423  }
424  if constexpr(kHasMask)
425  {
426  kargs.window_size_left = window_size_left;
427  kargs.window_size_right = window_size_right;
428  kargs.mask_type = static_cast<ck_tile::GenericAttentionMaskEnum>(mask_type);
429  }
430  if constexpr(kStoreLSE)
431  {
432  kargs.lse_ptr = lse_ptr;
433  kargs.nhead_stride_lse = nhead_stride_lse;
434  kargs.batch_stride_lse = batch_stride_lse;
435  }
436  if constexpr(kDoFp8StaticQuant)
437  {
438  kargs.scale_p = scale_p;
439  kargs.scale_o = scale_o;
440  }
441  if constexpr(kHasDropout)
442  {
443  if(drop_seed_offset.index() == 0) // seed & offset come from host
444  {
445  const auto& [seed, offset] = std::get<0>(drop_seed_offset);
446  kargs.init_dropout(p_drop, seed, offset);
447  }
448  else // seed & offset come from device
449  {
450  const auto& [seed_ptr, offset_ptr] = std::get<1>(drop_seed_offset);
451  kargs.init_dropout(p_drop,
452  reinterpret_cast<const uint64_t*>(seed_ptr),
453  reinterpret_cast<const uint64_t*>(offset_ptr));
454  }
455 
456  kargs.rand_val_ptr = rand_val_ptr;
457  kargs.stride_randval = stride_randval;
458  kargs.nhead_stride_randval = nhead_stride_randval;
459  kargs.batch_stride_randval = batch_stride_randval;
460  kargs.is_store_randval = s_randval;
461  }
462  if constexpr(kHasLogitsSoftCap)
463  {
464  kargs.init_logits_soft_cap(logits_soft_cap);
465  }
466 
467  return kargs;
468  }
469 
470  template <bool Cond = kIsGroupMode>
471  CK_TILE_HOST static constexpr std::enable_if_t<Cond, Kargs>
472  MakeKargs(const void* q_ptr,
473  const void* k_ptr,
474  const void* v_ptr,
475  const void* bias_ptr,
476  void* rand_val_ptr,
477  void* lse_ptr,
478  void* o_ptr,
479  const void* seqstart_q_ptr,
480  ck_tile::index_t hdim_q,
481  ck_tile::index_t hdim_v,
482  ck_tile::index_t num_head_q,
483  ck_tile::index_t nhead_ratio_qk,
484  int32_t num_total_pages,
485  const void* kv_indptr,
486  const void* kv_page_indices,
487 #if 0 // we assume page_block_size=1 for now
488  const void* kv_last_page_lens,
489  ck_tile::index_t page_block_size,
490 #endif
491  float scale_s,
492  float scale_p,
493  float scale_o,
494  float logits_soft_cap,
495  ck_tile::index_t stride_q,
496  ck_tile::index_t stride_k,
497  ck_tile::index_t stride_v,
498  ck_tile::index_t stride_bias,
499  ck_tile::index_t stride_randval,
500  ck_tile::index_t stride_o,
501  ck_tile::index_t nhead_stride_q,
502  ck_tile::index_t nhead_stride_k,
503  ck_tile::index_t nhead_stride_v,
504  ck_tile::index_t nhead_stride_bias,
505  ck_tile::index_t nhead_stride_randval,
506  ck_tile::index_t nhead_stride_lse,
507  ck_tile::index_t nhead_stride_o,
508  ck_tile::index_t batch_stride_k,
509  ck_tile::index_t batch_stride_v,
510  ck_tile::index_t window_size_left,
511  ck_tile::index_t window_size_right,
512  ck_tile::index_t mask_type,
513  float p_drop,
514  bool s_randval,
515  std::variant<std::pair<uint64_t, uint64_t>, std::pair<const void*, const void*>>
516  drop_seed_offset)
517  {
518  Kargs kargs{{q_ptr,
519  k_ptr,
520  v_ptr,
521  o_ptr,
522  -1, // seqlen will be updated by another pointer
523  -1, //
524  hdim_q,
525  hdim_v,
526  num_head_q,
527  nhead_ratio_qk,
528  num_total_pages,
529  reinterpret_cast<const int32_t*>(kv_indptr),
530  reinterpret_cast<const int32_t*>(kv_page_indices),
531 #if 0 // we assume page_block_size=1 for now
532  reinterpret_cast<const int32_t*>(kv_last_page_lens),
533  page_block_size,
534 #endif
536  static_cast<float>(scale_s * ck_tile::log2e_v<>),
537 #else
538  scale_s,
539 #endif
540  stride_q,
541  stride_k,
542  stride_v,
543  stride_o,
544  nhead_stride_q,
545  nhead_stride_k,
546  nhead_stride_v,
547  nhead_stride_o}, // args for common karg
548  {}, // placeholder for bias
549  {}, // placeholder for mask
550  {}, // placeholder for lse
551  {}, // placeholder for fp8_static_quant args
552  {}, // placeholder for dropout
553  {}, // placeholder for logits_soft_cap
554  reinterpret_cast<const int32_t*>(seqstart_q_ptr),
555  batch_stride_k,
556  batch_stride_v};
557 
559  {
560  kargs.bias_ptr = bias_ptr;
561  kargs.stride_bias = stride_bias;
562  kargs.nhead_stride_bias = nhead_stride_bias;
563  }
564  else if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI)
565  {
566  kargs.alibi_slope_ptr = bias_ptr;
567  kargs.alibi_slope_stride = stride_bias;
568  }
569  if constexpr(kHasMask)
570  {
571  kargs.window_size_left = window_size_left;
572  kargs.window_size_right = window_size_right;
573  kargs.mask_type = static_cast<ck_tile::GenericAttentionMaskEnum>(mask_type);
574  }
575  if constexpr(kStoreLSE)
576  {
577  kargs.lse_ptr = lse_ptr;
578  kargs.nhead_stride_lse = nhead_stride_lse;
579  }
580  if constexpr(kDoFp8StaticQuant)
581  {
582  kargs.scale_p = scale_p;
583  kargs.scale_o = scale_o;
584  }
585  if constexpr(kHasDropout)
586  {
587  if(drop_seed_offset.index() == 0) // seed & offset come from host
588  {
589  const auto& [seed, offset] = std::get<0>(drop_seed_offset);
590  kargs.init_dropout(p_drop, seed, offset);
591  }
592  else // seed & offset come from device
593  {
594  const auto& [seed_ptr, offset_ptr] = std::get<1>(drop_seed_offset);
595  kargs.init_dropout(p_drop,
596  reinterpret_cast<const uint64_t*>(seed_ptr),
597  reinterpret_cast<const uint64_t*>(offset_ptr));
598  }
599 
600  kargs.rand_val_ptr = rand_val_ptr;
601  kargs.stride_randval = stride_randval;
602  kargs.nhead_stride_randval = nhead_stride_randval;
603  kargs.is_store_randval = s_randval;
604  }
605  if constexpr(kHasLogitsSoftCap)
606  {
607  kargs.init_logits_soft_cap(logits_soft_cap);
608  }
609 
610  return kargs;
611  }
612 
613  CK_TILE_HOST static constexpr auto GridSize(ck_tile::index_t batch_size_,
614  ck_tile::index_t nhead_,
615  ck_tile::index_t seqlen_q_,
616  ck_tile::index_t hdim_v_)
617  {
618  if constexpr(kIsGroupMode)
619  {
620  // TODO: this may need tuning
621  return dim3(nhead_,
622  batch_size_,
623  ck_tile::integer_divide_ceil(seqlen_q_, FmhaPipeline::kM0) *
624  ck_tile::integer_divide_ceil(hdim_v_, FmhaPipeline::kN1));
625  }
626  else
627  {
628  // TODO: this may need tuning
629  return dim3(ck_tile::integer_divide_ceil(seqlen_q_, FmhaPipeline::kM0) *
630  ck_tile::integer_divide_ceil(hdim_v_, FmhaPipeline::kN1),
631  nhead_,
632  batch_size_);
633  }
634  }
635 
636  CK_TILE_DEVICE static constexpr auto GetTileIndex(const Kargs& kargs)
637  {
638  if constexpr(kIsGroupMode)
639  {
640  // const index_t num_tile_m0 = seqlen_q / kM0;
641  const index_t num_tile_n1 =
642  ck_tile::integer_divide_ceil(kargs.hdim_v, FmhaPipeline::kN1);
643 
644  const index_t i_block = blockIdx.z;
645  const index_t i_nhead = blockIdx.x;
646  const index_t i_batch = blockIdx.y;
647 
648  const auto f = [](index_t dividend, index_t divisor) {
649  index_t quotient = dividend / divisor;
650  index_t modulus = dividend - quotient * divisor;
651  return ck_tile::make_tuple(quotient, modulus);
652  };
653 
654  const auto [i_tile_m, i_tile_n] = f(i_block, num_tile_n1);
655  if constexpr(kHasMask)
656  {
657  // assume that num_tile_n1 is always 1
658  return ck_tile::make_tuple(gridDim.z - 1 - i_tile_m, i_tile_n, i_nhead, i_batch);
659  }
660  else
661  {
662  return ck_tile::make_tuple(i_tile_m, i_tile_n, i_nhead, i_batch);
663  }
664  }
665  else
666  {
667  // const index_t num_tile_m0 = seqlen_q / kM0;
668  const index_t num_tile_n1 =
669  ck_tile::integer_divide_ceil(kargs.hdim_v, FmhaPipeline::kN1);
670 
671  const index_t i_block = blockIdx.x;
672  const index_t i_nhead = blockIdx.y;
673  const index_t i_batch = blockIdx.z;
674 
675  const auto f = [](index_t dividend, index_t divisor) {
676  index_t quotient = dividend / divisor;
677  index_t modulus = dividend - quotient * divisor;
678  return ck_tile::make_tuple(quotient, modulus);
679  };
680 
681  const auto [i_tile_m, i_tile_n] = f(i_block, num_tile_n1);
682 
683  if constexpr(kHasMask)
684  {
685  // assume that num_tile_n1 is always 1
686  return ck_tile::make_tuple(gridDim.x - 1 - i_tile_m, i_tile_n, i_nhead, i_batch);
687  }
688  else
689  {
690  return ck_tile::make_tuple(i_tile_m, i_tile_n, i_nhead, i_batch);
691  }
692  }
693  }
694 
695  CK_TILE_HOST static dim3 BlockSize()
696  {
697  if(is_wave32())
698  {
699  return dim3(kBlockSize / 2);
700  }
701  else
702  {
703  return dim3(kBlockSize);
704  }
705  }
706 
708  {
709  return ck_tile::max(FmhaPipeline::GetSmemSize(), EpiloguePipeline::GetSmemSize());
710  }
711 
712  CK_TILE_DEVICE void operator()(Kargs kargs) const
713  {
714  // allocate LDS
715  __shared__ char smem_ptr[GetSmemSize()];
716 
717  // divide problem
718  const auto [i_tile_m, i_tile_n, i_nhead, i_batch] = GetTileIndex(kargs);
719 
720  const index_t i_m0 = amd_wave_read_first_lane(i_tile_m * FmhaPipeline::kM0);
721  const index_t i_n1 = amd_wave_read_first_lane(i_tile_n * FmhaPipeline::kN1);
722 
723  long_index_t batch_offset_q = 0;
724  long_index_t batch_offset_bias = 0;
725  long_index_t batch_offset_randval = 0;
726  long_index_t batch_offset_lse = 0;
727  long_index_t batch_offset_o = 0;
728 
729  const int32_t num_page_blocks = kargs.kv_indptr[i_batch + 1] - kargs.kv_indptr[i_batch];
730 #if 0 // we assume page_block_size=1 for now
731  const int32_t last_page_len = kargs.kv_last_page_lens[i_batch];
732 #endif
733  if constexpr(kIsGroupMode)
734  {
735  // get starting offset for each batch
736  const long_index_t query_start = kargs.seqstart_q_ptr[i_batch];
737 
738  batch_offset_q = query_start * kargs.stride_q;
739 
740  kargs.kv_page_indices += kargs.kv_indptr[i_batch];
741 
743  {
744  batch_offset_bias = query_start * kargs.stride_bias;
745  }
746  if constexpr(kStoreLSE)
747  {
748  batch_offset_lse = query_start;
749  }
750  if constexpr(kHasDropout)
751  {
752  batch_offset_randval = query_start * kargs.stride_randval;
753  }
754  batch_offset_o = query_start * kargs.stride_o;
755 
756  // get real # queries & # keys under group mode
757  kargs.seqlen_q = kargs.seqstart_q_ptr[i_batch + 1] - query_start;
758 
759  // # of required blocks is different in each groups, terminate unnecessary blocks
760  // earlier
761  if(kargs.seqlen_q <= i_m0)
762  {
763  return;
764  }
765 
766 #if 0 // we assume page_block_size=1 for now
767  kargs.seqlen_k = (num_page_blocks - 1) * kargs.page_block_size + last_page_len;
768 #else
769  kargs.seqlen_k = num_page_blocks;
770 #endif
771  }
772  else
773  {
774  batch_offset_q = static_cast<long_index_t>(i_batch) * kargs.batch_stride_q;
775 
776  kargs.kv_page_indices += kargs.kv_indptr[i_batch];
777 
779  {
780  batch_offset_bias = static_cast<long_index_t>(i_batch) * kargs.batch_stride_bias;
781  }
782  if constexpr(kStoreLSE)
783  {
784  batch_offset_lse = static_cast<long_index_t>(i_batch) * kargs.batch_stride_lse;
785  }
786  if constexpr(kHasDropout)
787  {
788  batch_offset_randval =
789  static_cast<long_index_t>(i_batch) * kargs.batch_stride_randval;
790  }
791  batch_offset_o = static_cast<long_index_t>(i_batch) * kargs.batch_stride_o;
792 
793 #if 0 // we assume page_block_size=1 for now
794  kargs.seqlen_k = (num_page_blocks - 1) * kargs.page_block_size + last_page_len;
795 #else
796  kargs.seqlen_k = num_page_blocks;
797 #endif
798  }
799 
800  // for simplicity, batch stride we just modify the pointer
801  const QDataType* q_ptr = reinterpret_cast<const QDataType*>(kargs.q_ptr) +
802  static_cast<long_index_t>(i_nhead) * kargs.nhead_stride_q +
803  batch_offset_q;
804  const KDataType* k_ptr =
805  reinterpret_cast<const KDataType*>(kargs.k_ptr) +
806  static_cast<long_index_t>(i_nhead / kargs.nhead_ratio_qk) * kargs.nhead_stride_k;
807  const VDataType* v_ptr =
808  reinterpret_cast<const VDataType*>(kargs.v_ptr) +
809  static_cast<long_index_t>(i_nhead / kargs.nhead_ratio_qk) * kargs.nhead_stride_v;
810  ODataType* o_ptr = reinterpret_cast<ODataType*>(kargs.o_ptr) +
811  static_cast<long_index_t>(i_nhead) * kargs.nhead_stride_o +
812  batch_offset_o;
813 
814  // Q/K/V DRAM and DRAM window
815  const auto q_dram = [&]() {
816  const auto q_dram_naive = make_naive_tensor_view<address_space_enum::global>(
817  q_ptr,
818  make_tuple(kargs.seqlen_q, kargs.hdim_q),
819  make_tuple(kargs.stride_q, 1),
821  number<1>{});
822  if constexpr(FmhaPipeline::kQLoadOnce)
823  {
824  return pad_tensor_view(
825  q_dram_naive,
828  }
829  else
830  {
831  return pad_tensor_view(
832  q_dram_naive,
835  }
836  }();
837  const auto k_dram = [&]() {
838  const auto k_dram_naive = make_naive_tensor_view<address_space_enum::global>(
839  k_ptr,
840  make_tuple(kargs.num_total_pages * kargs.page_block_size, kargs.hdim_q),
841  make_tuple(kargs.stride_k, 1),
843  number<1>{});
844 
845  constexpr bool kPadSeqLenK_ = kUseAsyncCopy ? kPadSeqLenK : true;
846  return pad_tensor_view(
847  k_dram_naive,
850  }();
851  const auto v_dram = [&]() {
852  if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
853  {
854  const auto v_dram_naive = make_naive_tensor_view<address_space_enum::global>(
855  v_ptr,
856  make_tuple(kargs.num_total_pages * kargs.page_block_size, kargs.hdim_v),
857  make_tuple(kargs.stride_v, 1),
859  number<1>{});
860 
861  const auto v_dram_transposed = transform_tensor_view(
862  v_dram_naive,
863  make_tuple(
864  make_pass_through_transform(kargs.hdim_v),
865  make_pass_through_transform(kargs.num_total_pages * kargs.page_block_size)),
868 
869  constexpr bool kPadSeqLenK_ = kUseAsyncCopy ? kPadSeqLenK : true;
870  return pad_tensor_view(
871  v_dram_transposed,
874  }
875  else
876  {
877  const auto v_dram_naive = make_naive_tensor_view<address_space_enum::global>(
878  v_ptr,
879  make_tuple(kargs.hdim_v, kargs.num_total_pages * kargs.page_block_size),
880  make_tuple(kargs.stride_v, 1),
882  number<1>{});
883 
884  constexpr bool kPadHeadDimV_ = kUseAsyncCopy ? kPadHeadDimV : false;
885  return pad_tensor_view(
886  v_dram_naive,
889  }
890  }();
891 
892  auto q_dram_window = make_tile_window(
893  q_dram,
894  [&]() {
895  if constexpr(FmhaPipeline::kQLoadOnce)
898  else
900  }(),
901  {i_m0, 0});
902 
903  auto k_dram_window = make_tile_window(
905 
906  auto v_dram_window =
907  make_tile_window(v_dram,
909  {i_n1, 0});
912  const auto bias_dram_window = [&, i_nhead_ = i_nhead]() {
913  constexpr auto bias_dram_window_lengths =
916  {
917  const BiasDataType* bias_ptr =
918  reinterpret_cast<const BiasDataType*>(kargs.bias_ptr) +
919  static_cast<long_index_t>(i_nhead_) * kargs.nhead_stride_bias +
920  batch_offset_bias;
921 
922  const auto bias_dram = [&]() {
923  const auto bias_dram_naive = make_naive_tensor_view<address_space_enum::global>(
924  bias_ptr,
925  make_tuple(kargs.seqlen_q, kargs.seqlen_k),
926  make_tuple(kargs.stride_bias, 1),
928  number<1>{});
929 
930  return pad_tensor_view(bias_dram_naive,
931  bias_dram_window_lengths,
933  }();
934 
935  return make_tile_window(bias_dram, bias_dram_window_lengths, {i_m0, 0});
936  }
937  else
938  {
939  return make_null_tile_window(bias_dram_window_lengths);
940  }
941  }();
942 
943  // lse
944  auto lse_dram_window = [&, i_nhead_ = i_nhead]() {
945  constexpr auto lse_dram_window_lengths = make_tuple(number<FmhaPipeline::kM0>{});
946  if constexpr(kStoreLSE)
947  {
948  LSEDataType* lse_ptr =
949  reinterpret_cast<LSEDataType*>(kargs.lse_ptr) +
950  static_cast<long_index_t>(i_nhead_) * kargs.nhead_stride_lse + batch_offset_lse;
951 
952  const auto lse_dram = [&]() {
953  const auto lse_dram_naive = make_naive_tensor_view<address_space_enum::global>(
954  lse_ptr,
955  make_tuple(kargs.seqlen_q),
956  make_tuple(1),
957  number<1>{},
958  number<1>{});
959 
960  return pad_tensor_view(
961  lse_dram_naive, lse_dram_window_lengths, sequence<kPadSeqLenQ>{});
962  }();
963 
964  return make_tile_window(lse_dram, lse_dram_window_lengths, {i_m0});
965  }
966  else
967  {
968  return make_null_tile_window(lse_dram_window_lengths);
969  }
970  }();
971 
972  auto dropout = [&, i_nhead_ = i_nhead, i_batch_ = i_batch]() {
973  if constexpr(kHasDropout)
974  {
975  return BlockDropout{i_batch_,
976  i_nhead_,
977  kargs.num_head_q,
978  kargs.is_drop_seed_offset_from_host ? kargs.drop_seed.val
979  : *kargs.drop_seed.ptr,
980  kargs.is_drop_seed_offset_from_host ? kargs.drop_offset.val
981  : *kargs.drop_offset.ptr,
982  kargs.rp_undrop,
983  kargs.p_undrop_in_uint8_t,
984  kargs.is_store_randval};
985  }
986  else
987  {
988  return NullBlockDropout{};
989  };
990  }();
991 
992  auto randval_dram_window = [&, i_nhead_ = i_nhead]() {
993  constexpr auto randval_dram_window_lengths =
995  if constexpr(kHasDropout)
996  {
997  RandValOutputDataType* rand_val_ptr =
998  reinterpret_cast<RandValOutputDataType*>(kargs.rand_val_ptr) +
999  static_cast<long_index_t>(i_nhead_) * kargs.nhead_stride_randval +
1000  batch_offset_randval;
1001 
1002  const auto randval_dram = [&]() {
1003  const auto randval_dram_naive =
1004  make_naive_tensor_view<address_space_enum::global>(
1005  rand_val_ptr,
1006  make_tuple(kargs.seqlen_q, kargs.seqlen_k),
1007  make_tuple(kargs.stride_randval, 1),
1008  number<1>{},
1009  number<1>{});
1010 
1011  return pad_tensor_view(randval_dram_naive,
1012  randval_dram_window_lengths,
1014  }();
1015 
1016  return make_tile_window(randval_dram, randval_dram_window_lengths, {i_m0, 0});
1017  }
1018  else
1019  {
1020  return make_null_tile_window(randval_dram_window_lengths);
1021  }
1022  }();
1023 
1024  FmhaMask mask = [&]() {
1025  if constexpr(kHasMask)
1026  return ck_tile::make_generic_attention_mask_from_lr_window<FmhaMask>(
1027  kargs.window_size_left,
1028  kargs.window_size_right,
1029  kargs.seqlen_q,
1030  kargs.seqlen_k,
1032  else
1033  return FmhaMask{kargs.seqlen_q, kargs.seqlen_k};
1034  }();
1035 
1036  // WA i_batch capture structure binding before c++20
1037  auto position_encoding = [&, i_batch_ = i_batch, i_nhead_ = i_nhead]() {
1038  if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI)
1039  {
1040  // data loading, shared by entire wg
1041  // TODO: how to use s_read?
1042  SaccDataType slope =
1043  *(reinterpret_cast<const SaccDataType*>(kargs.alibi_slope_ptr) +
1044  i_batch_ * kargs.alibi_slope_stride + i_nhead_);
1045 #if CK_TILE_FMHA_FWD_FAST_EXP2
1046  slope *= ck_tile::log2e_v<>;
1047 #endif
1048  if constexpr(kHasMask)
1049  {
1050  return make_alibi_from_lr_mask<SaccDataType, true>(slope,
1051  kargs.window_size_left,
1052  kargs.window_size_right,
1053  kargs.seqlen_q,
1054  kargs.seqlen_k,
1055  kargs.mask_type);
1056  }
1057  else
1058  {
1060  slope, kargs.seqlen_q, kargs.seqlen_k, AlibiMode::FROM_BOTTOM_RIGHT};
1061  }
1062  }
1063  else
1064  {
1066  }
1067  }();
1068 
1069  AttentionVariant variant;
1070  const auto variant_params = [&] {
1071  if constexpr(kHasLogitsSoftCap)
1072  {
1074  mask, kargs.scale_s, kargs.logits_soft_cap, kargs.logits_soft_cap_rcp};
1075  }
1076  else
1077  {
1078  return ck_tile::StandardAttentionParams<FmhaMask>{mask, kargs.scale_s};
1079  }
1080  }();
1081 
1082  BlockIndices block_indices{i_batch, i_nhead, i_nhead / kargs.nhead_ratio_qk};
1083 
1084  auto o_acc_tile = [&]() {
1085  if constexpr(kDoFp8StaticQuant)
1086  {
1087  return FmhaPipeline{}(
1088  q_dram_window,
1089  identity{}, // q_element_func
1090  k_dram_window,
1091  identity{}, // k_element_func
1092  v_dram_window,
1093  identity{}, // v_element_func
1094  bias_dram_window,
1095  identity{}, // bias_element_func
1096  randval_dram_window,
1097  lse_dram_window,
1098  identity{}, // lse_element_func
1099  identity{}, // s_acc_element_func
1100  scales{kargs.scale_p}, // p_compute_element_func
1101  composes(saturates<fp8_t>{}, scales{kargs.scale_o}), // o_acc_element_func
1102  mask,
1103  position_encoding,
1104  kargs.scale_s,
1105  variant,
1106  variant_params,
1107  block_indices,
1108  smem_ptr,
1109  kargs.kv_page_indices,
1110  kargs.stride_k,
1111  kargs.stride_v,
1112  dropout);
1113  }
1114  else
1115  {
1116  return FmhaPipeline{}(q_dram_window,
1117  k_dram_window,
1118  v_dram_window,
1119  bias_dram_window,
1120  randval_dram_window,
1121  lse_dram_window,
1122  mask,
1123  position_encoding,
1124  kargs.scale_s,
1125  variant,
1126  variant_params,
1127  block_indices,
1128  smem_ptr,
1129  kargs.kv_page_indices,
1130  kargs.stride_k,
1131  kargs.stride_v,
1132  dropout);
1133  }
1134  }();
1135 
1136  // O DRAM and O DRAM window
1137  auto o_dram = [&]() {
1138  const auto o_dram_naive = make_naive_tensor_view<address_space_enum::global>(
1139  o_ptr,
1140  make_tuple(kargs.seqlen_q, kargs.hdim_v),
1141  make_tuple(kargs.stride_o, 1),
1143  number<1>{});
1144 
1145  return pad_tensor_view(
1146  o_dram_naive,
1149  }();
1150 
1151  auto o_dram_window =
1152  make_tile_window(o_dram,
1154  {i_m0, i_n1});
1155 
1156  EpiloguePipeline{}(o_dram_window, o_acc_tile, nullptr);
1157  }
1158 };
1159 
1160 } // namespace ck_tile
#define CK_TILE_DEVICE
Definition: config.hpp:41
#define CK_TILE_HOST
Definition: config.hpp:40
#define CK_TILE_FMHA_FWD_FAST_EXP2
Definition: config.hpp:234
#define CK_TILE_HOST_DEVICE
Definition: config.hpp:42
#define _TS_
#define _SS_
__host__ constexpr __device__ T max(T x)
Definition: math.hpp:84
__host__ T floor(T x)
Definition: math_v2.hpp:367
Definition: cluster_descriptor.hpp:13
constexpr CK_TILE_DEVICE auto make_null_tile_window(const WindowLengths &window_lengths)
Definition: null_tile_window.hpp:66
_BitInt(8) fp8_t
Definition: float8.hpp:204
constexpr CK_TILE_HOST_DEVICE auto integer_divide_ceil(X x, Y y)
Definition: math.hpp:149
__device__ uint32_t amd_wave_read_first_lane(uint16_t v)
Definition: amd_buffer_addressing.hpp:35
_Float16 fp16_t
Definition: half.hpp:110
constexpr CK_TILE_HOST_DEVICE auto transform_tensor_view(const OldTensorView &old_tensor_view, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition: tensor_view.hpp:511
bfloat16_t bf16_t
Definition: bfloat16.hpp:113
int32_t index_t
Definition: integer.hpp:9
constexpr CK_TILE_HOST_DEVICE auto pad_tensor_view(const TensorView &tensor_view, const TileLengths &tile_lengths, DoPads)
Definition: tensor_view.hpp:530
constexpr CK_TILE_HOST_DEVICE auto make_pass_through_transform(const LowLength &low_length)
Definition: coordinate_transform.hpp:1558
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition: type_traits.hpp:21
int64_t long_index_t
Definition: integer.hpp:11
int32_t int32_t
Definition: integer.hpp:10
constexpr CK_TILE_DEVICE auto make_tile_window(null_tensor_view, const WindowLengths &window_lengths, const multi_index< WindowLengths::size()> &, Ts &&...)
Definition: null_tile_window.hpp:75
unsigned _BitInt(8) bf8_t
Definition: float8.hpp:206
constexpr CK_TILE_HOST_DEVICE auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:360
GenericAttentionMaskEnum
Definition: block_masking.hpp:11
__host__ __device__ composes(Ts &&...) -> composes< remove_cvref_t< Ts >... >
FIXME: create macro to replace 'host device' and nothing more.
constexpr CK_TILE_HOST_DEVICE T max(T x)
Definition: math.hpp:161
typename conditional< predicate, X, Y >::type conditional_t
Definition: functional.hpp:115
unsigned char uint8_t
Definition: stdint.h:124
unsigned __int64 uint64_t
Definition: stdint.h:136
Definition: block_position_encoding.hpp:48
Definition: block_attention_bias_enum.hpp:19
Definition: block_dropout.hpp:53
const float rp_undrop
Definition: block_dropout.hpp:377
Definition: block_position_encoding.hpp:137
Definition: fmha_batch_prefill_kernel.hpp:312
ck_tile::index_t kv_head_idx
Definition: fmha_batch_prefill_kernel.hpp:315
ck_tile::index_t qo_head_idx
Definition: fmha_batch_prefill_kernel.hpp:314
ck_tile::index_t batch_idx
Definition: fmha_batch_prefill_kernel.hpp:313
Definition: fmha_batch_prefill_kernel.hpp:192
ck_tile::index_t alibi_slope_stride
Definition: fmha_batch_prefill_kernel.hpp:195
const void * alibi_slope_ptr
Definition: fmha_batch_prefill_kernel.hpp:194
ck_tile::index_t batch_stride_bias
Definition: fmha_batch_prefill_kernel.hpp:188
ck_tile::index_t batch_stride_randval
Definition: fmha_batch_prefill_kernel.hpp:269
ck_tile::index_t batch_stride_o
Definition: fmha_batch_prefill_kernel.hpp:288
ck_tile::index_t batch_stride_v
Definition: fmha_batch_prefill_kernel.hpp:287
ck_tile::index_t batch_stride_q
Definition: fmha_batch_prefill_kernel.hpp:285
ck_tile::index_t batch_stride_k
Definition: fmha_batch_prefill_kernel.hpp:286
ck_tile::index_t nhead_stride_bias
Definition: fmha_batch_prefill_kernel.hpp:183
ck_tile::index_t stride_bias
Definition: fmha_batch_prefill_kernel.hpp:182
const void * bias_ptr
Definition: fmha_batch_prefill_kernel.hpp:181
ck_tile::index_t stride_randval
Definition: fmha_batch_prefill_kernel.hpp:263
void init_dropout(float p_drop, const uint64_t *seed_ptr, const uint64_t *offset_ptr)
Definition: fmha_batch_prefill_kernel.hpp:246
ck_tile::index_t nhead_stride_randval
Definition: fmha_batch_prefill_kernel.hpp:264
void init_dropout(float p_drop, uint64_t seed, uint64_t offset)
Definition: fmha_batch_prefill_kernel.hpp:234
void * rand_val_ptr
Definition: fmha_batch_prefill_kernel.hpp:261
float rp_undrop
Definition: fmha_batch_prefill_kernel.hpp:258
bool is_store_randval
Definition: fmha_batch_prefill_kernel.hpp:260
uint8_t p_undrop_in_uint8_t
Definition: fmha_batch_prefill_kernel.hpp:259
Definition: fmha_batch_prefill_kernel.hpp:118
ck_tile::index_t stride_q
Definition: fmha_batch_prefill_kernel.hpp:146
ck_tile::index_t stride_v
Definition: fmha_batch_prefill_kernel.hpp:148
int32_t num_total_pages
Definition: fmha_batch_prefill_kernel.hpp:134
float scale_s
Definition: fmha_batch_prefill_kernel.hpp:144
ck_tile::index_t seqlen_q
Definition: fmha_batch_prefill_kernel.hpp:124
ck_tile::index_t stride_k
Definition: fmha_batch_prefill_kernel.hpp:147
ck_tile::index_t nhead_stride_o
Definition: fmha_batch_prefill_kernel.hpp:154
ck_tile::index_t nhead_stride_k
Definition: fmha_batch_prefill_kernel.hpp:152
ck_tile::index_t nhead_ratio_qk
Definition: fmha_batch_prefill_kernel.hpp:132
ck_tile::index_t nhead_stride_v
Definition: fmha_batch_prefill_kernel.hpp:153
ck_tile::index_t nhead_stride_q
Definition: fmha_batch_prefill_kernel.hpp:151
const int32_t * kv_page_indices
Definition: fmha_batch_prefill_kernel.hpp:136
const void * v_ptr
Definition: fmha_batch_prefill_kernel.hpp:121
const int32_t * kv_indptr
Definition: fmha_batch_prefill_kernel.hpp:135
void * o_ptr
Definition: fmha_batch_prefill_kernel.hpp:122
ck_tile::index_t seqlen_k
Definition: fmha_batch_prefill_kernel.hpp:125
ck_tile::index_t stride_o
Definition: fmha_batch_prefill_kernel.hpp:149
ck_tile::index_t hdim_v
Definition: fmha_batch_prefill_kernel.hpp:127
ck_tile::index_t num_head_q
Definition: fmha_batch_prefill_kernel.hpp:129
static constexpr ck_tile::index_t page_block_size
Definition: fmha_batch_prefill_kernel.hpp:141
const void * k_ptr
Definition: fmha_batch_prefill_kernel.hpp:120
ck_tile::index_t hdim_q
Definition: fmha_batch_prefill_kernel.hpp:126
const void * q_ptr
Definition: fmha_batch_prefill_kernel.hpp:119
ck_tile::index_t batch_stride_lse
Definition: fmha_batch_prefill_kernel.hpp:215
ck_tile::index_t nhead_stride_lse
Definition: fmha_batch_prefill_kernel.hpp:214
void * lse_ptr
Definition: fmha_batch_prefill_kernel.hpp:213
bool is_drop_seed_offset_from_host
Definition: fmha_batch_prefill_kernel.hpp:229
ValueOrPointer< uint64_t > drop_seed
Definition: fmha_batch_prefill_kernel.hpp:227
ValueOrPointer< uint64_t > drop_offset
Definition: fmha_batch_prefill_kernel.hpp:228
Definition: fmha_batch_prefill_kernel.hpp:111
float scale_p
Definition: fmha_batch_prefill_kernel.hpp:207
float scale_o
Definition: fmha_batch_prefill_kernel.hpp:208
ck_tile::index_t batch_stride_v
Definition: fmha_batch_prefill_kernel.hpp:306
ck_tile::index_t batch_stride_k
Definition: fmha_batch_prefill_kernel.hpp:305
const int32_t * seqstart_q_ptr
Definition: fmha_batch_prefill_kernel.hpp:304
float logits_soft_cap_rcp
Definition: fmha_batch_prefill_kernel.hpp:176
void init_logits_soft_cap(float logits_soft_cap_)
Definition: fmha_batch_prefill_kernel.hpp:161
float logits_soft_cap
Definition: fmha_batch_prefill_kernel.hpp:175
Definition: fmha_batch_prefill_kernel.hpp:199
ck_tile::index_t window_size_right
Definition: fmha_batch_prefill_kernel.hpp:201
ck_tile::index_t window_size_left
Definition: fmha_batch_prefill_kernel.hpp:201
ck_tile::GenericAttentionMaskEnum mask_type
Definition: fmha_batch_prefill_kernel.hpp:202
Definition: fmha_batch_prefill_kernel.hpp:64
Definition: fmha_batch_prefill_kernel.hpp:26
static constexpr CK_TILE_DEVICE auto GetTileIndex(const Kargs &kargs)
Definition: fmha_batch_prefill_kernel.hpp:636
static constexpr bool kIsGroupMode
Definition: fmha_batch_prefill_kernel.hpp:47
static constexpr ck_tile::index_t kBlockPerCu
Definition: fmha_batch_prefill_kernel.hpp:31
ck_tile::remove_cvref_t< typename FmhaPipeline::VDataType > VDataType
Definition: fmha_batch_prefill_kernel.hpp:37
ck_tile::remove_cvref_t< FmhaPipeline_ > FmhaPipeline
Definition: fmha_batch_prefill_kernel.hpp:27
ck_tile::remove_cvref_t< typename FmhaPipeline::KDataType > KDataType
Definition: fmha_batch_prefill_kernel.hpp:36
static constexpr bool kPadSeqLenQ
Definition: fmha_batch_prefill_kernel.hpp:48
ck_tile::remove_cvref_t< typename FmhaPipeline::RandValOutputDataType > RandValOutputDataType
Definition: fmha_batch_prefill_kernel.hpp:40
static constexpr bool kDoFp8StaticQuant
Definition: fmha_batch_prefill_kernel.hpp:56
static constexpr bool kPadHeadDimV
Definition: fmha_batch_prefill_kernel.hpp:51
ck_tile::remove_cvref_t< typename FmhaPipeline::LSEDataType > LSEDataType
Definition: fmha_batch_prefill_kernel.hpp:41
ck_tile::remove_cvref_t< typename FmhaPipeline::QDataType > QDataType
Definition: fmha_batch_prefill_kernel.hpp:35
static constexpr bool kHasMask
Definition: fmha_batch_prefill_kernel.hpp:59
static CK_TILE_HOST std::string GetName()
Definition: fmha_batch_prefill_kernel.hpp:72
ck_tile::remove_cvref_t< typename FmhaPipeline::BiasDataType > BiasDataType
Definition: fmha_batch_prefill_kernel.hpp:38
ck_tile::remove_cvref_t< typename FmhaPipeline::FmhaMask > FmhaMask
Definition: fmha_batch_prefill_kernel.hpp:58
static CK_TILE_HOST dim3 BlockSize()
Definition: fmha_batch_prefill_kernel.hpp:695
static constexpr bool kPadSeqLenK
Definition: fmha_batch_prefill_kernel.hpp:49
static constexpr bool kHasLogitsSoftCap
Definition: fmha_batch_prefill_kernel.hpp:52
static constexpr bool kHasDropout
Definition: fmha_batch_prefill_kernel.hpp:55
static constexpr bool kStoreLSE
Definition: fmha_batch_prefill_kernel.hpp:54
ck_tile::remove_cvref_t< typename FmhaPipeline::SaccDataType > SaccDataType
Definition: fmha_batch_prefill_kernel.hpp:43
static constexpr CK_TILE_HOST_DEVICE ck_tile::index_t GetSmemSize()
Definition: fmha_batch_prefill_kernel.hpp:707
static constexpr ck_tile::index_t kBlockSize
Definition: fmha_batch_prefill_kernel.hpp:29
static constexpr CK_TILE_HOST std::enable_if_t< Cond, Kargs > MakeKargs(const void *q_ptr, const void *k_ptr, const void *v_ptr, const void *bias_ptr, void *rand_val_ptr, void *lse_ptr, void *o_ptr, const void *seqstart_q_ptr, ck_tile::index_t hdim_q, ck_tile::index_t hdim_v, ck_tile::index_t num_head_q, ck_tile::index_t nhead_ratio_qk, int32_t num_total_pages, const void *kv_indptr, const void *kv_page_indices, float scale_s, float scale_p, float scale_o, float logits_soft_cap, ck_tile::index_t stride_q, ck_tile::index_t stride_k, ck_tile::index_t stride_v, ck_tile::index_t stride_bias, ck_tile::index_t stride_randval, ck_tile::index_t stride_o, ck_tile::index_t nhead_stride_q, ck_tile::index_t nhead_stride_k, ck_tile::index_t nhead_stride_v, ck_tile::index_t nhead_stride_bias, ck_tile::index_t nhead_stride_randval, ck_tile::index_t nhead_stride_lse, ck_tile::index_t nhead_stride_o, ck_tile::index_t batch_stride_k, ck_tile::index_t batch_stride_v, ck_tile::index_t window_size_left, ck_tile::index_t window_size_right, ck_tile::index_t mask_type, float p_drop, bool s_randval, std::variant< std::pair< uint64_t, uint64_t >, std::pair< const void *, const void * >> drop_seed_offset)
Definition: fmha_batch_prefill_kernel.hpp:472
static constexpr auto BiasEnum
Definition: fmha_batch_prefill_kernel.hpp:53
ck_tile::remove_cvref_t< typename FmhaPipeline::VLayout > VLayout
Definition: fmha_batch_prefill_kernel.hpp:45
ck_tile::remove_cvref_t< typename FmhaPipeline::ODataType > ODataType
Definition: fmha_batch_prefill_kernel.hpp:42
ck_tile::remove_cvref_t< typename FmhaPipeline::AttentionVariant > AttentionVariant
Definition: fmha_batch_prefill_kernel.hpp:57
static constexpr bool kUseAsyncCopy
Definition: fmha_batch_prefill_kernel.hpp:61
static constexpr CK_TILE_HOST std::enable_if_t< Cond, Kargs > MakeKargs(const void *q_ptr, const void *k_ptr, const void *v_ptr, const void *bias_ptr, void *rand_val_ptr, void *lse_ptr, void *o_ptr, ck_tile::index_t seqlen_q, ck_tile::index_t hdim_q, ck_tile::index_t hdim_v, ck_tile::index_t num_head_q, ck_tile::index_t nhead_ratio_qk, int32_t num_total_pages, const void *kv_indptr, const void *kv_page_indices, float scale_s, float scale_p, float scale_o, float logits_soft_cap, ck_tile::index_t stride_q, ck_tile::index_t stride_k, ck_tile::index_t stride_v, ck_tile::index_t stride_bias, ck_tile::index_t stride_randval, ck_tile::index_t stride_o, ck_tile::index_t nhead_stride_q, ck_tile::index_t nhead_stride_k, ck_tile::index_t nhead_stride_v, ck_tile::index_t nhead_stride_bias, ck_tile::index_t nhead_stride_randval, ck_tile::index_t nhead_stride_lse, ck_tile::index_t nhead_stride_o, ck_tile::index_t batch_stride_q, ck_tile::index_t batch_stride_k, ck_tile::index_t batch_stride_v, ck_tile::index_t batch_stride_bias, ck_tile::index_t batch_stride_randval, ck_tile::index_t batch_stride_lse, ck_tile::index_t batch_stride_o, ck_tile::index_t window_size_left, ck_tile::index_t window_size_right, ck_tile::index_t mask_type, float p_drop, bool s_randval, std::variant< std::pair< uint64_t, uint64_t >, std::pair< const void *, const void * >> drop_seed_offset)
Definition: fmha_batch_prefill_kernel.hpp:320
static constexpr ck_tile::index_t kBlockPerCuInput
Definition: fmha_batch_prefill_kernel.hpp:33
static constexpr CK_TILE_HOST auto GridSize(ck_tile::index_t batch_size_, ck_tile::index_t nhead_, ck_tile::index_t seqlen_q_, ck_tile::index_t hdim_v_)
Definition: fmha_batch_prefill_kernel.hpp:613
static constexpr bool kPadHeadDimQ
Definition: fmha_batch_prefill_kernel.hpp:50
CK_TILE_DEVICE void operator()(Kargs kargs) const
Definition: fmha_batch_prefill_kernel.hpp:712
std::conditional_t< kIsGroupMode, FmhaFwdGroupModeKargs, FmhaFwdBatchModeKargs > Kargs
Definition: fmha_batch_prefill_kernel.hpp:309
ck_tile::remove_cvref_t< EpiloguePipeline_ > EpiloguePipeline
Definition: fmha_batch_prefill_kernel.hpp:28
Definition: variants.hpp:63
float logits_soft_cap
Definition: variants.hpp:128
Definition: block_dropout.hpp:39
Definition: variants.hpp:51
Definition: integral_constant.hpp:13
Definition: functional.hpp:86
Definition: coordinate_transform.hpp:1392
Definition: unary_element_function.hpp:56
Definition: math.hpp:28
Definition: sequence.hpp:49