/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp Source File
fmha_fwd_kernel.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include "ck_tile/core.hpp"
7 #include "ck_tile/ops/common.hpp"
10 
11 #include <string>
12 #include <type_traits>
13 #include <utility>
14 #include <variant>
15 
16 #define CK_TILE_FMHA_HANDLE_XOR_LENGTH_FOLD 0
17 // S[seqlen_q, seqlen_k] = Q[seqlen_q, hdim_q] @ K[seqlen_k, hdim_q]
18 // S'[seqlen_q, seqlen_k] = S[seqlen_q, seqlen_k] * Scale[1]
19 // S''[seqlen_q, seqlen_k] = S'[seqlen_q, seqlen_k] + Bias[seqlen_q, seqlen_k]
20 // P[seqlen_q, seqlen_k] = Softmax(S''[seqlen_q, seqlen_k])
21 // O[seqlen_q, hdim_v] = P[seqlen_q, seqlen_k] @ V^T[hdim_v, seqlen_k]
22 
23 namespace ck_tile {
24 
25 template <typename FmhaPipeline_, typename EpiloguePipeline_>
27 {
30  static constexpr ck_tile::index_t kBlockSize = FmhaPipeline::kBlockSize;
31 
32  static constexpr ck_tile::index_t kBlockPerCu = FmhaPipeline::kBlockPerCu;
33  static_assert(kBlockPerCu > 0);
34  static constexpr ck_tile::index_t kBlockPerCuInput = FmhaPipeline::Problem::kBlockPerCu;
35 
45 
47 
48  static constexpr bool kIsGroupMode = FmhaPipeline::kIsGroupMode;
49  static constexpr bool kPadSeqLenQ = FmhaPipeline::kPadSeqLenQ;
50  static constexpr bool kPadSeqLenK = FmhaPipeline::kPadSeqLenK;
51  static constexpr bool kPadHeadDimQ = FmhaPipeline::kPadHeadDimQ;
52  static constexpr bool kPadHeadDimV = FmhaPipeline::kPadHeadDimV;
53  static constexpr bool kHasLogitsSoftCap = FmhaPipeline::kHasLogitsSoftCap;
54  static constexpr auto BiasEnum = FmhaPipeline::BiasEnum;
55  static constexpr bool kStoreLSE = FmhaPipeline::kStoreLSE;
56  static constexpr bool kHasDropout = FmhaPipeline::kHasDropout;
57  static constexpr bool kDoFp8StaticQuant = FmhaPipeline::Problem::kDoFp8StaticQuant;
58  static constexpr bool kSkipMinSeqlenQ = FmhaPipeline::Problem::kSkipMinSeqlenQ;
59 
62  static constexpr bool kHasMask = FmhaMask::IsMasking;
63 
64  static constexpr bool kUseAsyncCopy = FmhaPipeline::Policy::AsyncCopy;
65 
66  static constexpr bool kUseTrLoad = FmhaPipeline::Problem::kUseTrLoad;
67 #if defined(__gfx950__)
68  static constexpr bool kIsAvailable = true;
69 #else
70  static constexpr bool kIsAvailable = !kUseTrLoad;
71 #endif
72  static constexpr std::string_view kPipelineName = FmhaPipeline::name;
73 
74  // clang-format off
75  template <typename T1, typename T2 = T1> struct t2s;
76  template <> struct t2s<float> { static constexpr const char * name = "fp32"; };
77  template <> struct t2s<ck_tile::fp16_t> { static constexpr const char * name = "fp16"; };
78  template <> struct t2s<ck_tile::bf16_t> { static constexpr const char * name = "bf16"; };
79  template <> struct t2s<ck_tile::fp8_t> { static constexpr const char * name = "fp8"; };
80  template <> struct t2s<ck_tile::bf8_t> { static constexpr const char * name = "bf8"; };
81  template <> struct t2s<ck_tile::fp8_t, ck_tile::bf16_t> { static constexpr const char * name = "fp8bf16"; };
82  template <> struct t2s<ck_tile::fp8_t, ck_tile::fp32_t> { static constexpr const char * name = "fp8fp32"; };
83  // clang-format on
84 
85  CK_TILE_HOST static std::string GetName()
86  {
87  // sync with generate.py
88  // clang-format off
89  using bfs = typename FmhaPipeline::BlockFmhaShape;
90  using g0br = typename bfs::Gemm0BlockWarps;
91  using g1br = typename bfs::Gemm1BlockWarps;
92  using g0wt = typename bfs::Gemm0WarpTile;
93  using g1wt = typename bfs::Gemm1WarpTile;
94  #define _SS_ std::string
95  #define _TS_ std::to_string
96  auto pn = [&] () {
97  std::string n;
98  if (kPadSeqLenQ) n += "s";
99  if (kPadSeqLenK) n += "sk";
100  if (kPadHeadDimQ) n += "d";
101  if (kPadHeadDimV) n += "dv";
102  return n.empty() ? n : std::string("p") + n; }();
103  return
104  _SS_("fmha_fwd_d") + _TS_(bfs::kQKHeaddim) + "_" + _SS_(t2s<QDataType, ODataType>::name) +
105  "_" + (kIsGroupMode ? "group" : "batch") + "_"
106  "b" + _TS_(bfs::kM0) + "x" + _TS_(bfs::kN0) + "x" + _TS_(bfs::kK0) + "x" +
107  _TS_(bfs::kN1) + "x" + _TS_(bfs::kK1) + "x" + _TS_(bfs::kQKHeaddim) + "_" +
108  "r" + _TS_(g0br::at(ck_tile::number<0>{})) + "x" + _TS_(g0br::at(ck_tile::number<1>{})) + "x" + _TS_(g0br::at(ck_tile::number<2>{})) + "_" +
109  "r" + _TS_(g1br::at(ck_tile::number<0>{})) + "x" + _TS_(g1br::at(ck_tile::number<1>{})) + "x" + _TS_(g1br::at(ck_tile::number<2>{})) + "_" +
110  "w" + _TS_(g0wt::at(ck_tile::number<0>{})) + "x" + _TS_(g0wt::at(ck_tile::number<1>{})) + "x" + _TS_(g0wt::at(ck_tile::number<2>{})) + "_" +
111  "w" + _TS_(g1wt::at(ck_tile::number<0>{})) + "x" + _TS_(g1wt::at(ck_tile::number<1>{})) + "x" + _TS_(g1wt::at(ck_tile::number<2>{})) + "_" +
112  (kBlockPerCuInput == -1 ? "" : ("o" + _TS_(kBlockPerCu) + "_")) + _SS_(FmhaPipeline::name) + "_" +
113  "v" + (std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor> ? "r" : "c") + (pn.empty() ? "_npad" : "_" + pn) +
114  (kHasLogitsSoftCap ? "_logits" : "_nlogits" ) + (BiasEnum == BlockAttentionBiasEnum::NO_BIAS ? _SS_("_nbias") : (_SS_("_") + BlockAttentionBiasEnumToStr<BiasEnum>::name)) +
115  (kHasMask ? "_" + _SS_(FmhaMask::name) : "_nmask") + (kStoreLSE ? "_lse" : "_nlse" ) + (kHasDropout ? "_dropout" : "_ndropout" ) + (kSkipMinSeqlenQ ? "_skip" : "_nskip" ) + (kDoFp8StaticQuant ? "_squant" : "_nsquant" ) + (kUseTrLoad ? "_trload" : "_ntrload");
116  #undef _SS_
117  #undef _TS_
118  // clang-format on
119  }
120 
121  template <ck_tile::index_t I> // to avoid duplicated base class prblem, introduce an template
122  // arg
124  {
125  };
126 
127  // kargs use aggregate initializer, so no constructor will provided
128  // use inheritance to minimize karg size
129  // user need to use MakeKargs() function to create kargs.
131  {
132  const void* q_ptr;
133  const void* k_ptr;
134  const void* v_ptr;
135  void* o_ptr;
136 
141 
143  // for MQA/GQA, nhead could be different. This parameter is nhead_q / nhead_k
144  // if this param is larger than 1, indicate MQA/GQA case
146  float scale_s;
147 
152 
157  };
158 
160  {
162 
163  void init_logits_soft_cap(float logits_soft_cap_)
164  {
165  if(0 < logits_soft_cap_)
166  {
167  logits_soft_cap = logits_soft_cap_;
169  }
170  else
171  {
172  logits_soft_cap = 0.f;
173  logits_soft_cap_rcp = 0.f;
174  }
175  }
176 
179  };
180 
182  {
183  const void* bias_ptr = nullptr;
186  };
187 
189  {
191  };
192 
194  {
195  // alibi is batch*nhead*1, no matter in batch/group mode, they are the same
196  const void* alibi_slope_ptr;
197  ck_tile::index_t alibi_slope_stride; // stride in batch, or 0 for all batch share same slope
198  };
199 
201  {
202  // ck_tile::index_t window_size_left, window_size_right;
205  };
206 
208  {
209  float scale_p;
210  float scale_o;
211  };
212 
214  {
215  void* lse_ptr = nullptr;
218  };
219 
221  {
222  template <typename T>
224  {
225  T val;
226  const T* ptr;
227  };
228 
232  };
233 
235  {
236  void init_dropout(float p_drop, uint64_t seed, uint64_t offset)
237  {
238  float p_undrop = 1.0 - p_drop;
241  rp_undrop = 1.0 / p_undrop;
242 
243  this->drop_seed.val = seed;
244  this->drop_offset.val = offset;
245  this->is_drop_seed_offset_from_host = true;
246  }
247 
248  void init_dropout(float p_drop, const uint64_t* seed_ptr, const uint64_t* offset_ptr)
249  {
250  float p_undrop = 1.0 - p_drop;
253  rp_undrop = 1.0 / p_undrop;
254 
255  this->drop_seed.ptr = seed_ptr;
256  this->drop_offset.ptr = offset_ptr;
257  this->is_drop_seed_offset_from_host = false;
258  }
259 
260  float rp_undrop = 1;
262  bool is_store_randval = false;
263  void* rand_val_ptr = nullptr;
264 
267  };
268 
270  {
272  };
273 
275  {
277  };
278 
281  std::conditional_t<BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS,
282  FmhaFwdBatchModeBiasKargs,
283  std::conditional_t<BiasEnum == BlockAttentionBiasEnum::ALIBI,
284  FmhaFwdAlibiKargs,
285  FmhaFwdEmptyKargs<0>>>,
286  std::conditional_t<kHasMask, FmhaFwdMaskKargs, FmhaFwdEmptyKargs<1>>,
287  std::conditional_t<kStoreLSE, FmhaFwdCommonLSEKargs, FmhaFwdEmptyKargs<2>>,
288  std::conditional_t<kDoFp8StaticQuant, FmhaFwdFp8StaticQuantKargs, FmhaFwdEmptyKargs<3>>,
289  std::conditional_t<kHasDropout, FmhaFwdBatchModeDropoutKargs, FmhaFwdEmptyKargs<4>>,
290  std::conditional_t<kHasLogitsSoftCap, FmhaFwdLogitsSoftCapKargs, FmhaFwdEmptyKargs<5>>
291  {
296 
297  // Optional cumulative sequence length pointers for batch mode
298  // If provided, they override seqlen_q / seqlen_k per-batch to skip tail padding.
299  const ck_tile::index_t* cu_seqlen_q_ptr = nullptr; // cumulative, length without PAD
300  const ck_tile::index_t* cu_seqlen_kv_ptr = nullptr; // cumulative, length without PAD
301  };
302 
305  std::conditional_t<BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS,
306  FmhaFwdCommonBiasKargs,
307  std::conditional_t<BiasEnum == BlockAttentionBiasEnum::ALIBI,
308  FmhaFwdAlibiKargs,
309  FmhaFwdEmptyKargs<0>>>,
310  std::conditional_t<kHasMask, FmhaFwdMaskKargs, FmhaFwdEmptyKargs<1>>,
311  std::conditional_t<kStoreLSE, FmhaFwdCommonLSEKargs, FmhaFwdEmptyKargs<2>>,
312  std::conditional_t<kDoFp8StaticQuant, FmhaFwdFp8StaticQuantKargs, FmhaFwdEmptyKargs<3>>,
313  std::conditional_t<kHasDropout, FmhaFwdCommonDropoutKargs, FmhaFwdEmptyKargs<4>>,
314  std::conditional_t<kHasLogitsSoftCap, FmhaFwdLogitsSoftCapKargs, FmhaFwdEmptyKargs<5>>,
315  std::conditional_t<kSkipMinSeqlenQ, FmhaFwdSkipMinSeqlenQKargs, FmhaFwdEmptyKargs<6>>
316  {
320 
321  // Optional cumulative padded sequence starts (including PAD tokens)
322  // Used solely to compute memory offsets when sequences are physically padded.
323  const int32_t* seqstart_padded_q_ptr = nullptr;
324  const int32_t* seqstart_padded_k_ptr = nullptr;
325  };
326 
327  using Kargs = std::conditional_t<kIsGroupMode, FmhaFwdGroupModeKargs, FmhaFwdBatchModeKargs>;
328 
330  {
334  };
335 
336  template <bool Cond = !kIsGroupMode>
337  CK_TILE_HOST static constexpr std::enable_if_t<Cond, Kargs>
338  MakeKargsImpl(const void* q_ptr,
339  const void* k_ptr,
340  const void* v_ptr,
341  const void* bias_ptr,
342  void* rand_val_ptr,
343  void* lse_ptr,
344  void* o_ptr,
345  ck_tile::index_t seqlen_q,
346  ck_tile::index_t seqlen_k,
347  ck_tile::index_t hdim_q,
348  ck_tile::index_t hdim_v,
349  ck_tile::index_t num_head_q,
350  ck_tile::index_t nhead_ratio_qk,
351  float scale_s,
352  float scale_p,
353  float scale_o,
354  float logits_soft_cap,
355  ck_tile::index_t stride_q,
356  ck_tile::index_t stride_k,
357  ck_tile::index_t stride_v,
358  ck_tile::index_t stride_bias,
359  ck_tile::index_t stride_randval,
360  ck_tile::index_t stride_o,
361  ck_tile::index_t nhead_stride_q,
362  ck_tile::index_t nhead_stride_k,
363  ck_tile::index_t nhead_stride_v,
364  ck_tile::index_t nhead_stride_bias,
365  ck_tile::index_t nhead_stride_randval,
366  ck_tile::index_t nhead_stride_lse,
367  ck_tile::index_t nhead_stride_o,
368  ck_tile::index_t batch_stride_q,
369  ck_tile::index_t batch_stride_k,
370  ck_tile::index_t batch_stride_v,
371  ck_tile::index_t batch_stride_bias,
372  ck_tile::index_t batch_stride_randval,
373  ck_tile::index_t batch_stride_lse,
374  ck_tile::index_t batch_stride_o,
375  ck_tile::index_t window_size_left,
376  ck_tile::index_t window_size_right,
377  ck_tile::index_t mask_type,
378  float p_drop,
379  bool s_randval,
380  std::variant<std::pair<uint64_t, uint64_t>, std::pair<const void*, const void*>>
381  drop_seed_offset,
382  const ck_tile::index_t* cu_seqlen_q_ptr = nullptr,
383  const ck_tile::index_t* cu_seqlen_kv_ptr = nullptr)
384  {
385  Kargs kargs{{q_ptr,
386  k_ptr,
387  v_ptr,
388  o_ptr,
389  seqlen_q,
390  seqlen_k,
391  hdim_q,
392  hdim_v,
393  num_head_q,
394  nhead_ratio_qk,
395 #if CK_TILE_FMHA_FWD_FAST_EXP2
396  static_cast<float>(scale_s * ck_tile::log2e_v<>),
397 #else
398  scale_s,
399 #endif
400  stride_q,
401  stride_k,
402  stride_v,
403  stride_o,
404  nhead_stride_q,
405  nhead_stride_k,
406  nhead_stride_v,
407  nhead_stride_o}, // args for common karg
408  {}, // placeholder for bias
409  {}, // placeholder for mask
410  {}, // placeholder for lse
411  {}, // placeholder for fp8_static_quant args
412  {}, // placeholder for dropout
413  {}, // placeholder for logits_soft_cap
414  batch_stride_q,
415  batch_stride_k,
416  batch_stride_v,
417  batch_stride_o};
418 
420  {
421  kargs.bias_ptr = bias_ptr;
422  kargs.stride_bias = stride_bias;
423  kargs.nhead_stride_bias = nhead_stride_bias;
424  kargs.batch_stride_bias = batch_stride_bias;
425  }
426  else if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI)
427  {
428  kargs.alibi_slope_ptr = bias_ptr;
429  kargs.alibi_slope_stride = stride_bias;
430  }
431  if constexpr(kHasMask)
432  {
433  kargs.window_size_left = window_size_left;
434  kargs.window_size_right = window_size_right;
435  kargs.mask_type = static_cast<ck_tile::GenericAttentionMaskEnum>(mask_type);
436  }
437  if constexpr(kStoreLSE)
438  {
439  kargs.lse_ptr = lse_ptr;
440  kargs.nhead_stride_lse = nhead_stride_lse;
441  kargs.batch_stride_lse = batch_stride_lse;
442  }
443  if constexpr(kDoFp8StaticQuant)
444  {
445  kargs.scale_p = scale_p;
446  kargs.scale_o = scale_o;
447  }
448  if constexpr(kHasDropout)
449  {
450  if(drop_seed_offset.index() == 0) // seed & offset come from host
451  {
452  const auto& [seed, offset] = std::get<0>(drop_seed_offset);
453  kargs.init_dropout(p_drop, seed, offset);
454  }
455  else // seed & offset come from device
456  {
457  const auto& [seed_ptr, offset_ptr] = std::get<1>(drop_seed_offset);
458  kargs.init_dropout(p_drop,
459  reinterpret_cast<const uint64_t*>(seed_ptr),
460  reinterpret_cast<const uint64_t*>(offset_ptr));
461  }
462 
463  kargs.rand_val_ptr = rand_val_ptr;
464  kargs.stride_randval = stride_randval;
465  kargs.nhead_stride_randval = nhead_stride_randval;
466  kargs.batch_stride_randval = batch_stride_randval;
467  kargs.is_store_randval = s_randval;
468  }
469  if constexpr(kHasLogitsSoftCap)
470  {
471  kargs.init_logits_soft_cap(logits_soft_cap);
472  }
473 
474  kargs.cu_seqlen_q_ptr = cu_seqlen_q_ptr;
475  kargs.cu_seqlen_kv_ptr = cu_seqlen_kv_ptr;
476  return kargs;
477  }
478 
479  // std::variant<> can't take in a list initializer, overload for backward compatibility
480  template <bool Cond = !kIsGroupMode>
481  CK_TILE_HOST static constexpr std::enable_if_t<Cond, Kargs>
482  MakeKargs(const void* q_ptr,
483  const void* k_ptr,
484  const void* v_ptr,
485  const void* bias_ptr,
486  void* rand_val_ptr,
487  void* lse_ptr,
488  void* o_ptr,
489  ck_tile::index_t seqlen_q,
490  ck_tile::index_t seqlen_k,
491  ck_tile::index_t hdim_q,
492  ck_tile::index_t hdim_v,
493  ck_tile::index_t num_head_q,
494  ck_tile::index_t nhead_ratio_qk,
495  float scale_s,
496  float scale_p,
497  float scale_o,
498  float logits_soft_cap,
499  ck_tile::index_t stride_q,
500  ck_tile::index_t stride_k,
501  ck_tile::index_t stride_v,
502  ck_tile::index_t stride_bias,
503  ck_tile::index_t stride_randval,
504  ck_tile::index_t stride_o,
505  ck_tile::index_t nhead_stride_q,
506  ck_tile::index_t nhead_stride_k,
507  ck_tile::index_t nhead_stride_v,
508  ck_tile::index_t nhead_stride_bias,
509  ck_tile::index_t nhead_stride_randval,
510  ck_tile::index_t nhead_stride_lse,
511  ck_tile::index_t nhead_stride_o,
512  ck_tile::index_t batch_stride_q,
513  ck_tile::index_t batch_stride_k,
514  ck_tile::index_t batch_stride_v,
515  ck_tile::index_t batch_stride_bias,
516  ck_tile::index_t batch_stride_randval,
517  ck_tile::index_t batch_stride_lse,
518  ck_tile::index_t batch_stride_o,
519  ck_tile::index_t window_size_left,
520  ck_tile::index_t window_size_right,
521  ck_tile::index_t mask_type,
522  float p_drop,
523  bool s_randval,
524  const std::tuple<uint64_t, uint64_t>& drop_seed_offset,
525  const ck_tile::index_t* cu_seqlen_q_ptr = nullptr,
526  const ck_tile::index_t* cu_seqlen_kv_ptr = nullptr)
527  {
528  return MakeKargsImpl(
529  q_ptr,
530  k_ptr,
531  v_ptr,
532  bias_ptr,
533  rand_val_ptr,
534  lse_ptr,
535  o_ptr,
536  seqlen_q,
537  seqlen_k,
538  hdim_q,
539  hdim_v,
540  num_head_q,
541  nhead_ratio_qk,
542  scale_s,
543  scale_p,
544  scale_o,
545  logits_soft_cap,
546  stride_q,
547  stride_k,
548  stride_v,
549  stride_bias,
550  stride_randval,
551  stride_o,
552  nhead_stride_q,
553  nhead_stride_k,
554  nhead_stride_v,
555  nhead_stride_bias,
556  nhead_stride_randval,
557  nhead_stride_lse,
558  nhead_stride_o,
559  batch_stride_q,
560  batch_stride_k,
561  batch_stride_v,
562  batch_stride_bias,
563  batch_stride_randval,
564  batch_stride_lse,
565  batch_stride_o,
566  window_size_left,
567  window_size_right,
568  mask_type,
569  p_drop,
570  s_randval,
571  std::make_pair(std::get<0>(drop_seed_offset), std::get<1>(drop_seed_offset)),
572  cu_seqlen_q_ptr,
573  cu_seqlen_kv_ptr);
574  }
575 
576  // std::variant<> can't take in a list initializer, overload for backward compatibility
577  template <bool Cond = !kIsGroupMode>
578  CK_TILE_HOST static constexpr std::enable_if_t<Cond, Kargs>
579  MakeKargs(const void* q_ptr,
580  const void* k_ptr,
581  const void* v_ptr,
582  const void* bias_ptr,
583  void* rand_val_ptr,
584  void* lse_ptr,
585  void* o_ptr,
586  ck_tile::index_t seqlen_q,
587  ck_tile::index_t seqlen_k,
588  ck_tile::index_t hdim_q,
589  ck_tile::index_t hdim_v,
590  ck_tile::index_t num_head_q,
591  ck_tile::index_t nhead_ratio_qk,
592  float scale_s,
593  float scale_p,
594  float scale_o,
595  float logits_soft_cap,
596  ck_tile::index_t stride_q,
597  ck_tile::index_t stride_k,
598  ck_tile::index_t stride_v,
599  ck_tile::index_t stride_bias,
600  ck_tile::index_t stride_randval,
601  ck_tile::index_t stride_o,
602  ck_tile::index_t nhead_stride_q,
603  ck_tile::index_t nhead_stride_k,
604  ck_tile::index_t nhead_stride_v,
605  ck_tile::index_t nhead_stride_bias,
606  ck_tile::index_t nhead_stride_randval,
607  ck_tile::index_t nhead_stride_lse,
608  ck_tile::index_t nhead_stride_o,
609  ck_tile::index_t batch_stride_q,
610  ck_tile::index_t batch_stride_k,
611  ck_tile::index_t batch_stride_v,
612  ck_tile::index_t batch_stride_bias,
613  ck_tile::index_t batch_stride_randval,
614  ck_tile::index_t batch_stride_lse,
615  ck_tile::index_t batch_stride_o,
616  ck_tile::index_t window_size_left,
617  ck_tile::index_t window_size_right,
618  ck_tile::index_t mask_type,
619  float p_drop,
620  bool s_randval,
621  const std::tuple<const void*, const void*>& drop_seed_offset,
622  const ck_tile::index_t* cu_seqlen_q_ptr = nullptr,
623  const ck_tile::index_t* cu_seqlen_kv_ptr = nullptr)
624  {
625  return MakeKargsImpl(
626  q_ptr,
627  k_ptr,
628  v_ptr,
629  bias_ptr,
630  rand_val_ptr,
631  lse_ptr,
632  o_ptr,
633  seqlen_q,
634  seqlen_k,
635  hdim_q,
636  hdim_v,
637  num_head_q,
638  nhead_ratio_qk,
639  scale_s,
640  scale_p,
641  scale_o,
642  logits_soft_cap,
643  stride_q,
644  stride_k,
645  stride_v,
646  stride_bias,
647  stride_randval,
648  stride_o,
649  nhead_stride_q,
650  nhead_stride_k,
651  nhead_stride_v,
652  nhead_stride_bias,
653  nhead_stride_randval,
654  nhead_stride_lse,
655  nhead_stride_o,
656  batch_stride_q,
657  batch_stride_k,
658  batch_stride_v,
659  batch_stride_bias,
660  batch_stride_randval,
661  batch_stride_lse,
662  batch_stride_o,
663  window_size_left,
664  window_size_right,
665  mask_type,
666  p_drop,
667  s_randval,
668  std::make_pair(std::get<0>(drop_seed_offset), std::get<1>(drop_seed_offset)),
669  cu_seqlen_q_ptr,
670  cu_seqlen_kv_ptr);
671  }
672 
673  template <bool Cond = kIsGroupMode>
674  CK_TILE_HOST static constexpr std::enable_if_t<Cond, Kargs>
675  MakeKargsImpl(const void* q_ptr,
676  const void* k_ptr,
677  const void* v_ptr,
678  const void* bias_ptr,
679  void* rand_val_ptr,
680  void* lse_ptr,
681  void* o_ptr,
682  const void* seqstart_q_ptr,
683  const void* seqstart_k_ptr,
684  const void* seqlen_k_ptr,
685  ck_tile::index_t hdim_q,
686  ck_tile::index_t hdim_v,
687  ck_tile::index_t num_head_q,
688  ck_tile::index_t nhead_ratio_qk,
689  float scale_s,
690  float scale_p,
691  float scale_o,
692  float logits_soft_cap,
693  ck_tile::index_t stride_q,
694  ck_tile::index_t stride_k,
695  ck_tile::index_t stride_v,
696  ck_tile::index_t stride_bias,
697  ck_tile::index_t stride_randval,
698  ck_tile::index_t stride_o,
699  ck_tile::index_t nhead_stride_q,
700  ck_tile::index_t nhead_stride_k,
701  ck_tile::index_t nhead_stride_v,
702  ck_tile::index_t nhead_stride_bias,
703  ck_tile::index_t nhead_stride_randval,
704  ck_tile::index_t nhead_stride_lse,
705  ck_tile::index_t nhead_stride_o,
706  ck_tile::index_t window_size_left,
707  ck_tile::index_t window_size_right,
708  ck_tile::index_t mask_type,
709  ck_tile::index_t min_seqlen_q,
710  float p_drop,
711  bool s_randval,
712  std::variant<std::pair<uint64_t, uint64_t>, std::pair<const void*, const void*>>
713  drop_seed_offset,
714  const void* seqstart_padded_q_ptr = nullptr,
715  const void* seqstart_padded_k_ptr = nullptr)
716  {
717  Kargs kargs{{q_ptr,
718  k_ptr,
719  v_ptr,
720  o_ptr,
721  -1, // seqlen will be updated by another pointer
722  -1, //
723  hdim_q,
724  hdim_v,
725  num_head_q,
726  nhead_ratio_qk,
727 #if CK_TILE_FMHA_FWD_FAST_EXP2
728  static_cast<float>(scale_s * ck_tile::log2e_v<>),
729 #else
730  scale_s,
731 #endif
732  stride_q,
733  stride_k,
734  stride_v,
735  stride_o,
736  nhead_stride_q,
737  nhead_stride_k,
738  nhead_stride_v,
739  nhead_stride_o}, // args for common karg
740  {}, // placeholder for bias
741  {}, // placeholder for mask
742  {}, // placeholder for lse
743  {}, // placeholder for fp8_static_quant args
744  {}, // placeholder for dropout
745  {}, // placeholder for logits_soft_cap
746  {}, // placeholder for min_seqlen_q
747  reinterpret_cast<const int32_t*>(seqstart_q_ptr),
748  reinterpret_cast<const int32_t*>(seqstart_k_ptr),
749  reinterpret_cast<const int32_t*>(seqlen_k_ptr)};
750 
752  {
753  kargs.bias_ptr = bias_ptr;
754  kargs.stride_bias = stride_bias;
755  kargs.nhead_stride_bias = nhead_stride_bias;
756  }
757  else if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI)
758  {
759  kargs.alibi_slope_ptr = bias_ptr;
760  kargs.alibi_slope_stride = stride_bias;
761  }
762  if constexpr(kHasMask)
763  {
764  kargs.window_size_left = window_size_left;
765  kargs.window_size_right = window_size_right;
766  kargs.mask_type = static_cast<ck_tile::GenericAttentionMaskEnum>(mask_type);
767  }
768  if constexpr(kStoreLSE)
769  {
770  kargs.lse_ptr = lse_ptr;
771  kargs.nhead_stride_lse = nhead_stride_lse;
772  }
773  if constexpr(kDoFp8StaticQuant)
774  {
775  kargs.scale_p = scale_p;
776  kargs.scale_o = scale_o;
777  }
778  if constexpr(kHasDropout)
779  {
780  if(drop_seed_offset.index() == 0) // seed & offset come from host
781  {
782  const auto& [seed, offset] = std::get<0>(drop_seed_offset);
783  kargs.init_dropout(p_drop, seed, offset);
784  }
785  else // seed & offset come from device
786  {
787  const auto& [seed_ptr, offset_ptr] = std::get<1>(drop_seed_offset);
788  kargs.init_dropout(p_drop,
789  reinterpret_cast<const uint64_t*>(seed_ptr),
790  reinterpret_cast<const uint64_t*>(offset_ptr));
791  }
792 
793  kargs.rand_val_ptr = rand_val_ptr;
794  kargs.stride_randval = stride_randval;
795  kargs.nhead_stride_randval = nhead_stride_randval;
796  kargs.is_store_randval = s_randval;
797  }
798  if constexpr(kHasLogitsSoftCap)
799  {
800  kargs.init_logits_soft_cap(logits_soft_cap);
801  }
802  if constexpr(kSkipMinSeqlenQ)
803  {
804  kargs.min_seqlen_q = min_seqlen_q;
805  }
806 
807  kargs.seqstart_padded_q_ptr = reinterpret_cast<const int32_t*>(seqstart_padded_q_ptr);
808  kargs.seqstart_padded_k_ptr = reinterpret_cast<const int32_t*>(seqstart_padded_k_ptr);
809  return kargs;
810  }
811 
812  // std::variant<> can't take in a list initializer, overload for backward compatibility
813  template <bool Cond = kIsGroupMode>
814  CK_TILE_HOST static constexpr std::enable_if_t<Cond, Kargs>
815  MakeKargs(const void* q_ptr,
816  const void* k_ptr,
817  const void* v_ptr,
818  const void* bias_ptr,
819  void* rand_val_ptr,
820  void* lse_ptr,
821  void* o_ptr,
822  const void* seqstart_q_ptr,
823  const void* seqstart_k_ptr,
824  const void* seqlen_k_ptr,
825  ck_tile::index_t hdim_q,
826  ck_tile::index_t hdim_v,
827  ck_tile::index_t num_head_q,
828  ck_tile::index_t nhead_ratio_qk,
829  float scale_s,
830  float scale_p,
831  float scale_o,
832  float logits_soft_cap,
833  ck_tile::index_t stride_q,
834  ck_tile::index_t stride_k,
835  ck_tile::index_t stride_v,
836  ck_tile::index_t stride_bias,
837  ck_tile::index_t stride_randval,
838  ck_tile::index_t stride_o,
839  ck_tile::index_t nhead_stride_q,
840  ck_tile::index_t nhead_stride_k,
841  ck_tile::index_t nhead_stride_v,
842  ck_tile::index_t nhead_stride_bias,
843  ck_tile::index_t nhead_stride_randval,
844  ck_tile::index_t nhead_stride_lse,
845  ck_tile::index_t nhead_stride_o,
846  ck_tile::index_t window_size_left,
847  ck_tile::index_t window_size_right,
848  ck_tile::index_t mask_type,
849  ck_tile::index_t min_seqlen_q,
850  float p_drop,
851  bool s_randval,
852  const std::tuple<uint64_t, uint64_t>& drop_seed_offset,
853  const void* seqstart_padded_q_ptr = nullptr,
854  const void* seqstart_padded_k_ptr = nullptr)
855  {
856  return MakeKargsImpl(
857  q_ptr,
858  k_ptr,
859  v_ptr,
860  bias_ptr,
861  rand_val_ptr,
862  lse_ptr,
863  o_ptr,
864  seqstart_q_ptr,
865  seqstart_k_ptr,
866  seqlen_k_ptr,
867  hdim_q,
868  hdim_v,
869  num_head_q,
870  nhead_ratio_qk,
871  scale_s,
872  scale_p,
873  scale_o,
874  logits_soft_cap,
875  stride_q,
876  stride_k,
877  stride_v,
878  stride_bias,
879  stride_randval,
880  stride_o,
881  nhead_stride_q,
882  nhead_stride_k,
883  nhead_stride_v,
884  nhead_stride_bias,
885  nhead_stride_randval,
886  nhead_stride_lse,
887  nhead_stride_o,
888  window_size_left,
889  window_size_right,
890  mask_type,
891  min_seqlen_q,
892  p_drop,
893  s_randval,
894  std::make_pair(std::get<0>(drop_seed_offset), std::get<1>(drop_seed_offset)),
895  seqstart_padded_q_ptr,
896  seqstart_padded_k_ptr);
897  }
898 
899  // std::variant<> can't take in a list initializer, overload for backward compatibility
900  template <bool Cond = kIsGroupMode>
901  CK_TILE_HOST static constexpr std::enable_if_t<Cond, Kargs>
902  MakeKargs(const void* q_ptr,
903  const void* k_ptr,
904  const void* v_ptr,
905  const void* bias_ptr,
906  void* rand_val_ptr,
907  void* lse_ptr,
908  void* o_ptr,
909  const void* seqstart_q_ptr,
910  const void* seqstart_k_ptr,
911  const void* seqlen_k_ptr,
912  ck_tile::index_t hdim_q,
913  ck_tile::index_t hdim_v,
914  ck_tile::index_t num_head_q,
915  ck_tile::index_t nhead_ratio_qk,
916  float scale_s,
917  float scale_p,
918  float scale_o,
919  float logits_soft_cap,
920  ck_tile::index_t stride_q,
921  ck_tile::index_t stride_k,
922  ck_tile::index_t stride_v,
923  ck_tile::index_t stride_bias,
924  ck_tile::index_t stride_randval,
925  ck_tile::index_t stride_o,
926  ck_tile::index_t nhead_stride_q,
927  ck_tile::index_t nhead_stride_k,
928  ck_tile::index_t nhead_stride_v,
929  ck_tile::index_t nhead_stride_bias,
930  ck_tile::index_t nhead_stride_randval,
931  ck_tile::index_t nhead_stride_lse,
932  ck_tile::index_t nhead_stride_o,
933  ck_tile::index_t window_size_left,
934  ck_tile::index_t window_size_right,
935  ck_tile::index_t mask_type,
936  ck_tile::index_t min_seqlen_q,
937  float p_drop,
938  bool s_randval,
939  const std::tuple<const void*, const void*>& drop_seed_offset,
940  const void* seqstart_padded_q_ptr = nullptr,
941  const void* seqstart_padded_k_ptr = nullptr)
942  {
943  return MakeKargsImpl(
944  q_ptr,
945  k_ptr,
946  v_ptr,
947  bias_ptr,
948  rand_val_ptr,
949  lse_ptr,
950  o_ptr,
951  seqstart_q_ptr,
952  seqstart_k_ptr,
953  seqlen_k_ptr,
954  hdim_q,
955  hdim_v,
956  num_head_q,
957  nhead_ratio_qk,
958  scale_s,
959  scale_p,
960  scale_o,
961  logits_soft_cap,
962  stride_q,
963  stride_k,
964  stride_v,
965  stride_bias,
966  stride_randval,
967  stride_o,
968  nhead_stride_q,
969  nhead_stride_k,
970  nhead_stride_v,
971  nhead_stride_bias,
972  nhead_stride_randval,
973  nhead_stride_lse,
974  nhead_stride_o,
975  window_size_left,
976  window_size_right,
977  mask_type,
978  min_seqlen_q,
979  p_drop,
980  s_randval,
981  std::make_pair(std::get<0>(drop_seed_offset), std::get<1>(drop_seed_offset)),
982  seqstart_padded_q_ptr,
983  seqstart_padded_k_ptr);
984  }
985 
986  CK_TILE_HOST static constexpr auto GridSize(ck_tile::index_t batch_size_,
987  ck_tile::index_t nhead_,
988  ck_tile::index_t seqlen_q_,
989  ck_tile::index_t hdim_v_,
990  bool has_padded_seqlen_k = false)
991  {
992  // has_padded_seqlen_k is determined by checking (seqlen_k_ptr != nullptr)
993  if(has_padded_seqlen_k)
994  {
995  // TODO: this may need tuning
996  return dim3(nhead_,
997  batch_size_,
998  ck_tile::integer_divide_ceil(seqlen_q_, FmhaPipeline::kM0) *
999  ck_tile::integer_divide_ceil(hdim_v_, FmhaPipeline::kN1));
1000  }
1001  else
1002  {
1003  // TODO: this may need tuning
1004  return dim3(nhead_,
1005  ck_tile::integer_divide_ceil(seqlen_q_, FmhaPipeline::kM0) *
1006  ck_tile::integer_divide_ceil(hdim_v_, FmhaPipeline::kN1),
1007  batch_size_);
1008  }
1009  }
1010 
1011  CK_TILE_DEVICE static constexpr auto GetTileIndex(const Kargs& kargs)
1012  {
1013  bool has_padded_seqlen_k = false;
1014 
1015  if constexpr(kIsGroupMode)
1016  has_padded_seqlen_k = (kargs.seqlen_k_ptr != nullptr);
1017 
1018  if(has_padded_seqlen_k)
1019  {
1020  // const index_t num_tile_m0 = seqlen_q / kM0;
1021  const index_t num_tile_n1 =
1022  ck_tile::integer_divide_ceil(kargs.hdim_v, FmhaPipeline::kN1);
1023 
1024  const index_t i_block = blockIdx.z;
1025  const index_t i_nhead = blockIdx.x;
1026  const index_t i_batch = blockIdx.y;
1027 
1028  const auto f = [](index_t dividend, index_t divisor) {
1029  index_t quotient = dividend / divisor;
1030  index_t modulus = dividend - quotient * divisor;
1031  return ck_tile::make_tuple(quotient, modulus);
1032  };
1033 
1034  const auto [i_tile_m, i_tile_n] = f(i_block, num_tile_n1);
1035 
1036  if constexpr(kHasMask)
1037  {
1038  // assume that num_tile_n1 is always 1
1039  return ck_tile::make_tuple(gridDim.z - 1 - i_tile_m, i_tile_n, i_nhead, i_batch);
1040  }
1041  else
1042  {
1043  return ck_tile::make_tuple(i_tile_m, i_tile_n, i_nhead, i_batch);
1044  }
1045  }
1046  else
1047  {
1048  // const index_t num_tile_m0 = seqlen_q / kM0;
1049  const index_t num_tile_n1 =
1050  ck_tile::integer_divide_ceil(kargs.hdim_v, FmhaPipeline::kN1);
1051 
1052  const index_t i_block = blockIdx.y; // blockIdx.x
1053  const index_t i_nhead = blockIdx.x; // blockIdx.y
1054  const index_t i_batch = blockIdx.z;
1055 
1056  const auto f = [](index_t dividend, index_t divisor) {
1057  index_t quotient = dividend / divisor;
1058  index_t modulus = dividend - quotient * divisor;
1059  return ck_tile::make_tuple(quotient, modulus);
1060  };
1061 
1062  const auto [i_tile_m, i_tile_n] = f(i_block, num_tile_n1);
1063 
1064  if constexpr(kHasMask)
1065  {
1066  // assume that num_tile_n1 is always 1
1067  return ck_tile::make_tuple(gridDim.y - 1 - i_tile_m, i_tile_n, i_nhead, i_batch);
1068  }
1069  else
1070  {
1071  return ck_tile::make_tuple(i_tile_m, i_tile_n, i_nhead, i_batch);
1072  }
1073  }
1074  }
1075 
1076  CK_TILE_HOST static constexpr auto BlockSize() { return dim3(kBlockSize); }
1077 
1079  {
1080  return ck_tile::max(FmhaPipeline::GetSmemSize(), EpiloguePipeline::GetSmemSize());
1081  }
1082 
1083  CK_TILE_DEVICE void operator()(Kargs kargs) const
1084  {
1085  if constexpr(kIsAvailable)
1086  run_(std::move(kargs));
1087  }
1088 
1089  CK_TILE_DEVICE void run_(Kargs kargs) const
1090  {
1091  if constexpr(kPipelineName != "qr_async_trload")
1092  {
1093  // allocate LDS
1094  __shared__ char smem_ptr[GetSmemSize()];
1095 
1096  // divide problem
1097  const auto [i_tile_m, i_tile_n, i_nhead, i_batch] = GetTileIndex(kargs);
1098 
1099  const index_t i_m0 = amd_wave_read_first_lane(i_tile_m * FmhaPipeline::kM0);
1100  const index_t i_n1 = amd_wave_read_first_lane(i_tile_n * FmhaPipeline::kN1);
1101 
1102  long_index_t batch_offset_q = 0;
1103  long_index_t batch_offset_k = 0;
1104  long_index_t batch_offset_v = 0;
1105  long_index_t batch_offset_bias = 0;
1106  long_index_t batch_offset_randval = 0;
1107  long_index_t batch_offset_lse = 0;
1108  long_index_t batch_offset_o = 0;
1109 
1110  if constexpr(kIsGroupMode)
1111  {
1112  // logical and physical (padded) starts
1113  const long_index_t query_start_unpadded = kargs.seqstart_q_ptr[i_batch];
1114  const long_index_t key_start_unpadded = kargs.seqstart_k_ptr[i_batch];
1115 
1116  const long_index_t query_start_padded = kargs.seqstart_padded_q_ptr
1117  ? kargs.seqstart_padded_q_ptr[i_batch]
1118  : query_start_unpadded;
1119  const long_index_t key_start_padded = kargs.seqstart_padded_k_ptr
1120  ? kargs.seqstart_padded_k_ptr[i_batch]
1121  : key_start_unpadded;
1122 
1123  // DRAM base offsets use physical padded starts
1124  batch_offset_q = query_start_padded * kargs.stride_q;
1125  batch_offset_k = key_start_padded * kargs.stride_k;
1126  if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
1127  {
1128  batch_offset_v = key_start_padded * kargs.stride_v;
1129  }
1130  else
1131  {
1132  batch_offset_v = key_start_padded;
1133  }
1135  {
1136  batch_offset_bias = query_start_padded * kargs.stride_bias;
1137  }
1138  if constexpr(kStoreLSE)
1139  {
1140  // LSE stays indexed by unpadded starts
1141  batch_offset_lse = query_start_unpadded;
1142  }
1143  if constexpr(kHasDropout)
1144  {
1145  batch_offset_randval = query_start_padded * kargs.stride_randval;
1146  }
1147  batch_offset_o = query_start_padded * kargs.stride_o;
1148 
1149  // real logical lengths (exclude PAD)
1150  const auto adjusted_seqstart_q_ptr = kargs.seqstart_q_ptr + i_batch;
1151  kargs.seqlen_q = adjusted_seqstart_q_ptr[1] - adjusted_seqstart_q_ptr[0];
1152 
1153  if constexpr(kSkipMinSeqlenQ)
1154  {
1155  if(kargs.seqlen_q <= kargs.min_seqlen_q)
1156  {
1157  return;
1158  }
1159  }
1160 
1161  // terminate unnecessary blocks earlier
1162  if(kargs.seqlen_q <= i_m0)
1163  {
1164  return;
1165  }
1166 
1167  if(kargs.seqlen_k_ptr != nullptr)
1168  {
1169  kargs.seqlen_k = kargs.seqlen_k_ptr[i_batch];
1170  }
1171  else
1172  {
1173  const auto adjusted_seqstart_k_ptr = kargs.seqstart_k_ptr + i_batch;
1174  kargs.seqlen_k = adjusted_seqstart_k_ptr[1] - adjusted_seqstart_k_ptr[0];
1175  }
1176  }
1177  else
1178  {
1179  batch_offset_q = static_cast<long_index_t>(i_batch) * kargs.batch_stride_q;
1180  batch_offset_k = static_cast<long_index_t>(i_batch) * kargs.batch_stride_k;
1181  batch_offset_v = static_cast<long_index_t>(i_batch) * kargs.batch_stride_v;
1183  {
1184  batch_offset_bias =
1185  static_cast<long_index_t>(i_batch) * kargs.batch_stride_bias;
1186  }
1187  if constexpr(kStoreLSE)
1188  {
1189  batch_offset_lse = static_cast<long_index_t>(i_batch) * kargs.batch_stride_lse;
1190  }
1191  if constexpr(kHasDropout)
1192  {
1193  batch_offset_randval =
1194  static_cast<long_index_t>(i_batch) * kargs.batch_stride_randval;
1195  }
1196  batch_offset_o = static_cast<long_index_t>(i_batch) * kargs.batch_stride_o;
1197 
1198  // If cumulative seqlen pointers are provided, override per-batch effective lengths
1199  if(kargs.cu_seqlen_q_ptr != nullptr)
1200  {
1201  kargs.seqlen_q =
1202  kargs.cu_seqlen_q_ptr[i_batch + 1] - kargs.cu_seqlen_q_ptr[i_batch];
1203  }
1204  if(kargs.cu_seqlen_kv_ptr != nullptr)
1205  {
1206  kargs.seqlen_k =
1207  kargs.cu_seqlen_kv_ptr[i_batch + 1] - kargs.cu_seqlen_kv_ptr[i_batch];
1208  }
1209  }
1210 
1211  // for simplicity, batch stride we just modify the pointer
1212  const QDataType* q_ptr = reinterpret_cast<const QDataType*>(kargs.q_ptr) +
1213  static_cast<long_index_t>(i_nhead) * kargs.nhead_stride_q +
1214  batch_offset_q;
1215  const KDataType* k_ptr =
1216  reinterpret_cast<const KDataType*>(kargs.k_ptr) +
1217  static_cast<long_index_t>(i_nhead / kargs.nhead_ratio_qk) * kargs.nhead_stride_k +
1218  batch_offset_k;
1219  const VDataType* v_ptr =
1220  reinterpret_cast<const VDataType*>(kargs.v_ptr) +
1221  static_cast<long_index_t>(i_nhead / kargs.nhead_ratio_qk) * kargs.nhead_stride_v +
1222  batch_offset_v;
1223  ODataType* o_ptr = reinterpret_cast<ODataType*>(kargs.o_ptr) +
1224  static_cast<long_index_t>(i_nhead) * kargs.nhead_stride_o +
1225  batch_offset_o;
1226 
1227  // Q/K/V DRAM and DRAM window
1228  const auto q_dram = [&]() {
1229  const auto q_dram_naive = make_naive_tensor_view<address_space_enum::global>(
1230  q_ptr,
1231  make_tuple(kargs.seqlen_q, kargs.hdim_q),
1232  make_tuple(kargs.stride_q, 1),
1234  number<1>{});
1235  if constexpr(FmhaPipeline::kQLoadOnce)
1236  {
1237  return pad_tensor_view(q_dram_naive,
1241  }
1242  else
1243  {
1244  return pad_tensor_view(
1245  q_dram_naive,
1248  }
1249  }();
1250  const auto k_dram = [&]() {
1251  const auto k_dram_naive = make_naive_tensor_view<address_space_enum::global>(
1252  k_ptr,
1253  make_tuple(kargs.seqlen_k, kargs.hdim_q),
1254  make_tuple(kargs.stride_k, 1),
1256  number<1>{});
1257 
1258  constexpr bool kPadSeqLenK_ = kUseAsyncCopy ? kPadSeqLenK : false;
1259  return pad_tensor_view(
1260  k_dram_naive,
1263  }();
1264  const auto v_dram = [&]() {
1265  if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
1266  {
1267  const auto v_dram_naive = make_naive_tensor_view<address_space_enum::global>(
1268  v_ptr,
1269  make_tuple(kargs.seqlen_k, kargs.hdim_v),
1270  make_tuple(kargs.stride_v, 1),
1272  number<1>{});
1273 
1274  const auto v_dram_transposed = transform_tensor_view(
1275  v_dram_naive,
1277  make_pass_through_transform(kargs.seqlen_k)),
1280 
1281  constexpr bool kPadSeqLenK_ = kUseAsyncCopy ? kPadSeqLenK : false;
1282  return pad_tensor_view(
1283  v_dram_transposed,
1286  }
1287  else
1288  {
1289  const auto v_dram_naive = make_naive_tensor_view<address_space_enum::global>(
1290  v_ptr,
1291  make_tuple(kargs.hdim_v, kargs.seqlen_k),
1292  make_tuple(kargs.stride_v, 1),
1294  number<1>{});
1295 
1296  constexpr bool kPadHeadDimV_ = kUseAsyncCopy ? kPadHeadDimV : false;
1297  return pad_tensor_view(
1298  v_dram_naive,
1301  }
1302  }();
1303 
1304  auto q_dram_window = make_tile_window(
1305  q_dram,
1306  [&]() {
1307  if constexpr(FmhaPipeline::kQLoadOnce)
1310  else
1312  }(),
1313  {i_m0, 0});
1314 
1315  auto k_dram_window = make_tile_window(
1316  k_dram,
1318  {0, 0});
1319 
1320  auto v_dram_window = make_tile_window(
1321  v_dram,
1323  {i_n1, 0});
1326  const auto bias_dram_window = [&, i_nhead_ = i_nhead]() {
1327  constexpr auto bias_dram_window_lengths =
1330  {
1331  const BiasDataType* bias_ptr =
1332  reinterpret_cast<const BiasDataType*>(kargs.bias_ptr) +
1333  static_cast<long_index_t>(i_nhead_) * kargs.nhead_stride_bias +
1334  batch_offset_bias;
1335 
1336  const auto bias_dram = [&]() {
1337  const auto bias_dram_naive =
1338  make_naive_tensor_view<address_space_enum::global>(
1339  bias_ptr,
1340  make_tuple(kargs.seqlen_q, kargs.seqlen_k),
1341  make_tuple(kargs.stride_bias, 1),
1343  number<1>{});
1344 
1345  return pad_tensor_view(bias_dram_naive,
1346  bias_dram_window_lengths,
1348  }();
1349 
1350  return make_tile_window(bias_dram, bias_dram_window_lengths, {i_m0, 0});
1351  }
1352  else
1353  {
1354  return make_null_tile_window(bias_dram_window_lengths);
1355  }
1356  }();
1357 
1358  // lse
1359  auto lse_dram_window = [&, i_nhead_ = i_nhead]() {
1360  constexpr auto lse_dram_window_lengths = make_tuple(number<FmhaPipeline::kM0>{});
1361  if constexpr(kStoreLSE)
1362  {
1363  LSEDataType* lse_ptr =
1364  reinterpret_cast<LSEDataType*>(kargs.lse_ptr) +
1365  static_cast<long_index_t>(i_nhead_) * kargs.nhead_stride_lse +
1366  batch_offset_lse;
1367 
1368  const auto lse_dram = [&]() {
1369  const auto lse_dram_naive =
1370  make_naive_tensor_view<address_space_enum::global>(
1371  lse_ptr,
1372  make_tuple(kargs.seqlen_q),
1373  make_tuple(1),
1374  number<1>{},
1375  number<1>{});
1376 
1377  return pad_tensor_view(
1378  lse_dram_naive, lse_dram_window_lengths, sequence<kPadSeqLenQ>{});
1379  }();
1380 
1381  return make_tile_window(lse_dram, lse_dram_window_lengths, {i_m0});
1382  }
1383  else
1384  {
1385  return make_null_tile_window(lse_dram_window_lengths);
1386  }
1387  }();
1388 
1389  auto dropout = [&, i_nhead_ = i_nhead, i_batch_ = i_batch]() {
1390  if constexpr(kHasDropout)
1391  {
1392  return BlockDropout{i_batch_,
1393  i_nhead_,
1394  kargs.num_head_q,
1395  kargs.is_drop_seed_offset_from_host ? kargs.drop_seed.val
1396  : *kargs.drop_seed.ptr,
1397  kargs.is_drop_seed_offset_from_host
1398  ? kargs.drop_offset.val
1399  : *kargs.drop_offset.ptr,
1400  kargs.rp_undrop,
1401  kargs.p_undrop_in_uint8_t,
1402  kargs.is_store_randval};
1403  }
1404  else
1405  {
1406  return NullBlockDropout{};
1407  };
1408  }();
1409 
1410  auto randval_dram_window = [&, i_nhead_ = i_nhead]() {
1411  constexpr auto randval_dram_window_lengths =
1413  if constexpr(kHasDropout)
1414  {
1415  RandValOutputDataType* rand_val_ptr =
1416  reinterpret_cast<RandValOutputDataType*>(kargs.rand_val_ptr) +
1417  static_cast<long_index_t>(i_nhead_) * kargs.nhead_stride_randval +
1418  batch_offset_randval;
1419 
1420  const auto randval_dram = [&]() {
1421  const auto randval_dram_naive =
1422  make_naive_tensor_view<address_space_enum::global>(
1423  rand_val_ptr,
1424  make_tuple(kargs.seqlen_q, kargs.seqlen_k),
1425  make_tuple(kargs.stride_randval, 1),
1426  number<1>{},
1427  number<1>{});
1428 
1429  return pad_tensor_view(randval_dram_naive,
1430  randval_dram_window_lengths,
1432  }();
1433 
1434  return make_tile_window(randval_dram, randval_dram_window_lengths, {i_m0, 0});
1435  }
1436  else
1437  {
1438  return make_null_tile_window(randval_dram_window_lengths);
1439  }
1440  }();
1441 
1442  FmhaMask mask = [&]() {
1443  if constexpr(kHasMask)
1444  return ck_tile::make_generic_attention_mask_from_lr_window<FmhaMask>(
1445  kargs.window_size_left,
1446  kargs.window_size_right,
1447  kargs.seqlen_q,
1448  kargs.seqlen_k,
1450  else
1451  return FmhaMask{kargs.seqlen_q, kargs.seqlen_k};
1452  }();
1453 
1454  // WA i_batch capture structure binding before c++20
1455  auto position_encoding = [&, i_batch_ = i_batch, i_nhead_ = i_nhead]() {
1456  if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI)
1457  {
1458  // data loading, shared by entire wg
1459  // TODO: how to use s_read?
1460  SaccDataType slope =
1461  *(reinterpret_cast<const SaccDataType*>(kargs.alibi_slope_ptr) +
1462  i_batch_ * kargs.alibi_slope_stride + i_nhead_);
1463 #if CK_TILE_FMHA_FWD_FAST_EXP2
1464  slope *= ck_tile::log2e_v<>;
1465 #endif
1466  if constexpr(kHasMask)
1467  {
1468  return make_alibi_from_lr_mask<SaccDataType, true>(slope,
1469  kargs.window_size_left,
1470  kargs.window_size_right,
1471  kargs.seqlen_q,
1472  kargs.seqlen_k,
1473  kargs.mask_type);
1474  }
1475  else
1476  {
1478  slope, kargs.seqlen_q, kargs.seqlen_k, AlibiMode::FROM_BOTTOM_RIGHT};
1479  }
1480  }
1481  else
1482  {
1484  }
1485  }();
1486 
1487  AttentionVariant variant;
1488  const auto variant_params = [&] {
1489  if constexpr(kHasLogitsSoftCap)
1490  {
1492  mask, kargs.scale_s, kargs.logits_soft_cap, kargs.logits_soft_cap_rcp};
1493  }
1494  else
1495  {
1496  return ck_tile::StandardAttentionParams<FmhaMask>{mask, kargs.scale_s};
1497  }
1498  }();
1499 
1500  BlockIndices block_indices{i_batch, i_nhead, i_nhead / kargs.nhead_ratio_qk};
1501 
1502  auto o_acc_tile = [&]() {
1503  if constexpr(kDoFp8StaticQuant)
1504  {
1505  auto o_acc_element_func = [&]() {
1506  if constexpr(std::is_same_v<ODataType, ck_tile::fp8_t>)
1508  ck_tile::scales{kargs.scale_o});
1509  else
1510  return ck_tile::scales{kargs.scale_o};
1511  }();
1512  return FmhaPipeline{}(q_dram_window,
1513  identity{}, // q_element_func
1514  k_dram_window,
1515  identity{}, // k_element_func
1516  v_dram_window,
1517  identity{}, // v_element_func
1518  bias_dram_window,
1519  identity{}, // bias_element_func
1520  randval_dram_window,
1521  lse_dram_window,
1522  identity{}, // lse_element_func
1523  identity{}, // s_acc_element_func
1524  scales{kargs.scale_p}, // p_compute_element_func
1525  o_acc_element_func, // o_acc_element_func
1526  mask,
1527  position_encoding,
1528  kargs.scale_s,
1529  variant,
1530  variant_params,
1531  block_indices,
1532  smem_ptr,
1533  dropout);
1534  }
1535  else
1536  {
1537  return FmhaPipeline{}(q_dram_window,
1538  k_dram_window,
1539  v_dram_window,
1540  bias_dram_window,
1541  randval_dram_window,
1542  lse_dram_window,
1543  mask,
1544  position_encoding,
1545  kargs.scale_s,
1546  variant,
1547  variant_params,
1548  block_indices,
1549  smem_ptr,
1550  dropout);
1551  }
1552  }();
1553 
1554  // O DRAM and O DRAM window
1555  auto o_dram = [&]() {
1556  const auto o_dram_naive = make_naive_tensor_view<address_space_enum::global>(
1557  o_ptr,
1558  make_tuple(kargs.seqlen_q, kargs.hdim_v),
1559  make_tuple(kargs.stride_o, 1),
1561  number<1>{});
1562 
1563  return pad_tensor_view(
1564  o_dram_naive,
1567  }();
1568 
1569  auto o_dram_window = make_tile_window(
1570  o_dram,
1572  {i_m0, i_n1});
1573 
1574  EpiloguePipeline{}(o_dram_window, o_acc_tile, nullptr);
1575  }
1576  else
1577  {
1578  // TODO: Refine the logical here.
1579  // In Decode case
1580  // 1. we don't expect KV data reused by different ThreadGroups, bypass the cache
1581  // 2. limit the LDS usage, as we want higher occupancy
1582  // In Prefill case
1583  // 1. we expect KV data reused by different ThreadGroups, use cache
1584  // 2. use more LDS, as we want better memory latency hiding
1585  // If SplitKV off, we don't expect Q data reused by different ThreadGroups, bypass the
1586  // cache
1587  constexpr bool PrefillCase = FmhaPipeline::kM0 >= 128;
1588  // divide problem
1589  const auto [i_tile_m, i_tile_n, i_nhead, i_batch] = GetTileIndex(kargs);
1590 
1591  const index_t i_m0 = i_tile_m * FmhaPipeline::kM0;
1592  const index_t i_n1 = i_tile_n * FmhaPipeline::kN1;
1593 
1594  long_index_t batch_offset_q = 0;
1595  long_index_t batch_offset_k = 0; // unused for paged-kvcache
1596  long_index_t batch_offset_v = 0; // unused for paged-kvcache
1597  long_index_t batch_offset_bias = 0;
1598  long_index_t batch_offset_lse = 0;
1599  long_index_t batch_offset_o = 0;
1600  // index_t kv_l2p_offset =
1601  // 0; // logical-to-physical offset of seqlen_k coordinate. only used for
1602  // paged-kvcache
1603 
1604  if constexpr(kIsGroupMode)
1605  {
1606  // get starting offset for each batch
1607  const long_index_t query_start_unpadded = kargs.seqstart_q_ptr[i_batch];
1608  const long_index_t key_start_unpadded = kargs.seqstart_k_ptr[i_batch];
1609 
1610  const long_index_t query_start_padded = kargs.seqstart_padded_q_ptr
1611  ? kargs.seqstart_padded_q_ptr[i_batch]
1612  : query_start_unpadded;
1613  const long_index_t key_start_padded = kargs.seqstart_padded_k_ptr
1614  ? kargs.seqstart_padded_k_ptr[i_batch]
1615  : key_start_unpadded;
1616 
1617  batch_offset_q = query_start_padded * kargs.stride_q;
1618  batch_offset_k = key_start_padded * kargs.stride_k;
1619  if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
1620  {
1621  batch_offset_v = key_start_padded * kargs.stride_v;
1622  }
1623  else
1624  {
1625  // col-major V: offset along seqlen dimension is scalar index
1626  batch_offset_v = key_start_padded;
1627  }
1629  {
1630  batch_offset_bias = query_start_padded * kargs.stride_bias;
1631  }
1632 
1633  // LSE layout is [nhead, total_seqlen], index by unpadded start
1634  batch_offset_lse = query_start_unpadded;
1635  batch_offset_o = query_start_padded * kargs.stride_o;
1636 
1637  // get real # queries & # keys under group mode
1638  kargs.seqlen_q = kargs.seqstart_q_ptr[i_batch + 1] - kargs.seqstart_q_ptr[i_batch];
1639 
1640  // # of required blocks is different in each groups, terminate unnecessary blocks
1641  // earlier
1642  if(kargs.seqlen_q <= i_m0)
1643  {
1644  return;
1645  }
1646 
1647  if(kargs.seqlen_k_ptr != nullptr)
1648  {
1649  kargs.seqlen_k = kargs.seqlen_k_ptr[i_batch];
1650  }
1651  else
1652  {
1653  kargs.seqlen_k =
1654  kargs.seqstart_k_ptr[i_batch + 1] - kargs.seqstart_k_ptr[i_batch];
1655  }
1656  }
1657  else
1658  {
1659  batch_offset_q = static_cast<long_index_t>(i_batch) * kargs.batch_stride_q;
1660  batch_offset_k = static_cast<long_index_t>(i_batch) * kargs.batch_stride_k;
1661  batch_offset_v = static_cast<long_index_t>(i_batch) * kargs.batch_stride_v;
1662  if constexpr(kStoreLSE)
1663  {
1664  batch_offset_lse = static_cast<long_index_t>(i_batch) * kargs.batch_stride_lse;
1665  }
1666  batch_offset_o = static_cast<long_index_t>(i_batch) * kargs.batch_stride_o;
1667 
1669  {
1670  batch_offset_bias =
1671  static_cast<long_index_t>(i_batch) * kargs.batch_stride_bias;
1672  }
1673 
1674  // If cumulative seqlen pointers are provided, override per-batch effective lengths
1675  if(kargs.cu_seqlen_q_ptr != nullptr)
1676  {
1677  kargs.seqlen_q =
1678  kargs.cu_seqlen_q_ptr[i_batch + 1] - kargs.cu_seqlen_q_ptr[i_batch];
1679  }
1680  if(kargs.cu_seqlen_kv_ptr != nullptr)
1681  {
1682  kargs.seqlen_k =
1683  kargs.cu_seqlen_kv_ptr[i_batch + 1] - kargs.cu_seqlen_kv_ptr[i_batch];
1684  }
1685  }
1686 
1687  // for simplicity, batch stride we just modify the pointer
1688  const index_t i_nhead_k = i_nhead / kargs.nhead_ratio_qk;
1689 
1690  const QDataType* q_ptr = reinterpret_cast<const QDataType*>(kargs.q_ptr) +
1691  static_cast<long_index_t>(i_nhead) * kargs.nhead_stride_q +
1692  batch_offset_q;
1693  const KDataType* k_ptr = reinterpret_cast<const KDataType*>(kargs.k_ptr) +
1694  static_cast<long_index_t>(i_nhead_k) * kargs.nhead_stride_k +
1695  batch_offset_k;
1696  const VDataType* v_ptr = reinterpret_cast<const VDataType*>(kargs.v_ptr) +
1697  static_cast<long_index_t>(i_nhead_k) * kargs.nhead_stride_v +
1698  batch_offset_v;
1699 
1700  ODataType* o_ptr = reinterpret_cast<ODataType*>(kargs.o_ptr) +
1701  static_cast<long_index_t>(i_nhead) * kargs.nhead_stride_o +
1702  batch_offset_o;
1703 
1704  // Q/K/V DRAM and DRAM window
1705  const auto q_dram = [&] {
1706  const auto q_dram_naive = [&] {
1707  {
1708  return make_naive_tensor_view<address_space_enum::global,
1709  memory_operation_enum::set,
1711  q_ptr,
1712  make_tuple(kargs.seqlen_q, kargs.hdim_q),
1713  make_tuple(kargs.stride_q, 1),
1715  number<1>{});
1716  }
1717  }();
1718 
1719  if constexpr(FmhaPipeline::kQLoadOnce)
1720  {
1721  const auto seqlen_q = kargs.seqlen_q;
1722  const auto q_dram_pad = pad_tensor_view(
1723  q_dram_naive,
1726 #if CK_TILE_FMHA_HANDLE_XOR_LENGTH_FOLD
1727  constexpr index_t LDSLayerSize = 256 / sizeof(QDataType);
1728  constexpr index_t XorLengthFold = LDSLayerSize / (FmhaPipeline::kQKHeaddim);
1729 
1730  if constexpr(XorLengthFold > 1)
1731  {
1732  const auto q_dram_unmerged = transform_tensor_view(
1733  q_dram_pad,
1734  make_tuple(
1736  make_tuple(seqlen_q / XorLengthFold, XorLengthFold)),
1740 
1741  const auto q_dram_merged = transform_tensor_view(
1742  q_dram_unmerged,
1743  make_tuple(make_pass_through_transform(seqlen_q / XorLengthFold),
1745  XorLengthFold, number<FmhaPipeline::kQKHeaddim>{}))),
1748 
1749  const auto q_dram_unmerged_xor = transform_tensor_view(
1750  q_dram_merged,
1751  make_tuple(make_pass_through_transform(seqlen_q / XorLengthFold),
1757 
1758  const auto q_dram_permuted = transform_tensor_view(
1759  q_dram_unmerged_xor,
1760  make_tuple(
1762  make_tuple(seqlen_q / XorLengthFold,
1767 
1768  const auto q_dram_tmp = transform_tensor_view(
1769  q_dram_permuted,
1770  make_tuple(
1771  make_pass_through_transform(seqlen_q / XorLengthFold),
1774  number<FmhaPipeline::kQKHeaddim /
1775  FmhaPipeline::kAlignmentQ>{})),
1779 
1780  return transform_tensor_view(
1781  q_dram_tmp,
1782  make_tuple(
1784  make_tuple(seqlen_q / XorLengthFold, number<XorLengthFold>{})),
1790  }
1791  else
1792 #endif // CK_TILE_FMHA_HANDLE_XOR_LENGTH_FOLD
1793  {
1794  const auto q_dram_unmerged = transform_tensor_view(
1795  q_dram_pad,
1796  make_tuple(
1797  make_pass_through_transform(seqlen_q),
1803 
1804  const auto q_dram_permuted = transform_tensor_view(
1805  q_dram_unmerged,
1806  make_tuple(
1807  make_xor_transform(make_tuple(seqlen_q,
1808  number<FmhaPipeline::kQKHeaddim /
1809  FmhaPipeline::kAlignmentQ>{})),
1813 
1814  return transform_tensor_view(
1815  q_dram_permuted,
1816  make_tuple(
1817  make_pass_through_transform(seqlen_q),
1823  }
1824  }
1825  else
1826  {
1827  return pad_tensor_view(
1828  q_dram_naive,
1831  }
1832  }();
1833 
1834  const auto make_k_dram = [&](const KDataType* data, index_t height) {
1835  const auto k_dram_naive = make_naive_tensor_view<address_space_enum::global>(
1836  data, // will update this pointer if using paged-kvcache
1837  make_tuple(height, kargs.hdim_q),
1838  make_tuple(kargs.stride_k, 1),
1840  number<1>{});
1841 
1842  const auto k_dram_pad = pad_tensor_view(
1843  k_dram_naive,
1846 
1847  constexpr auto kDramTileK =
1848  FmhaPipeline::kKLoadOnce ? FmhaPipeline::kQKHeaddim : FmhaPipeline::kK0;
1849 
1850 #if CK_TILE_FMHA_HANDLE_XOR_LENGTH_FOLD
1851  constexpr index_t LDSLayerSize = 256 / sizeof(KDataType);
1852  constexpr index_t XorLengthFold = LDSLayerSize / (FmhaPipeline::kQKHeaddim);
1853 
1854  if constexpr(XorLengthFold > 1)
1855  {
1856  const auto k_dram_unmerged = transform_tensor_view(
1857  k_dram_pad,
1859  make_tuple(height / XorLengthFold, XorLengthFold)),
1863 
1864  const auto k_dram_merged = transform_tensor_view(
1865  k_dram_unmerged,
1866  make_tuple(make_pass_through_transform(height / XorLengthFold),
1868  XorLengthFold, number<FmhaPipeline::kQKHeaddim>{}))),
1871 
1872  const auto k_dram_unmerged_xor = transform_tensor_view(
1873  k_dram_merged,
1874  make_tuple(make_pass_through_transform(height / XorLengthFold),
1880 
1881  const auto k_dram_permuted = transform_tensor_view(
1882  k_dram_unmerged_xor,
1883  make_tuple(
1885  make_tuple(height / XorLengthFold,
1890 
1891  const auto k_dram_tmp = transform_tensor_view(
1892  k_dram_permuted,
1893  make_tuple(
1894  make_pass_through_transform(height / XorLengthFold),
1897  number<FmhaPipeline::kQKHeaddim / FmhaPipeline::kAlignmentK>{})),
1901 
1902  return transform_tensor_view(
1903  k_dram_tmp,
1904  make_tuple(
1906  make_tuple(height / XorLengthFold, number<XorLengthFold>{})),
1912  }
1913  else
1914 #endif // CK_TILE_FMHA_HANDLE_XOR_LENGTH_FOLD
1915  {
1916  const auto k_dram_unmerged = transform_tensor_view(
1917  k_dram_pad,
1920  make_tuple(number<FmhaPipeline::kQKHeaddim / kDramTileK /
1921  FmhaPipeline::kAlignmentK>{},
1922  number<kDramTileK / FmhaPipeline::kAlignmentK>{},
1926 
1927  const auto k_dram_permuted = transform_tensor_view(
1928  k_dram_unmerged,
1929  make_tuple(
1933  number<FmhaPipeline::kQKHeaddim / kDramTileK /
1934  FmhaPipeline::kAlignmentK>{}),
1938 
1939  return transform_tensor_view(
1940  k_dram_permuted,
1943  make_tuple(number<FmhaPipeline::kQKHeaddim / kDramTileK /
1944  FmhaPipeline::kAlignmentK>{},
1945  number<kDramTileK / FmhaPipeline::kAlignmentK>{},
1949  }
1950  };
1951  const auto k_dram = [&]() {
1952  {
1953  return make_k_dram(k_ptr, kargs.seqlen_k);
1954  }
1955  }();
1956 
1957  const auto make_v_dram = [&](const VDataType* data, index_t length) {
1958  const auto v_dram_naive = make_naive_tensor_view<address_space_enum::global>(
1959  data, // will update this pointer if using paged-kvcache
1960  make_tuple(length, kargs.hdim_v),
1961  make_tuple(kargs.stride_v, 1),
1963  number<1>{});
1964 
1965  // TODO: Add kVHeadDim
1966  constexpr index_t XorGroupSize =
1967  FmhaPipeline::Problem::BlockFmhaShape::Gemm1WarpTile::at(number<0>{});
1968 
1969  const auto v_dram_pad = pad_tensor_view(
1970  v_dram_naive,
1973 
1974 #if CK_TILE_FMHA_HANDLE_XOR_LENGTH_FOLD
1975  constexpr index_t LDSLayerSize = 256 / sizeof(VDataType);
1976  constexpr index_t XorLengthFold = LDSLayerSize / (FmhaPipeline::kQKHeaddim);
1977 
1978  if constexpr(XorLengthFold > 1)
1979  {
1980  const auto v_dram_unmerged = transform_tensor_view(
1981  v_dram_pad,
1983  make_tuple(length / XorLengthFold, XorLengthFold)),
1987 
1988  const auto v_dram_merged = transform_tensor_view(
1989  v_dram_unmerged,
1990  make_tuple(make_pass_through_transform(length / XorLengthFold),
1992  XorLengthFold, number<FmhaPipeline::kQKHeaddim>{}))),
1995 
1996  const auto v_dram_unmerged_xor = transform_tensor_view(
1997  v_dram_merged,
1998  make_tuple(
1999  make_pass_through_transform(length / XorLengthFold),
2001  number<XorGroupSize>{}))),
2004 
2005  const auto v_dram_permuted = transform_tensor_view(
2006  v_dram_unmerged_xor,
2007  make_tuple(
2008  make_xor_transform(make_tuple(length / XorLengthFold,
2013 
2014  const auto v_dram_tmp = transform_tensor_view(
2015  v_dram_permuted,
2016  make_tuple(make_pass_through_transform(length / XorLengthFold),
2019  number<FmhaPipeline::kQKHeaddim / XorGroupSize>{})),
2023 
2024  return transform_tensor_view(
2025  v_dram_tmp,
2027  make_tuple(length / XorLengthFold, number<XorLengthFold>{})),
2030  number<XorGroupSize>{}))),
2033  }
2034  else
2035 #endif // CK_TILE_FMHA_HANDLE_XOR_LENGTH_FOLD
2036  {
2037  const auto v_dram_unmerged = transform_tensor_view(
2038  v_dram_pad,
2042  number<XorGroupSize>{}))),
2045 
2046  const auto v_dram_permuted = transform_tensor_view(
2047  v_dram_unmerged,
2053 
2054  return transform_tensor_view(
2055  v_dram_permuted,
2059  number<XorGroupSize>{}))),
2062  }
2063  };
2064 
2065  const auto v_dram = [&]() {
2066  {
2067  return make_v_dram(v_ptr, kargs.seqlen_k);
2068  }
2069  }();
2070 
2071  auto q_dram_window = make_tile_window(
2072  q_dram,
2073  [&]() {
2074  if constexpr(FmhaPipeline::kQLoadOnce)
2077  else
2079  }(),
2080  {i_m0, 0});
2081 
2082  auto k_dram_window = make_tile_window(
2083  k_dram,
2085  {0, 0});
2086 
2087  auto v_dram_window = make_tile_window(
2088  v_dram,
2090  {0, 0});
2091 
2094  const auto bias_dram_window = [&, i_nhead_ = i_nhead]() {
2095  constexpr auto bias_dram_window_lengths =
2098  {
2099  const BiasDataType* bias_ptr =
2100  reinterpret_cast<const BiasDataType*>(kargs.bias_ptr) +
2101  static_cast<long_index_t>(i_nhead_) * kargs.nhead_stride_bias +
2102  batch_offset_bias;
2103 
2104  const auto bias_dram = [&]() {
2105  const auto bias_dram_naive =
2106  make_naive_tensor_view<address_space_enum::global>(
2107  bias_ptr,
2108  make_tuple(kargs.seqlen_q, kargs.seqlen_k),
2109  make_tuple(kargs.stride_bias, 1),
2111  number<1>{});
2112 
2113  return pad_tensor_view(bias_dram_naive,
2114  bias_dram_window_lengths,
2116  }();
2117 
2118  return make_tile_window(bias_dram, bias_dram_window_lengths, {i_m0, 0});
2119  }
2120  else
2121  {
2122  return make_null_tile_window(bias_dram_window_lengths);
2123  }
2124  }();
2125 
2126  // lse acc
2127  auto lse_dram_window = [&, i_nhead_ = i_nhead]() {
2128  constexpr auto lse_dram_window_lengths = make_tuple(number<FmhaPipeline::kM0>{});
2129  if constexpr(kStoreLSE)
2130  {
2131  LSEDataType* lse_ptr =
2132  reinterpret_cast<LSEDataType*>(kargs.lse_ptr) +
2133  static_cast<long_index_t>(i_nhead_) * kargs.nhead_stride_lse +
2134  batch_offset_lse;
2135 
2136  const auto lse_dram = [&] {
2137  const auto lse_dram_naive = [&] {
2138  {
2139  return make_naive_tensor_view<address_space_enum::global>(
2140  lse_ptr,
2141  make_tuple(kargs.seqlen_q),
2142  make_tuple(1),
2143  number<1>{},
2144  number<1>{});
2145  }
2146  }();
2147  return pad_tensor_view(
2148  lse_dram_naive, lse_dram_window_lengths, sequence<kPadSeqLenQ>{});
2149  }();
2150 
2151  return make_tile_window(lse_dram, lse_dram_window_lengths, {i_m0});
2152  }
2153  else
2154  {
2155  return make_null_tile_window(lse_dram_window_lengths);
2156  }
2157  }();
2158 
2159  FmhaMask mask = [&]() {
2160  if constexpr(kHasMask)
2161  return ck_tile::make_generic_attention_mask_from_lr_window<FmhaMask>(
2162  kargs.window_size_left,
2163  kargs.window_size_right,
2164  kargs.seqlen_q,
2165  kargs.seqlen_k,
2167  else
2168  return FmhaMask{kargs.seqlen_q, kargs.seqlen_k};
2169  }();
2170 
2171  // WA i_batch capture structure binding before c++20
2172  auto position_encoding = [&, i_batch_ = i_batch, i_nhead_ = i_nhead]() {
2173  if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI)
2174  {
2175  // data loading, shared by entire wg
2176  // TODO: how to use s_read?
2177  SaccDataType slope =
2178  *(reinterpret_cast<const SaccDataType*>(kargs.alibi_slope_ptr) +
2179  i_batch_ * kargs.alibi_slope_stride + i_nhead_);
2180 #if CK_TILE_FMHA_FWD_FAST_EXP2
2181  slope *= ck_tile::log2e_v<>;
2182 #endif
2183  if constexpr(kHasMask)
2184  {
2185  return make_alibi_from_lr_mask<SaccDataType, true, 32>(
2186  slope,
2187  kargs.window_size_left,
2188  kargs.window_size_right,
2189  kargs.seqlen_q,
2190  kargs.seqlen_k,
2191  kargs.mask_type);
2192  }
2193  else
2194  {
2196  slope, kargs.seqlen_q, kargs.seqlen_k, AlibiMode::FROM_BOTTOM_RIGHT};
2197  }
2198  }
2199  else
2200  {
2202  }
2203  }();
2204 
2205  auto o_acc_tile = [&]() {
2206  if constexpr(PrefillCase)
2207  {
2208  // allocate double lds
2209  // add __restrict__ here to avoid aliasing
2210  __shared__ char smem_ptrk0
2211  [FmhaPipeline::Policy::template GetSmemSizeK<typename FmhaPipeline::Problem,
2212  true>()];
2213  __shared__ char smem_ptrk1
2214  [FmhaPipeline::Policy::template GetSmemSizeK<typename FmhaPipeline::Problem,
2215  true>()];
2216  __shared__ char smem_ptrv0[FmhaPipeline::Policy::template GetSmemSizeV<
2217  typename FmhaPipeline::Problem>()];
2218  __shared__ char smem_ptrv1[FmhaPipeline::Policy::template GetSmemSizeV<
2219  typename FmhaPipeline::Problem>()];
2220 
2221  return FmhaPipeline{}(q_dram_window,
2222  k_dram_window,
2223  v_dram_window,
2224  bias_dram_window,
2225  lse_dram_window,
2226  mask,
2227  position_encoding,
2228  kargs.scale_s,
2229  smem_ptrk0,
2230  smem_ptrk1,
2231  smem_ptrv0,
2232  smem_ptrv1);
2233  }
2234  else
2235  {
2236  __shared__ char smem_ptr[GetSmemSize()];
2237  return FmhaPipeline{}(q_dram_window,
2238  k_dram_window,
2239  v_dram_window,
2240  bias_dram_window,
2241  lse_dram_window,
2242  mask,
2243  position_encoding,
2244  kargs.scale_s,
2245  smem_ptr);
2246  }
2247  }();
2248 
2249  // Oacc DRAM and Oacc DRAM window
2250  auto o_dram = [&] {
2251  const auto o_dram_naive = [&] {
2252  {
2253  return make_naive_tensor_view<address_space_enum::global>(
2254  o_ptr,
2255  make_tuple(kargs.seqlen_q, kargs.hdim_v),
2256  make_tuple(kargs.stride_o, 1),
2258  number<1>{});
2259  }
2260  }();
2261 
2262  return pad_tensor_view(
2263  o_dram_naive,
2266  }();
2267 
2268  auto o_dram_window = make_tile_window(
2269  o_dram,
2271  {i_m0, i_n1});
2272 
2273  EpiloguePipeline{}(o_dram_window, o_acc_tile, nullptr);
2274  }
2275  }
2276 };
2277 
2278 } // namespace ck_tile
#define CK_TILE_DEVICE
Definition: config.hpp:41
#define CK_TILE_HOST
Definition: config.hpp:40
#define CK_TILE_HOST_DEVICE
Definition: config.hpp:42
#define _TS_
#define _SS_
__host__ constexpr __device__ T max(T x)
Definition: math.hpp:84
__host__ T floor(T x)
Definition: math_v2.hpp:367
Definition: cluster_descriptor.hpp:13
constexpr CK_TILE_DEVICE auto make_null_tile_window(const WindowLengths &window_lengths)
Definition: null_tile_window.hpp:66
_BitInt(8) fp8_t
Definition: float8.hpp:204
constexpr CK_TILE_HOST_DEVICE auto integer_divide_ceil(X x, Y y)
Definition: math.hpp:149
__device__ uint32_t amd_wave_read_first_lane(uint16_t v)
Definition: amd_buffer_addressing.hpp:2834
_Float16 fp16_t
Definition: half.hpp:110
constexpr CK_TILE_HOST_DEVICE auto transform_tensor_view(const OldTensorView &old_tensor_view, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition: tensor_view.hpp:511
bfloat16_t bf16_t
Definition: bfloat16.hpp:113
int32_t index_t
Definition: integer.hpp:9
constexpr CK_TILE_HOST_DEVICE auto pad_tensor_view(const TensorView &tensor_view, const TileLengths &tile_lengths, DoPads)
Definition: tensor_view.hpp:530
constexpr CK_TILE_HOST_DEVICE auto make_pass_through_transform(const LowLength &low_length)
Definition: coordinate_transform.hpp:1558
float fp32_t
Definition: pk_fp4.hpp:21
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition: type_traits.hpp:21
constexpr CK_TILE_HOST_DEVICE auto make_unmerge_transform(const UpLengths &up_lengths, bool_constant< Use24BitIntegerCalculation >=bool_constant< false >{})
Definition: coordinate_transform.hpp:1622
constexpr CK_TILE_HOST_DEVICE auto make_merge_transform_v3_division_mod(const LowLengths &low_lengths)
Definition: coordinate_transform.hpp:1609
int64_t long_index_t
Definition: integer.hpp:11
int32_t int32_t
Definition: integer.hpp:10
constexpr CK_TILE_DEVICE auto make_tile_window(null_tensor_view, const WindowLengths &window_lengths, const multi_index< WindowLengths::size()> &, Ts &&...)
Definition: null_tile_window.hpp:75
unsigned _BitInt(8) bf8_t
Definition: float8.hpp:206
constexpr CK_TILE_HOST_DEVICE auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:360
constexpr CK_TILE_HOST_DEVICE auto make_xor_transform(const LowLengths &low_lengths)
Definition: coordinate_transform.hpp:1662
GenericAttentionMaskEnum
Definition: block_masking.hpp:11
constexpr CK_TILE_HOST_DEVICE auto make_naive_tensor_view(DataType *__restrict__ p, const tuple< Lengths... > &lengths, const tuple< Strides... > &strides, number< GuaranteedLastDimensionVectorLength >=number<-1 >{}, number< GuaranteedLastDimensionVectorStride >=number<-1 >{})
Definition: tensor_view.hpp:471
__host__ __device__ composes(Ts &&...) -> composes< remove_cvref_t< Ts >... >
FIXME: create macro to replace 'host device' and nothing more.
constexpr CK_TILE_HOST_DEVICE T max(T x)
Definition: math.hpp:161
typename conditional< predicate, X, Y >::type conditional_t
Definition: functional.hpp:115
unsigned char uint8_t
Definition: stdint.h:124
unsigned __int64 uint64_t
Definition: stdint.h:136
Definition: block_position_encoding.hpp:48
Definition: block_attention_bias_enum.hpp:19
Definition: block_dropout.hpp:53
const float rp_undrop
Definition: block_dropout.hpp:377
Definition: block_position_encoding.hpp:137
Definition: fmha_fwd_kernel.hpp:330
ck_tile::index_t kv_head_idx
Definition: fmha_fwd_kernel.hpp:333
ck_tile::index_t batch_idx
Definition: fmha_fwd_kernel.hpp:331
ck_tile::index_t qo_head_idx
Definition: fmha_fwd_kernel.hpp:332
Definition: fmha_fwd_kernel.hpp:194
ck_tile::index_t alibi_slope_stride
Definition: fmha_fwd_kernel.hpp:197
const void * alibi_slope_ptr
Definition: fmha_fwd_kernel.hpp:196
Definition: fmha_fwd_kernel.hpp:189
ck_tile::index_t batch_stride_bias
Definition: fmha_fwd_kernel.hpp:190
Definition: fmha_fwd_kernel.hpp:270
ck_tile::index_t batch_stride_randval
Definition: fmha_fwd_kernel.hpp:271
Definition: fmha_fwd_kernel.hpp:291
const ck_tile::index_t * cu_seqlen_kv_ptr
Definition: fmha_fwd_kernel.hpp:300
ck_tile::index_t batch_stride_o
Definition: fmha_fwd_kernel.hpp:295
ck_tile::index_t batch_stride_q
Definition: fmha_fwd_kernel.hpp:292
const ck_tile::index_t * cu_seqlen_q_ptr
Definition: fmha_fwd_kernel.hpp:299
ck_tile::index_t batch_stride_k
Definition: fmha_fwd_kernel.hpp:293
ck_tile::index_t batch_stride_v
Definition: fmha_fwd_kernel.hpp:294
Definition: fmha_fwd_kernel.hpp:182
const void * bias_ptr
Definition: fmha_fwd_kernel.hpp:183
ck_tile::index_t stride_bias
Definition: fmha_fwd_kernel.hpp:184
ck_tile::index_t nhead_stride_bias
Definition: fmha_fwd_kernel.hpp:185
Definition: fmha_fwd_kernel.hpp:235
void init_dropout(float p_drop, const uint64_t *seed_ptr, const uint64_t *offset_ptr)
Definition: fmha_fwd_kernel.hpp:248
float rp_undrop
Definition: fmha_fwd_kernel.hpp:260
ck_tile::index_t stride_randval
Definition: fmha_fwd_kernel.hpp:265
ck_tile::index_t nhead_stride_randval
Definition: fmha_fwd_kernel.hpp:266
void * rand_val_ptr
Definition: fmha_fwd_kernel.hpp:263
void init_dropout(float p_drop, uint64_t seed, uint64_t offset)
Definition: fmha_fwd_kernel.hpp:236
bool is_store_randval
Definition: fmha_fwd_kernel.hpp:262
uint8_t p_undrop_in_uint8_t
Definition: fmha_fwd_kernel.hpp:261
Definition: fmha_fwd_kernel.hpp:131
ck_tile::index_t nhead_stride_k
Definition: fmha_fwd_kernel.hpp:154
float scale_s
Definition: fmha_fwd_kernel.hpp:146
ck_tile::index_t seqlen_k
Definition: fmha_fwd_kernel.hpp:138
ck_tile::index_t nhead_stride_o
Definition: fmha_fwd_kernel.hpp:156
ck_tile::index_t nhead_ratio_qk
Definition: fmha_fwd_kernel.hpp:145
ck_tile::index_t num_head_q
Definition: fmha_fwd_kernel.hpp:142
ck_tile::index_t hdim_q
Definition: fmha_fwd_kernel.hpp:139
const void * v_ptr
Definition: fmha_fwd_kernel.hpp:134
void * o_ptr
Definition: fmha_fwd_kernel.hpp:135
const void * k_ptr
Definition: fmha_fwd_kernel.hpp:133
ck_tile::index_t nhead_stride_q
Definition: fmha_fwd_kernel.hpp:153
ck_tile::index_t stride_k
Definition: fmha_fwd_kernel.hpp:149
ck_tile::index_t stride_o
Definition: fmha_fwd_kernel.hpp:151
ck_tile::index_t stride_v
Definition: fmha_fwd_kernel.hpp:150
ck_tile::index_t hdim_v
Definition: fmha_fwd_kernel.hpp:140
ck_tile::index_t nhead_stride_v
Definition: fmha_fwd_kernel.hpp:155
const void * q_ptr
Definition: fmha_fwd_kernel.hpp:132
ck_tile::index_t seqlen_q
Definition: fmha_fwd_kernel.hpp:137
ck_tile::index_t stride_q
Definition: fmha_fwd_kernel.hpp:148
Definition: fmha_fwd_kernel.hpp:214
ck_tile::index_t batch_stride_lse
Definition: fmha_fwd_kernel.hpp:217
void * lse_ptr
Definition: fmha_fwd_kernel.hpp:215
ck_tile::index_t nhead_stride_lse
Definition: fmha_fwd_kernel.hpp:216
Definition: fmha_fwd_kernel.hpp:221
bool is_drop_seed_offset_from_host
Definition: fmha_fwd_kernel.hpp:231
ValueOrPointer< uint64_t > drop_seed
Definition: fmha_fwd_kernel.hpp:229
ValueOrPointer< uint64_t > drop_offset
Definition: fmha_fwd_kernel.hpp:230
Definition: fmha_fwd_kernel.hpp:124
Definition: fmha_fwd_kernel.hpp:208
float scale_o
Definition: fmha_fwd_kernel.hpp:210
float scale_p
Definition: fmha_fwd_kernel.hpp:209
Definition: fmha_fwd_kernel.hpp:316
const int32_t * seqstart_q_ptr
Definition: fmha_fwd_kernel.hpp:317
const int32_t * seqlen_k_ptr
Definition: fmha_fwd_kernel.hpp:319
const int32_t * seqstart_padded_q_ptr
Definition: fmha_fwd_kernel.hpp:323
const int32_t * seqstart_padded_k_ptr
Definition: fmha_fwd_kernel.hpp:324
const int32_t * seqstart_k_ptr
Definition: fmha_fwd_kernel.hpp:318
Definition: fmha_fwd_kernel.hpp:160
float logits_soft_cap
Definition: fmha_fwd_kernel.hpp:177
float logits_soft_cap_rcp
Definition: fmha_fwd_kernel.hpp:178
void init_logits_soft_cap(float logits_soft_cap_)
Definition: fmha_fwd_kernel.hpp:163
Definition: fmha_fwd_kernel.hpp:201
ck_tile::GenericAttentionMaskEnum mask_type
Definition: fmha_fwd_kernel.hpp:204
ck_tile::index_t window_size_right
Definition: fmha_fwd_kernel.hpp:203
ck_tile::index_t window_size_left
Definition: fmha_fwd_kernel.hpp:203
Definition: fmha_fwd_kernel.hpp:275
ck_tile::index_t min_seqlen_q
Definition: fmha_fwd_kernel.hpp:276
Definition: fmha_fwd_kernel.hpp:75
Definition: fmha_fwd_kernel.hpp:27
static constexpr bool kHasDropout
Definition: fmha_fwd_kernel.hpp:56
static CK_TILE_HOST std::string GetName()
Definition: fmha_fwd_kernel.hpp:85
static constexpr bool kIsAvailable
Definition: fmha_fwd_kernel.hpp:70
static constexpr bool kDoFp8StaticQuant
Definition: fmha_fwd_kernel.hpp:57
static constexpr bool kStoreLSE
Definition: fmha_fwd_kernel.hpp:55
ck_tile::remove_cvref_t< typename FmhaPipeline::KDataType > KDataType
Definition: fmha_fwd_kernel.hpp:37
std::conditional_t< kIsGroupMode, FmhaFwdGroupModeKargs, FmhaFwdBatchModeKargs > Kargs
Definition: fmha_fwd_kernel.hpp:327
static constexpr ck_tile::index_t kBlockPerCu
Definition: fmha_fwd_kernel.hpp:32
ck_tile::remove_cvref_t< typename FmhaPipeline::ODataType > ODataType
Definition: fmha_fwd_kernel.hpp:43
ck_tile::remove_cvref_t< typename FmhaPipeline::VLayout > VLayout
Definition: fmha_fwd_kernel.hpp:46
static constexpr ck_tile::index_t kBlockSize
Definition: fmha_fwd_kernel.hpp:30
ck_tile::remove_cvref_t< typename FmhaPipeline::BiasDataType > BiasDataType
Definition: fmha_fwd_kernel.hpp:39
static constexpr CK_TILE_HOST auto GridSize(ck_tile::index_t batch_size_, ck_tile::index_t nhead_, ck_tile::index_t seqlen_q_, ck_tile::index_t hdim_v_, bool has_padded_seqlen_k=false)
Definition: fmha_fwd_kernel.hpp:986
static constexpr CK_TILE_HOST auto BlockSize()
Definition: fmha_fwd_kernel.hpp:1076
ck_tile::remove_cvref_t< typename FmhaPipeline::VDataType > VDataType
Definition: fmha_fwd_kernel.hpp:38
static constexpr ck_tile::index_t kBlockPerCuInput
Definition: fmha_fwd_kernel.hpp:34
static constexpr auto BiasEnum
Definition: fmha_fwd_kernel.hpp:54
static constexpr CK_TILE_HOST std::enable_if_t< Cond, Kargs > MakeKargs(const void *q_ptr, const void *k_ptr, const void *v_ptr, const void *bias_ptr, void *rand_val_ptr, void *lse_ptr, void *o_ptr, const void *seqstart_q_ptr, const void *seqstart_k_ptr, const void *seqlen_k_ptr, ck_tile::index_t hdim_q, ck_tile::index_t hdim_v, ck_tile::index_t num_head_q, ck_tile::index_t nhead_ratio_qk, float scale_s, float scale_p, float scale_o, float logits_soft_cap, ck_tile::index_t stride_q, ck_tile::index_t stride_k, ck_tile::index_t stride_v, ck_tile::index_t stride_bias, ck_tile::index_t stride_randval, ck_tile::index_t stride_o, ck_tile::index_t nhead_stride_q, ck_tile::index_t nhead_stride_k, ck_tile::index_t nhead_stride_v, ck_tile::index_t nhead_stride_bias, ck_tile::index_t nhead_stride_randval, ck_tile::index_t nhead_stride_lse, ck_tile::index_t nhead_stride_o, ck_tile::index_t window_size_left, ck_tile::index_t window_size_right, ck_tile::index_t mask_type, ck_tile::index_t min_seqlen_q, float p_drop, bool s_randval, const std::tuple< const void *, const void * > &drop_seed_offset, const void *seqstart_padded_q_ptr=nullptr, const void *seqstart_padded_k_ptr=nullptr)
Definition: fmha_fwd_kernel.hpp:902
static constexpr bool kPadHeadDimV
Definition: fmha_fwd_kernel.hpp:52
static constexpr CK_TILE_DEVICE auto GetTileIndex(const Kargs &kargs)
Definition: fmha_fwd_kernel.hpp:1011
static constexpr bool kSkipMinSeqlenQ
Definition: fmha_fwd_kernel.hpp:58
static constexpr std::string_view kPipelineName
Definition: fmha_fwd_kernel.hpp:72
ck_tile::remove_cvref_t< typename FmhaPipeline::LSEDataType > LSEDataType
Definition: fmha_fwd_kernel.hpp:42
ck_tile::remove_cvref_t< typename FmhaPipeline::QDataType > QDataType
Definition: fmha_fwd_kernel.hpp:36
CK_TILE_DEVICE void run_(Kargs kargs) const
Definition: fmha_fwd_kernel.hpp:1089
ck_tile::remove_cvref_t< typename FmhaPipeline::AttentionVariant > AttentionVariant
Definition: fmha_fwd_kernel.hpp:60
static constexpr CK_TILE_HOST_DEVICE ck_tile::index_t GetSmemSize()
Definition: fmha_fwd_kernel.hpp:1078
static constexpr bool kUseTrLoad
Definition: fmha_fwd_kernel.hpp:66
static constexpr CK_TILE_HOST std::enable_if_t< Cond, Kargs > MakeKargsImpl(const void *q_ptr, const void *k_ptr, const void *v_ptr, const void *bias_ptr, void *rand_val_ptr, void *lse_ptr, void *o_ptr, const void *seqstart_q_ptr, const void *seqstart_k_ptr, const void *seqlen_k_ptr, ck_tile::index_t hdim_q, ck_tile::index_t hdim_v, ck_tile::index_t num_head_q, ck_tile::index_t nhead_ratio_qk, float scale_s, float scale_p, float scale_o, float logits_soft_cap, ck_tile::index_t stride_q, ck_tile::index_t stride_k, ck_tile::index_t stride_v, ck_tile::index_t stride_bias, ck_tile::index_t stride_randval, ck_tile::index_t stride_o, ck_tile::index_t nhead_stride_q, ck_tile::index_t nhead_stride_k, ck_tile::index_t nhead_stride_v, ck_tile::index_t nhead_stride_bias, ck_tile::index_t nhead_stride_randval, ck_tile::index_t nhead_stride_lse, ck_tile::index_t nhead_stride_o, ck_tile::index_t window_size_left, ck_tile::index_t window_size_right, ck_tile::index_t mask_type, ck_tile::index_t min_seqlen_q, float p_drop, bool s_randval, std::variant< std::pair< uint64_t, uint64_t >, std::pair< const void *, const void * >> drop_seed_offset, const void *seqstart_padded_q_ptr=nullptr, const void *seqstart_padded_k_ptr=nullptr)
Definition: fmha_fwd_kernel.hpp:675
static constexpr CK_TILE_HOST std::enable_if_t< Cond, Kargs > MakeKargsImpl(const void *q_ptr, const void *k_ptr, const void *v_ptr, const void *bias_ptr, void *rand_val_ptr, void *lse_ptr, void *o_ptr, ck_tile::index_t seqlen_q, ck_tile::index_t seqlen_k, ck_tile::index_t hdim_q, ck_tile::index_t hdim_v, ck_tile::index_t num_head_q, ck_tile::index_t nhead_ratio_qk, float scale_s, float scale_p, float scale_o, float logits_soft_cap, ck_tile::index_t stride_q, ck_tile::index_t stride_k, ck_tile::index_t stride_v, ck_tile::index_t stride_bias, ck_tile::index_t stride_randval, ck_tile::index_t stride_o, ck_tile::index_t nhead_stride_q, ck_tile::index_t nhead_stride_k, ck_tile::index_t nhead_stride_v, ck_tile::index_t nhead_stride_bias, ck_tile::index_t nhead_stride_randval, ck_tile::index_t nhead_stride_lse, ck_tile::index_t nhead_stride_o, ck_tile::index_t batch_stride_q, ck_tile::index_t batch_stride_k, ck_tile::index_t batch_stride_v, ck_tile::index_t batch_stride_bias, ck_tile::index_t batch_stride_randval, ck_tile::index_t batch_stride_lse, ck_tile::index_t batch_stride_o, ck_tile::index_t window_size_left, ck_tile::index_t window_size_right, ck_tile::index_t mask_type, float p_drop, bool s_randval, std::variant< std::pair< uint64_t, uint64_t >, std::pair< const void *, const void * >> drop_seed_offset, const ck_tile::index_t *cu_seqlen_q_ptr=nullptr, const ck_tile::index_t *cu_seqlen_kv_ptr=nullptr)
Definition: fmha_fwd_kernel.hpp:338
static constexpr bool kHasMask
Definition: fmha_fwd_kernel.hpp:62
static constexpr bool kUseAsyncCopy
Definition: fmha_fwd_kernel.hpp:64
ck_tile::remove_cvref_t< FmhaPipeline_ > FmhaPipeline
Definition: fmha_fwd_kernel.hpp:28
static constexpr CK_TILE_HOST std::enable_if_t< Cond, Kargs > MakeKargs(const void *q_ptr, const void *k_ptr, const void *v_ptr, const void *bias_ptr, void *rand_val_ptr, void *lse_ptr, void *o_ptr, ck_tile::index_t seqlen_q, ck_tile::index_t seqlen_k, ck_tile::index_t hdim_q, ck_tile::index_t hdim_v, ck_tile::index_t num_head_q, ck_tile::index_t nhead_ratio_qk, float scale_s, float scale_p, float scale_o, float logits_soft_cap, ck_tile::index_t stride_q, ck_tile::index_t stride_k, ck_tile::index_t stride_v, ck_tile::index_t stride_bias, ck_tile::index_t stride_randval, ck_tile::index_t stride_o, ck_tile::index_t nhead_stride_q, ck_tile::index_t nhead_stride_k, ck_tile::index_t nhead_stride_v, ck_tile::index_t nhead_stride_bias, ck_tile::index_t nhead_stride_randval, ck_tile::index_t nhead_stride_lse, ck_tile::index_t nhead_stride_o, ck_tile::index_t batch_stride_q, ck_tile::index_t batch_stride_k, ck_tile::index_t batch_stride_v, ck_tile::index_t batch_stride_bias, ck_tile::index_t batch_stride_randval, ck_tile::index_t batch_stride_lse, ck_tile::index_t batch_stride_o, ck_tile::index_t window_size_left, ck_tile::index_t window_size_right, ck_tile::index_t mask_type, float p_drop, bool s_randval, const std::tuple< uint64_t, uint64_t > &drop_seed_offset, const ck_tile::index_t *cu_seqlen_q_ptr=nullptr, const ck_tile::index_t *cu_seqlen_kv_ptr=nullptr)
Definition: fmha_fwd_kernel.hpp:482
static constexpr bool kPadHeadDimQ
Definition: fmha_fwd_kernel.hpp:51
ck_tile::remove_cvref_t< typename FmhaPipeline::SaccDataType > SaccDataType
Definition: fmha_fwd_kernel.hpp:44
static constexpr bool kPadSeqLenQ
Definition: fmha_fwd_kernel.hpp:49
ck_tile::remove_cvref_t< typename FmhaPipeline::FmhaMask > FmhaMask
Definition: fmha_fwd_kernel.hpp:61
static constexpr CK_TILE_HOST std::enable_if_t< Cond, Kargs > MakeKargs(const void *q_ptr, const void *k_ptr, const void *v_ptr, const void *bias_ptr, void *rand_val_ptr, void *lse_ptr, void *o_ptr, ck_tile::index_t seqlen_q, ck_tile::index_t seqlen_k, ck_tile::index_t hdim_q, ck_tile::index_t hdim_v, ck_tile::index_t num_head_q, ck_tile::index_t nhead_ratio_qk, float scale_s, float scale_p, float scale_o, float logits_soft_cap, ck_tile::index_t stride_q, ck_tile::index_t stride_k, ck_tile::index_t stride_v, ck_tile::index_t stride_bias, ck_tile::index_t stride_randval, ck_tile::index_t stride_o, ck_tile::index_t nhead_stride_q, ck_tile::index_t nhead_stride_k, ck_tile::index_t nhead_stride_v, ck_tile::index_t nhead_stride_bias, ck_tile::index_t nhead_stride_randval, ck_tile::index_t nhead_stride_lse, ck_tile::index_t nhead_stride_o, ck_tile::index_t batch_stride_q, ck_tile::index_t batch_stride_k, ck_tile::index_t batch_stride_v, ck_tile::index_t batch_stride_bias, ck_tile::index_t batch_stride_randval, ck_tile::index_t batch_stride_lse, ck_tile::index_t batch_stride_o, ck_tile::index_t window_size_left, ck_tile::index_t window_size_right, ck_tile::index_t mask_type, float p_drop, bool s_randval, const std::tuple< const void *, const void * > &drop_seed_offset, const ck_tile::index_t *cu_seqlen_q_ptr=nullptr, const ck_tile::index_t *cu_seqlen_kv_ptr=nullptr)
Definition: fmha_fwd_kernel.hpp:579
static constexpr CK_TILE_HOST std::enable_if_t< Cond, Kargs > MakeKargs(const void *q_ptr, const void *k_ptr, const void *v_ptr, const void *bias_ptr, void *rand_val_ptr, void *lse_ptr, void *o_ptr, const void *seqstart_q_ptr, const void *seqstart_k_ptr, const void *seqlen_k_ptr, ck_tile::index_t hdim_q, ck_tile::index_t hdim_v, ck_tile::index_t num_head_q, ck_tile::index_t nhead_ratio_qk, float scale_s, float scale_p, float scale_o, float logits_soft_cap, ck_tile::index_t stride_q, ck_tile::index_t stride_k, ck_tile::index_t stride_v, ck_tile::index_t stride_bias, ck_tile::index_t stride_randval, ck_tile::index_t stride_o, ck_tile::index_t nhead_stride_q, ck_tile::index_t nhead_stride_k, ck_tile::index_t nhead_stride_v, ck_tile::index_t nhead_stride_bias, ck_tile::index_t nhead_stride_randval, ck_tile::index_t nhead_stride_lse, ck_tile::index_t nhead_stride_o, ck_tile::index_t window_size_left, ck_tile::index_t window_size_right, ck_tile::index_t mask_type, ck_tile::index_t min_seqlen_q, float p_drop, bool s_randval, const std::tuple< uint64_t, uint64_t > &drop_seed_offset, const void *seqstart_padded_q_ptr=nullptr, const void *seqstart_padded_k_ptr=nullptr)
Definition: fmha_fwd_kernel.hpp:815
static constexpr bool kHasLogitsSoftCap
Definition: fmha_fwd_kernel.hpp:53
ck_tile::remove_cvref_t< typename FmhaPipeline::RandValOutputDataType > RandValOutputDataType
Definition: fmha_fwd_kernel.hpp:41
ck_tile::remove_cvref_t< EpiloguePipeline_ > EpiloguePipeline
Definition: fmha_fwd_kernel.hpp:29
static constexpr bool kPadSeqLenK
Definition: fmha_fwd_kernel.hpp:50
static constexpr bool kIsGroupMode
Definition: fmha_fwd_kernel.hpp:48
CK_TILE_DEVICE void operator()(Kargs kargs) const
Definition: fmha_fwd_kernel.hpp:1083
Definition: variants.hpp:63
float logits_soft_cap
Definition: variants.hpp:128
Definition: block_dropout.hpp:39
Definition: variants.hpp:51
Definition: integral_constant.hpp:13
Definition: functional.hpp:86
Definition: coordinate_transform.hpp:1392
Definition: unary_element_function.hpp:56
Definition: math.hpp:28
Definition: sequence.hpp:49
const T * ptr
Definition: fmha_fwd_kernel.hpp:226