/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp Source File
fmha_fwd_kernel.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include "ck_tile/core.hpp"
7 #include "ck_tile/ops/common.hpp"
10 
11 #include <string>
12 #include <type_traits>
13 #include <utility>
14 #include <variant>
15 
16 #define CK_TILE_FMHA_HANDLE_XOR_LENGTH_FOLD 0
17 // S[seqlen_q, seqlen_k] = Q[seqlen_q, hdim_q] @ K[seqlen_k, hdim_q]
18 // S'[seqlen_q, seqlen_k] = S[seqlen_q, seqlen_k] * Scale[1]
19 // S''[seqlen_q, seqlen_k] = S'[seqlen_q, seqlen_k] + Bias[seqlen_q, seqlen_k]
20 // P[seqlen_q, seqlen_k] = Softmax(S''[seqlen_q, seqlen_k])
21 // O[seqlen_q, hdim_v] = P[seqlen_q, seqlen_k] @ V^T[hdim_v, seqlen_k]
22 
23 namespace ck_tile {
24 
25 template <typename FmhaPipeline_, typename EpiloguePipeline_>
27 {
30  static constexpr ck_tile::index_t kBlockSize = FmhaPipeline::kBlockSize;
31 
32  static constexpr ck_tile::index_t kBlockPerCu = FmhaPipeline::kBlockPerCu;
33  static_assert(kBlockPerCu > 0);
34  static constexpr ck_tile::index_t kBlockPerCuInput = FmhaPipeline::Problem::kBlockPerCu;
35 
45 
47 
48  static constexpr bool kIsGroupMode = FmhaPipeline::kIsGroupMode;
49  static constexpr bool kPadSeqLenQ = FmhaPipeline::kPadSeqLenQ;
50  static constexpr bool kPadSeqLenK = FmhaPipeline::kPadSeqLenK;
51  static constexpr bool kPadHeadDimQ = FmhaPipeline::kPadHeadDimQ;
52  static constexpr bool kPadHeadDimV = FmhaPipeline::kPadHeadDimV;
53  static constexpr bool kHasLogitsSoftCap = FmhaPipeline::kHasLogitsSoftCap;
54  static constexpr auto BiasEnum = FmhaPipeline::BiasEnum;
55  static constexpr bool kStoreLSE = FmhaPipeline::kStoreLSE;
56  static constexpr bool kHasDropout = FmhaPipeline::kHasDropout;
57  static constexpr bool kDoFp8StaticQuant = FmhaPipeline::Problem::kDoFp8StaticQuant;
58  static constexpr bool kSkipMinSeqlenQ = FmhaPipeline::Problem::kSkipMinSeqlenQ;
59 
62  static constexpr bool kHasMask = FmhaMask::IsMasking;
63 
64  static constexpr bool kUseAsyncCopy = FmhaPipeline::Policy::AsyncCopy;
65 
66  static constexpr bool kUseTrLoad = FmhaPipeline::Problem::kUseTrLoad;
67 #if defined(__gfx950__)
68  static constexpr bool kIsAvailable = true;
69 #else
70  static constexpr bool kIsAvailable = !kUseTrLoad;
71 #endif
72  static constexpr std::string_view kPipelineName = FmhaPipeline::name;
73 
74  // clang-format off
75  template <typename T1, typename T2 = T1> struct t2s;
76  template <> struct t2s<float> { static constexpr const char * name = "fp32"; };
77  template <> struct t2s<ck_tile::fp16_t> { static constexpr const char * name = "fp16"; };
78  template <> struct t2s<ck_tile::bf16_t> { static constexpr const char * name = "bf16"; };
79  template <> struct t2s<ck_tile::fp8_t> { static constexpr const char * name = "fp8"; };
80  template <> struct t2s<ck_tile::bf8_t> { static constexpr const char * name = "bf8"; };
81  template <> struct t2s<ck_tile::fp8_t, ck_tile::bf16_t> { static constexpr const char * name = "fp8bf16"; };
82  template <> struct t2s<ck_tile::fp8_t, ck_tile::fp32_t> { static constexpr const char * name = "fp8fp32"; };
83  // clang-format on
84 
85  CK_TILE_HOST static std::string GetName()
86  {
87  // sync with generate.py
88  // clang-format off
89  using bfs = typename FmhaPipeline::BlockFmhaShape;
90  using g0br = typename bfs::Gemm0BlockWarps;
91  using g1br = typename bfs::Gemm1BlockWarps;
92  using g0wt = typename bfs::Gemm0WarpTile;
93  using g1wt = typename bfs::Gemm1WarpTile;
94  #define _SS_ std::string
95  #define _TS_ std::to_string
96  auto pn = [&] () {
97  std::string n;
98  if (kPadSeqLenQ) n += "s";
99  if (kPadSeqLenK) n += "sk";
100  if (kPadHeadDimQ) n += "d";
101  if (kPadHeadDimV) n += "dv";
102  return n.empty() ? n : std::string("p") + n; }();
103  return
104  _SS_("fmha_fwd_d") + _TS_(bfs::kQKHeaddim) + "_" + _SS_(t2s<QDataType, ODataType>::name) +
105  "_" + (kIsGroupMode ? "group" : "batch") + "_"
106  "b" + _TS_(bfs::kM0) + "x" + _TS_(bfs::kN0) + "x" + _TS_(bfs::kK0) + "x" +
107  _TS_(bfs::kN1) + "x" + _TS_(bfs::kK1) + "x" + _TS_(bfs::kQKHeaddim) + "_" +
108  "r" + _TS_(g0br::at(ck_tile::number<0>{})) + "x" + _TS_(g0br::at(ck_tile::number<1>{})) + "x" + _TS_(g0br::at(ck_tile::number<2>{})) + "_" +
109  "r" + _TS_(g1br::at(ck_tile::number<0>{})) + "x" + _TS_(g1br::at(ck_tile::number<1>{})) + "x" + _TS_(g1br::at(ck_tile::number<2>{})) + "_" +
110  "w" + _TS_(g0wt::at(ck_tile::number<0>{})) + "x" + _TS_(g0wt::at(ck_tile::number<1>{})) + "x" + _TS_(g0wt::at(ck_tile::number<2>{})) + "_" +
111  "w" + _TS_(g1wt::at(ck_tile::number<0>{})) + "x" + _TS_(g1wt::at(ck_tile::number<1>{})) + "x" + _TS_(g1wt::at(ck_tile::number<2>{})) + "_" +
112  (kBlockPerCuInput == -1 ? "" : ("o" + _TS_(kBlockPerCu) + "_")) + _SS_(FmhaPipeline::name) + "_" +
113  "v" + (std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor> ? "r" : "c") + (pn.empty() ? "_npad" : "_" + pn) +
114  (kHasLogitsSoftCap ? "_logits" : "_nlogits" ) + (BiasEnum == BlockAttentionBiasEnum::NO_BIAS ? _SS_("_nbias") : (_SS_("_") + BlockAttentionBiasEnumToStr<BiasEnum>::name)) +
115  (kHasMask ? "_" + _SS_(FmhaMask::name) : "_nmask") + (kStoreLSE ? "_lse" : "_nlse" ) + (kHasDropout ? "_dropout" : "_ndropout" ) + (kSkipMinSeqlenQ ? "_skip" : "_nskip" ) + (kDoFp8StaticQuant ? "_squant" : "_nsquant" ) + (kUseTrLoad ? "_trload" : "_ntrload");
116  #undef _SS_
117  #undef _TS_
118  // clang-format on
119  }
120 
121  template <ck_tile::index_t I> // to avoid duplicated base class prblem, introduce an template
122  // arg
124  {
125  };
126 
127  // kargs use aggregate initializer, so no constructor will provided
128  // use inheritance to minimize karg size
129  // user need to use MakeKargs() function to create kargs.
131  {
132  const void* q_ptr;
133  const void* k_ptr;
134  const void* v_ptr;
135  void* o_ptr;
136 
141 
143  // for MQA/GQA, nhead could be different. This parameter is nhead_q / nhead_k
144  // if this param is larger than 1, indicate MQA/GQA case
146  float scale_s;
147 
152 
157  };
158 
160  {
162 
163  void init_logits_soft_cap(float logits_soft_cap_)
164  {
165  if(0 < logits_soft_cap_)
166  {
167  logits_soft_cap = logits_soft_cap_;
169  }
170  else
171  {
172  logits_soft_cap = 0.f;
173  logits_soft_cap_rcp = 0.f;
174  }
175  }
176 
179  };
180 
182  {
183  const void* bias_ptr = nullptr;
186  };
187 
189  {
191  };
192 
194  {
195  // alibi is batch*nhead*1, no matter in batch/group mode, they are the same
196  const void* alibi_slope_ptr;
197  ck_tile::index_t alibi_slope_stride; // stride in batch, or 0 for all batch share same slope
198  };
199 
201  {
202  // ck_tile::index_t window_size_left, window_size_right;
205  };
206 
208  {
209  float scale_p;
210  float scale_o;
211  };
212 
214  {
215  void* lse_ptr = nullptr;
218  };
219 
221  {
222  template <typename T>
224  {
225  T val;
226  const T* ptr;
227  };
228 
232  };
233 
235  {
236  void init_dropout(float p_drop, uint64_t seed, uint64_t offset)
237  {
238  float p_undrop = 1.0 - p_drop;
241  rp_undrop = 1.0 / p_undrop;
242 
243  this->drop_seed.val = seed;
244  this->drop_offset.val = offset;
245  this->is_drop_seed_offset_from_host = true;
246  }
247 
248  void init_dropout(float p_drop, const uint64_t* seed_ptr, const uint64_t* offset_ptr)
249  {
250  float p_undrop = 1.0 - p_drop;
253  rp_undrop = 1.0 / p_undrop;
254 
255  this->drop_seed.ptr = seed_ptr;
256  this->drop_offset.ptr = offset_ptr;
257  this->is_drop_seed_offset_from_host = false;
258  }
259 
260  float rp_undrop = 1;
262  bool is_store_randval = false;
263  void* rand_val_ptr = nullptr;
264 
267  };
268 
270  {
272  };
273 
275  {
277  };
278 
281  std::conditional_t<BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS,
282  FmhaFwdBatchModeBiasKargs,
283  std::conditional_t<BiasEnum == BlockAttentionBiasEnum::ALIBI,
284  FmhaFwdAlibiKargs,
285  FmhaFwdEmptyKargs<0>>>,
286  std::conditional_t<kHasMask, FmhaFwdMaskKargs, FmhaFwdEmptyKargs<1>>,
287  std::conditional_t<kStoreLSE, FmhaFwdCommonLSEKargs, FmhaFwdEmptyKargs<2>>,
288  std::conditional_t<kDoFp8StaticQuant, FmhaFwdFp8StaticQuantKargs, FmhaFwdEmptyKargs<3>>,
289  std::conditional_t<kHasDropout, FmhaFwdBatchModeDropoutKargs, FmhaFwdEmptyKargs<4>>,
290  std::conditional_t<kHasLogitsSoftCap, FmhaFwdLogitsSoftCapKargs, FmhaFwdEmptyKargs<5>>
291  {
296 
297  // Optional cumulative sequence length pointers for batch mode
298  // If provided, they override seqlen_q / seqlen_k per-batch to skip tail padding.
299  const int32_t* cu_seqlen_q_ptr = nullptr; // cumulative, length without PAD
300  const int32_t* cu_seqlen_k_ptr = nullptr; // cumulative, length without PAD
301  };
302 
305  std::conditional_t<BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS,
306  FmhaFwdCommonBiasKargs,
307  std::conditional_t<BiasEnum == BlockAttentionBiasEnum::ALIBI,
308  FmhaFwdAlibiKargs,
309  FmhaFwdEmptyKargs<0>>>,
310  std::conditional_t<kHasMask, FmhaFwdMaskKargs, FmhaFwdEmptyKargs<1>>,
311  std::conditional_t<kStoreLSE, FmhaFwdCommonLSEKargs, FmhaFwdEmptyKargs<2>>,
312  std::conditional_t<kDoFp8StaticQuant, FmhaFwdFp8StaticQuantKargs, FmhaFwdEmptyKargs<3>>,
313  std::conditional_t<kHasDropout, FmhaFwdCommonDropoutKargs, FmhaFwdEmptyKargs<4>>,
314  std::conditional_t<kHasLogitsSoftCap, FmhaFwdLogitsSoftCapKargs, FmhaFwdEmptyKargs<5>>,
315  std::conditional_t<kSkipMinSeqlenQ, FmhaFwdSkipMinSeqlenQKargs, FmhaFwdEmptyKargs<6>>
316  {
321 
322  // Optional per-sequence and cumulative logical (excluding padding) sequence length arrays
323  const int32_t* cu_seqlen_q_ptr = nullptr;
324  const int32_t* cu_seqlen_k_ptr = nullptr;
325  };
326 
327  using Kargs = std::conditional_t<kIsGroupMode, FmhaFwdGroupModeKargs, FmhaFwdBatchModeKargs>;
328 
330  {
334  };
335 
336  template <bool Cond = !kIsGroupMode>
337  CK_TILE_HOST static constexpr std::enable_if_t<Cond, Kargs>
338  MakeKargsImpl(const void* q_ptr,
339  const void* k_ptr,
340  const void* v_ptr,
341  const void* bias_ptr,
342  void* rand_val_ptr,
343  void* lse_ptr,
344  void* o_ptr,
345  ck_tile::index_t seqlen_q,
346  ck_tile::index_t seqlen_k,
347  ck_tile::index_t hdim_q,
348  ck_tile::index_t hdim_v,
349  ck_tile::index_t num_head_q,
350  ck_tile::index_t nhead_ratio_qk,
351  float scale_s,
352  float scale_p,
353  float scale_o,
354  float logits_soft_cap,
355  ck_tile::index_t stride_q,
356  ck_tile::index_t stride_k,
357  ck_tile::index_t stride_v,
358  ck_tile::index_t stride_bias,
359  ck_tile::index_t stride_randval,
360  ck_tile::index_t stride_o,
361  ck_tile::index_t nhead_stride_q,
362  ck_tile::index_t nhead_stride_k,
363  ck_tile::index_t nhead_stride_v,
364  ck_tile::index_t nhead_stride_bias,
365  ck_tile::index_t nhead_stride_randval,
366  ck_tile::index_t nhead_stride_lse,
367  ck_tile::index_t nhead_stride_o,
368  ck_tile::index_t batch_stride_q,
369  ck_tile::index_t batch_stride_k,
370  ck_tile::index_t batch_stride_v,
371  ck_tile::index_t batch_stride_bias,
372  ck_tile::index_t batch_stride_randval,
373  ck_tile::index_t batch_stride_lse,
374  ck_tile::index_t batch_stride_o,
375  ck_tile::index_t window_size_left,
376  ck_tile::index_t window_size_right,
377  ck_tile::index_t mask_type,
378  float p_drop,
379  bool s_randval,
380  std::variant<std::pair<uint64_t, uint64_t>, std::pair<const void*, const void*>>
381  drop_seed_offset,
382  const void* cu_seqlen_q_ptr = nullptr,
383  const void* cu_seqlen_k_ptr = nullptr)
384  {
385  Kargs kargs{{q_ptr,
386  k_ptr,
387  v_ptr,
388  o_ptr,
389  seqlen_q,
390  seqlen_k,
391  hdim_q,
392  hdim_v,
393  num_head_q,
394  nhead_ratio_qk,
395 #if CK_TILE_FMHA_FWD_FAST_EXP2
396  static_cast<float>(scale_s * ck_tile::log2e_v<>),
397 #else
398  scale_s,
399 #endif
400  stride_q,
401  stride_k,
402  stride_v,
403  stride_o,
404  nhead_stride_q,
405  nhead_stride_k,
406  nhead_stride_v,
407  nhead_stride_o}, // args for common karg
408  {}, // placeholder for bias
409  {}, // placeholder for mask
410  {}, // placeholder for lse
411  {}, // placeholder for fp8_static_quant args
412  {}, // placeholder for dropout
413  {}, // placeholder for logits_soft_cap
414  batch_stride_q,
415  batch_stride_k,
416  batch_stride_v,
417  batch_stride_o};
418 
420  {
421  kargs.bias_ptr = bias_ptr;
422  kargs.stride_bias = stride_bias;
423  kargs.nhead_stride_bias = nhead_stride_bias;
424  kargs.batch_stride_bias = batch_stride_bias;
425  }
426  else if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI)
427  {
428  kargs.alibi_slope_ptr = bias_ptr;
429  kargs.alibi_slope_stride = stride_bias;
430  }
431  if constexpr(kHasMask)
432  {
433  kargs.window_size_left = window_size_left;
434  kargs.window_size_right = window_size_right;
435  kargs.mask_type = static_cast<ck_tile::GenericAttentionMaskEnum>(mask_type);
436  }
437  if constexpr(kStoreLSE)
438  {
439  kargs.lse_ptr = lse_ptr;
440  kargs.nhead_stride_lse = nhead_stride_lse;
441  kargs.batch_stride_lse = batch_stride_lse;
442  }
443  if constexpr(kDoFp8StaticQuant)
444  {
445  kargs.scale_p = scale_p;
446  kargs.scale_o = scale_o;
447  }
448  if constexpr(kHasDropout)
449  {
450  if(drop_seed_offset.index() == 0) // seed & offset come from host
451  {
452  const auto& [seed, offset] = std::get<0>(drop_seed_offset);
453  kargs.init_dropout(p_drop, seed, offset);
454  }
455  else // seed & offset come from device
456  {
457  const auto& [seed_ptr, offset_ptr] = std::get<1>(drop_seed_offset);
458  kargs.init_dropout(p_drop,
459  reinterpret_cast<const uint64_t*>(seed_ptr),
460  reinterpret_cast<const uint64_t*>(offset_ptr));
461  }
462 
463  kargs.rand_val_ptr = rand_val_ptr;
464  kargs.stride_randval = stride_randval;
465  kargs.nhead_stride_randval = nhead_stride_randval;
466  kargs.batch_stride_randval = batch_stride_randval;
467  kargs.is_store_randval = s_randval;
468  }
469  if constexpr(kHasLogitsSoftCap)
470  {
471  kargs.init_logits_soft_cap(logits_soft_cap);
472  }
473 
474  kargs.cu_seqlen_q_ptr = reinterpret_cast<const int32_t*>(cu_seqlen_q_ptr);
475  kargs.cu_seqlen_k_ptr = reinterpret_cast<const int32_t*>(cu_seqlen_k_ptr);
476  return kargs;
477  }
478 
479  // std::variant<> can't take in a list initializer, overload for backward compatibility
480  template <bool Cond = !kIsGroupMode>
481  CK_TILE_HOST static constexpr std::enable_if_t<Cond, Kargs>
482  MakeKargs(const void* q_ptr,
483  const void* k_ptr,
484  const void* v_ptr,
485  const void* bias_ptr,
486  void* rand_val_ptr,
487  void* lse_ptr,
488  void* o_ptr,
489  ck_tile::index_t seqlen_q,
490  ck_tile::index_t seqlen_k,
491  ck_tile::index_t hdim_q,
492  ck_tile::index_t hdim_v,
493  ck_tile::index_t num_head_q,
494  ck_tile::index_t nhead_ratio_qk,
495  float scale_s,
496  float scale_p,
497  float scale_o,
498  float logits_soft_cap,
499  ck_tile::index_t stride_q,
500  ck_tile::index_t stride_k,
501  ck_tile::index_t stride_v,
502  ck_tile::index_t stride_bias,
503  ck_tile::index_t stride_randval,
504  ck_tile::index_t stride_o,
505  ck_tile::index_t nhead_stride_q,
506  ck_tile::index_t nhead_stride_k,
507  ck_tile::index_t nhead_stride_v,
508  ck_tile::index_t nhead_stride_bias,
509  ck_tile::index_t nhead_stride_randval,
510  ck_tile::index_t nhead_stride_lse,
511  ck_tile::index_t nhead_stride_o,
512  ck_tile::index_t batch_stride_q,
513  ck_tile::index_t batch_stride_k,
514  ck_tile::index_t batch_stride_v,
515  ck_tile::index_t batch_stride_bias,
516  ck_tile::index_t batch_stride_randval,
517  ck_tile::index_t batch_stride_lse,
518  ck_tile::index_t batch_stride_o,
519  ck_tile::index_t window_size_left,
520  ck_tile::index_t window_size_right,
521  ck_tile::index_t mask_type,
522  float p_drop,
523  bool s_randval,
524  const std::tuple<uint64_t, uint64_t>& drop_seed_offset,
525  const void* cu_seqlen_q_ptr = nullptr,
526  const void* cu_seqlen_k_ptr = nullptr)
527  {
528  return MakeKargsImpl(
529  q_ptr,
530  k_ptr,
531  v_ptr,
532  bias_ptr,
533  rand_val_ptr,
534  lse_ptr,
535  o_ptr,
536  seqlen_q,
537  seqlen_k,
538  hdim_q,
539  hdim_v,
540  num_head_q,
541  nhead_ratio_qk,
542  scale_s,
543  scale_p,
544  scale_o,
545  logits_soft_cap,
546  stride_q,
547  stride_k,
548  stride_v,
549  stride_bias,
550  stride_randval,
551  stride_o,
552  nhead_stride_q,
553  nhead_stride_k,
554  nhead_stride_v,
555  nhead_stride_bias,
556  nhead_stride_randval,
557  nhead_stride_lse,
558  nhead_stride_o,
559  batch_stride_q,
560  batch_stride_k,
561  batch_stride_v,
562  batch_stride_bias,
563  batch_stride_randval,
564  batch_stride_lse,
565  batch_stride_o,
566  window_size_left,
567  window_size_right,
568  mask_type,
569  p_drop,
570  s_randval,
571  std::make_pair(std::get<0>(drop_seed_offset), std::get<1>(drop_seed_offset)),
572  cu_seqlen_q_ptr,
573  cu_seqlen_k_ptr);
574  }
575 
576  // std::variant<> can't take in a list initializer, overload for backward compatibility
577  template <bool Cond = !kIsGroupMode>
578  CK_TILE_HOST static constexpr std::enable_if_t<Cond, Kargs>
579  MakeKargs(const void* q_ptr,
580  const void* k_ptr,
581  const void* v_ptr,
582  const void* bias_ptr,
583  void* rand_val_ptr,
584  void* lse_ptr,
585  void* o_ptr,
586  ck_tile::index_t seqlen_q,
587  ck_tile::index_t seqlen_k,
588  ck_tile::index_t hdim_q,
589  ck_tile::index_t hdim_v,
590  ck_tile::index_t num_head_q,
591  ck_tile::index_t nhead_ratio_qk,
592  float scale_s,
593  float scale_p,
594  float scale_o,
595  float logits_soft_cap,
596  ck_tile::index_t stride_q,
597  ck_tile::index_t stride_k,
598  ck_tile::index_t stride_v,
599  ck_tile::index_t stride_bias,
600  ck_tile::index_t stride_randval,
601  ck_tile::index_t stride_o,
602  ck_tile::index_t nhead_stride_q,
603  ck_tile::index_t nhead_stride_k,
604  ck_tile::index_t nhead_stride_v,
605  ck_tile::index_t nhead_stride_bias,
606  ck_tile::index_t nhead_stride_randval,
607  ck_tile::index_t nhead_stride_lse,
608  ck_tile::index_t nhead_stride_o,
609  ck_tile::index_t batch_stride_q,
610  ck_tile::index_t batch_stride_k,
611  ck_tile::index_t batch_stride_v,
612  ck_tile::index_t batch_stride_bias,
613  ck_tile::index_t batch_stride_randval,
614  ck_tile::index_t batch_stride_lse,
615  ck_tile::index_t batch_stride_o,
616  ck_tile::index_t window_size_left,
617  ck_tile::index_t window_size_right,
618  ck_tile::index_t mask_type,
619  float p_drop,
620  bool s_randval,
621  const std::tuple<const void*, const void*>& drop_seed_offset,
622  const void* cu_seqlen_q_ptr = nullptr,
623  const void* cu_seqlen_k_ptr = nullptr)
624  {
625  return MakeKargsImpl(
626  q_ptr,
627  k_ptr,
628  v_ptr,
629  bias_ptr,
630  rand_val_ptr,
631  lse_ptr,
632  o_ptr,
633  seqlen_q,
634  seqlen_k,
635  hdim_q,
636  hdim_v,
637  num_head_q,
638  nhead_ratio_qk,
639  scale_s,
640  scale_p,
641  scale_o,
642  logits_soft_cap,
643  stride_q,
644  stride_k,
645  stride_v,
646  stride_bias,
647  stride_randval,
648  stride_o,
649  nhead_stride_q,
650  nhead_stride_k,
651  nhead_stride_v,
652  nhead_stride_bias,
653  nhead_stride_randval,
654  nhead_stride_lse,
655  nhead_stride_o,
656  batch_stride_q,
657  batch_stride_k,
658  batch_stride_v,
659  batch_stride_bias,
660  batch_stride_randval,
661  batch_stride_lse,
662  batch_stride_o,
663  window_size_left,
664  window_size_right,
665  mask_type,
666  p_drop,
667  s_randval,
668  std::make_pair(std::get<0>(drop_seed_offset), std::get<1>(drop_seed_offset)),
669  cu_seqlen_q_ptr,
670  cu_seqlen_k_ptr);
671  }
672 
673  template <bool Cond = kIsGroupMode>
674  CK_TILE_HOST static constexpr std::enable_if_t<Cond, Kargs>
675  MakeKargsImpl(const void* q_ptr,
676  const void* k_ptr,
677  const void* v_ptr,
678  const void* bias_ptr,
679  void* rand_val_ptr,
680  void* lse_ptr,
681  void* o_ptr,
682  const void* seqstart_q_ptr,
683  const void* seqstart_k_ptr,
684  const void* seqlen_q_ptr,
685  const void* seqlen_k_ptr,
686  ck_tile::index_t hdim_q,
687  ck_tile::index_t hdim_v,
688  ck_tile::index_t num_head_q,
689  ck_tile::index_t nhead_ratio_qk,
690  float scale_s,
691  float scale_p,
692  float scale_o,
693  float logits_soft_cap,
694  ck_tile::index_t stride_q,
695  ck_tile::index_t stride_k,
696  ck_tile::index_t stride_v,
697  ck_tile::index_t stride_bias,
698  ck_tile::index_t stride_randval,
699  ck_tile::index_t stride_o,
700  ck_tile::index_t nhead_stride_q,
701  ck_tile::index_t nhead_stride_k,
702  ck_tile::index_t nhead_stride_v,
703  ck_tile::index_t nhead_stride_bias,
704  ck_tile::index_t nhead_stride_randval,
705  ck_tile::index_t nhead_stride_lse,
706  ck_tile::index_t nhead_stride_o,
707  ck_tile::index_t window_size_left,
708  ck_tile::index_t window_size_right,
709  ck_tile::index_t mask_type,
710  ck_tile::index_t min_seqlen_q,
711  float p_drop,
712  bool s_randval,
713  std::variant<std::pair<uint64_t, uint64_t>, std::pair<const void*, const void*>>
714  drop_seed_offset,
715  const void* cu_seqlen_q_ptr = nullptr,
716  const void* cu_seqlen_k_ptr = nullptr)
717  {
718  Kargs kargs{{q_ptr,
719  k_ptr,
720  v_ptr,
721  o_ptr,
722  -1, // seqlen will be updated by another pointer
723  -1, //
724  hdim_q,
725  hdim_v,
726  num_head_q,
727  nhead_ratio_qk,
728 #if CK_TILE_FMHA_FWD_FAST_EXP2
729  static_cast<float>(scale_s * ck_tile::log2e_v<>),
730 #else
731  scale_s,
732 #endif
733  stride_q,
734  stride_k,
735  stride_v,
736  stride_o,
737  nhead_stride_q,
738  nhead_stride_k,
739  nhead_stride_v,
740  nhead_stride_o}, // args for common karg
741  {}, // placeholder for bias
742  {}, // placeholder for mask
743  {}, // placeholder for lse
744  {}, // placeholder for fp8_static_quant args
745  {}, // placeholder for dropout
746  {}, // placeholder for logits_soft_cap
747  {}, // placeholder for min_seqlen_q
748  reinterpret_cast<const int32_t*>(seqstart_q_ptr),
749  reinterpret_cast<const int32_t*>(seqstart_k_ptr),
750  reinterpret_cast<const int32_t*>(seqlen_q_ptr),
751  reinterpret_cast<const int32_t*>(seqlen_k_ptr)};
752 
754  {
755  kargs.bias_ptr = bias_ptr;
756  kargs.stride_bias = stride_bias;
757  kargs.nhead_stride_bias = nhead_stride_bias;
758  }
759  else if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI)
760  {
761  kargs.alibi_slope_ptr = bias_ptr;
762  kargs.alibi_slope_stride = stride_bias;
763  }
764  if constexpr(kHasMask)
765  {
766  kargs.window_size_left = window_size_left;
767  kargs.window_size_right = window_size_right;
768  kargs.mask_type = static_cast<ck_tile::GenericAttentionMaskEnum>(mask_type);
769  }
770  if constexpr(kStoreLSE)
771  {
772  kargs.lse_ptr = lse_ptr;
773  kargs.nhead_stride_lse = nhead_stride_lse;
774  }
775  if constexpr(kDoFp8StaticQuant)
776  {
777  kargs.scale_p = scale_p;
778  kargs.scale_o = scale_o;
779  }
780  if constexpr(kHasDropout)
781  {
782  if(drop_seed_offset.index() == 0) // seed & offset come from host
783  {
784  const auto& [seed, offset] = std::get<0>(drop_seed_offset);
785  kargs.init_dropout(p_drop, seed, offset);
786  }
787  else // seed & offset come from device
788  {
789  const auto& [seed_ptr, offset_ptr] = std::get<1>(drop_seed_offset);
790  kargs.init_dropout(p_drop,
791  reinterpret_cast<const uint64_t*>(seed_ptr),
792  reinterpret_cast<const uint64_t*>(offset_ptr));
793  }
794 
795  kargs.rand_val_ptr = rand_val_ptr;
796  kargs.stride_randval = stride_randval;
797  kargs.nhead_stride_randval = nhead_stride_randval;
798  kargs.is_store_randval = s_randval;
799  }
800  if constexpr(kHasLogitsSoftCap)
801  {
802  kargs.init_logits_soft_cap(logits_soft_cap);
803  }
804  if constexpr(kSkipMinSeqlenQ)
805  {
806  kargs.min_seqlen_q = min_seqlen_q;
807  }
808 
809  kargs.cu_seqlen_q_ptr = reinterpret_cast<const int32_t*>(cu_seqlen_q_ptr);
810  kargs.cu_seqlen_k_ptr = reinterpret_cast<const int32_t*>(cu_seqlen_k_ptr);
811  return kargs;
812  }
813 
814  // std::variant<> can't take in a list initializer, overload for backward compatibility
815  template <bool Cond = kIsGroupMode>
816  CK_TILE_HOST static constexpr std::enable_if_t<Cond, Kargs>
817  MakeKargs(const void* q_ptr,
818  const void* k_ptr,
819  const void* v_ptr,
820  const void* bias_ptr,
821  void* rand_val_ptr,
822  void* lse_ptr,
823  void* o_ptr,
824  const void* seqstart_q_ptr,
825  const void* seqstart_k_ptr,
826  const void* seqlen_q_ptr,
827  const void* seqlen_k_ptr,
828  ck_tile::index_t hdim_q,
829  ck_tile::index_t hdim_v,
830  ck_tile::index_t num_head_q,
831  ck_tile::index_t nhead_ratio_qk,
832  float scale_s,
833  float scale_p,
834  float scale_o,
835  float logits_soft_cap,
836  ck_tile::index_t stride_q,
837  ck_tile::index_t stride_k,
838  ck_tile::index_t stride_v,
839  ck_tile::index_t stride_bias,
840  ck_tile::index_t stride_randval,
841  ck_tile::index_t stride_o,
842  ck_tile::index_t nhead_stride_q,
843  ck_tile::index_t nhead_stride_k,
844  ck_tile::index_t nhead_stride_v,
845  ck_tile::index_t nhead_stride_bias,
846  ck_tile::index_t nhead_stride_randval,
847  ck_tile::index_t nhead_stride_lse,
848  ck_tile::index_t nhead_stride_o,
849  ck_tile::index_t window_size_left,
850  ck_tile::index_t window_size_right,
851  ck_tile::index_t mask_type,
852  ck_tile::index_t min_seqlen_q,
853  float p_drop,
854  bool s_randval,
855  const std::tuple<uint64_t, uint64_t>& drop_seed_offset,
856  const void* cu_seqlen_q_ptr = nullptr,
857  const void* cu_seqlen_k_ptr = nullptr)
858  {
859  return MakeKargsImpl(
860  q_ptr,
861  k_ptr,
862  v_ptr,
863  bias_ptr,
864  rand_val_ptr,
865  lse_ptr,
866  o_ptr,
867  seqstart_q_ptr,
868  seqstart_k_ptr,
869  seqlen_q_ptr,
870  seqlen_k_ptr,
871  hdim_q,
872  hdim_v,
873  num_head_q,
874  nhead_ratio_qk,
875  scale_s,
876  scale_p,
877  scale_o,
878  logits_soft_cap,
879  stride_q,
880  stride_k,
881  stride_v,
882  stride_bias,
883  stride_randval,
884  stride_o,
885  nhead_stride_q,
886  nhead_stride_k,
887  nhead_stride_v,
888  nhead_stride_bias,
889  nhead_stride_randval,
890  nhead_stride_lse,
891  nhead_stride_o,
892  window_size_left,
893  window_size_right,
894  mask_type,
895  min_seqlen_q,
896  p_drop,
897  s_randval,
898  std::make_pair(std::get<0>(drop_seed_offset), std::get<1>(drop_seed_offset)),
899  cu_seqlen_q_ptr,
900  cu_seqlen_k_ptr);
901  }
902 
903  // std::variant<> can't take in a list initializer, overload for backward compatibility
904  template <bool Cond = kIsGroupMode>
905  CK_TILE_HOST static constexpr std::enable_if_t<Cond, Kargs>
906  MakeKargs(const void* q_ptr,
907  const void* k_ptr,
908  const void* v_ptr,
909  const void* bias_ptr,
910  void* rand_val_ptr,
911  void* lse_ptr,
912  void* o_ptr,
913  const void* seqstart_q_ptr,
914  const void* seqstart_k_ptr,
915  const void* seqlen_q_ptr,
916  const void* seqlen_k_ptr,
917  ck_tile::index_t hdim_q,
918  ck_tile::index_t hdim_v,
919  ck_tile::index_t num_head_q,
920  ck_tile::index_t nhead_ratio_qk,
921  float scale_s,
922  float scale_p,
923  float scale_o,
924  float logits_soft_cap,
925  ck_tile::index_t stride_q,
926  ck_tile::index_t stride_k,
927  ck_tile::index_t stride_v,
928  ck_tile::index_t stride_bias,
929  ck_tile::index_t stride_randval,
930  ck_tile::index_t stride_o,
931  ck_tile::index_t nhead_stride_q,
932  ck_tile::index_t nhead_stride_k,
933  ck_tile::index_t nhead_stride_v,
934  ck_tile::index_t nhead_stride_bias,
935  ck_tile::index_t nhead_stride_randval,
936  ck_tile::index_t nhead_stride_lse,
937  ck_tile::index_t nhead_stride_o,
938  ck_tile::index_t window_size_left,
939  ck_tile::index_t window_size_right,
940  ck_tile::index_t mask_type,
941  ck_tile::index_t min_seqlen_q,
942  float p_drop,
943  bool s_randval,
944  const std::tuple<const void*, const void*>& drop_seed_offset,
945  const void* cu_seqlen_q_ptr = nullptr,
946  const void* cu_seqlen_k_ptr = nullptr)
947  {
948  return MakeKargsImpl(
949  q_ptr,
950  k_ptr,
951  v_ptr,
952  bias_ptr,
953  rand_val_ptr,
954  lse_ptr,
955  o_ptr,
956  seqstart_q_ptr,
957  seqstart_k_ptr,
958  seqlen_q_ptr,
959  seqlen_k_ptr,
960  hdim_q,
961  hdim_v,
962  num_head_q,
963  nhead_ratio_qk,
964  scale_s,
965  scale_p,
966  scale_o,
967  logits_soft_cap,
968  stride_q,
969  stride_k,
970  stride_v,
971  stride_bias,
972  stride_randval,
973  stride_o,
974  nhead_stride_q,
975  nhead_stride_k,
976  nhead_stride_v,
977  nhead_stride_bias,
978  nhead_stride_randval,
979  nhead_stride_lse,
980  nhead_stride_o,
981  window_size_left,
982  window_size_right,
983  mask_type,
984  min_seqlen_q,
985  p_drop,
986  s_randval,
987  std::make_pair(std::get<0>(drop_seed_offset), std::get<1>(drop_seed_offset)),
988  cu_seqlen_q_ptr,
989  cu_seqlen_k_ptr);
990  }
991 
992  CK_TILE_HOST static constexpr auto GridSize(ck_tile::index_t batch_size_,
993  ck_tile::index_t nhead_,
994  ck_tile::index_t seqlen_q_,
995  ck_tile::index_t hdim_v_,
996  bool has_padded_seqlen_k = false)
997  {
998  // has_padded_seqlen_k is determined by checking (seqlen_k_ptr != nullptr)
999  if(has_padded_seqlen_k)
1000  {
1001  // TODO: this may need tuning
1002  return dim3(nhead_,
1003  batch_size_,
1004  ck_tile::integer_divide_ceil(seqlen_q_, FmhaPipeline::kM0) *
1005  ck_tile::integer_divide_ceil(hdim_v_, FmhaPipeline::kN1));
1006  }
1007  else
1008  {
1009  // TODO: this may need tuning
1010  return dim3(nhead_,
1011  ck_tile::integer_divide_ceil(seqlen_q_, FmhaPipeline::kM0) *
1012  ck_tile::integer_divide_ceil(hdim_v_, FmhaPipeline::kN1),
1013  batch_size_);
1014  }
1015  }
1016 
1017  CK_TILE_DEVICE static constexpr auto GetTileIndex(const Kargs& kargs)
1018  {
1019  bool has_padded_seqlen_k = false;
1020 
1021  if constexpr(kIsGroupMode)
1022  has_padded_seqlen_k = (kargs.seqlen_k_ptr != nullptr);
1023 
1024  if(has_padded_seqlen_k)
1025  {
1026  // const index_t num_tile_m0 = seqlen_q / kM0;
1027  const index_t num_tile_n1 =
1028  ck_tile::integer_divide_ceil(kargs.hdim_v, FmhaPipeline::kN1);
1029 
1030  const index_t i_block = blockIdx.z;
1031  const index_t i_nhead = blockIdx.x;
1032  const index_t i_batch = blockIdx.y;
1033 
1034  const auto f = [](index_t dividend, index_t divisor) {
1035  index_t quotient = dividend / divisor;
1036  index_t modulus = dividend - quotient * divisor;
1037  return ck_tile::make_tuple(quotient, modulus);
1038  };
1039 
1040  const auto [i_tile_m, i_tile_n] = f(i_block, num_tile_n1);
1041 
1042  if constexpr(kHasMask)
1043  {
1044  // assume that num_tile_n1 is always 1
1045  return ck_tile::make_tuple(gridDim.z - 1 - i_tile_m, i_tile_n, i_nhead, i_batch);
1046  }
1047  else
1048  {
1049  return ck_tile::make_tuple(i_tile_m, i_tile_n, i_nhead, i_batch);
1050  }
1051  }
1052  else
1053  {
1054  // const index_t num_tile_m0 = seqlen_q / kM0;
1055  const index_t num_tile_n1 =
1056  ck_tile::integer_divide_ceil(kargs.hdim_v, FmhaPipeline::kN1);
1057 
1058  const index_t i_block = blockIdx.y; // blockIdx.x
1059  const index_t i_nhead = blockIdx.x; // blockIdx.y
1060  const index_t i_batch = blockIdx.z;
1061 
1062  const auto f = [](index_t dividend, index_t divisor) {
1063  index_t quotient = dividend / divisor;
1064  index_t modulus = dividend - quotient * divisor;
1065  return ck_tile::make_tuple(quotient, modulus);
1066  };
1067 
1068  const auto [i_tile_m, i_tile_n] = f(i_block, num_tile_n1);
1069 
1070  if constexpr(kHasMask)
1071  {
1072  // assume that num_tile_n1 is always 1
1073  return ck_tile::make_tuple(gridDim.y - 1 - i_tile_m, i_tile_n, i_nhead, i_batch);
1074  }
1075  else
1076  {
1077  return ck_tile::make_tuple(i_tile_m, i_tile_n, i_nhead, i_batch);
1078  }
1079  }
1080  }
1081 
1082  CK_TILE_HOST static dim3 BlockSize()
1083  {
1084  if(is_wave32())
1085  {
1086  return dim3(kBlockSize / 2);
1087  }
1088  else
1089  {
1090  return dim3(kBlockSize);
1091  }
1092  }
1093 
1095  {
1096  return ck_tile::max(FmhaPipeline::GetSmemSize(), EpiloguePipeline::GetSmemSize());
1097  }
1098 
1099  CK_TILE_DEVICE void operator()(Kargs kargs) const
1100  {
1101  if constexpr(kIsAvailable)
1102  run_(std::move(kargs));
1103  }
1104 
1105  CK_TILE_DEVICE void run_(Kargs kargs) const
1106  {
1107  if constexpr(kPipelineName != "qr_async_trload")
1108  {
1109  // allocate LDS
1110  __shared__ char smem_ptr[GetSmemSize()];
1111 
1112  // divide problem
1113  const auto [i_tile_m, i_tile_n, i_nhead, i_batch] = GetTileIndex(kargs);
1114 
1115  const index_t i_m0 = amd_wave_read_first_lane(i_tile_m * FmhaPipeline::kM0);
1116  const index_t i_n1 = amd_wave_read_first_lane(i_tile_n * FmhaPipeline::kN1);
1117 
1118  long_index_t batch_offset_q = 0;
1119  long_index_t batch_offset_k = 0;
1120  long_index_t batch_offset_v = 0;
1121  long_index_t batch_offset_bias = 0;
1122  long_index_t batch_offset_randval = 0;
1123  long_index_t batch_offset_lse = 0;
1124  long_index_t batch_offset_o = 0;
1125 
1126  if constexpr(kIsGroupMode)
1127  {
1128  // Use seqstart_q_ptr and seqstart_k_ptr for physical starts
1129  const long_index_t query_start = kargs.seqstart_q_ptr[i_batch];
1130  const long_index_t key_start = kargs.seqstart_k_ptr[i_batch];
1131 
1132  // DRAM base offsets use physical starts
1133  batch_offset_q = query_start * kargs.stride_q;
1134  batch_offset_k = key_start * kargs.stride_k;
1135  if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
1136  {
1137  batch_offset_v = key_start * kargs.stride_v;
1138  }
1139  else
1140  {
1141  batch_offset_v = key_start;
1142  }
1144  {
1145  batch_offset_bias = query_start * kargs.stride_bias;
1146  }
1147  if constexpr(kStoreLSE)
1148  {
1149  // LSE follows the physical layout to stay consistent with other tensors
1150  batch_offset_lse = query_start;
1151  }
1152  if constexpr(kHasDropout)
1153  {
1154  batch_offset_randval = query_start * kargs.stride_randval;
1155  }
1156  batch_offset_o = query_start * kargs.stride_o;
1157 
1158  // real logical lengths (exclude PAD)
1159  // Priority: seqlen_q_ptr > cu_seqlen_q_ptr > calculated from seqstart_q_ptr
1160  if(kargs.seqlen_q_ptr != nullptr)
1161  {
1162  kargs.seqlen_q = kargs.seqlen_q_ptr[i_batch];
1163  }
1164  else if(kargs.cu_seqlen_q_ptr != nullptr)
1165  {
1166  kargs.seqlen_q =
1167  kargs.cu_seqlen_q_ptr[i_batch + 1] - kargs.cu_seqlen_q_ptr[i_batch];
1168  }
1169  else
1170  {
1171  const auto adjusted_seqstart_q_ptr = kargs.seqstart_q_ptr + i_batch;
1172  kargs.seqlen_q = adjusted_seqstart_q_ptr[1] - adjusted_seqstart_q_ptr[0];
1173  }
1174 
1175  if constexpr(kSkipMinSeqlenQ)
1176  {
1177  if(kargs.seqlen_q <= kargs.min_seqlen_q)
1178  {
1179  return;
1180  }
1181  }
1182 
1183  // terminate unnecessary blocks earlier
1184  if(kargs.seqlen_q <= i_m0)
1185  {
1186  return;
1187  }
1188 
1189  if(kargs.seqlen_k_ptr != nullptr)
1190  {
1191  kargs.seqlen_k = kargs.seqlen_k_ptr[i_batch];
1192  }
1193  else if(kargs.cu_seqlen_k_ptr != nullptr)
1194  {
1195  kargs.seqlen_k =
1196  kargs.cu_seqlen_k_ptr[i_batch + 1] - kargs.cu_seqlen_k_ptr[i_batch];
1197  }
1198  else
1199  {
1200  const auto adjusted_seqstart_k_ptr = kargs.seqstart_k_ptr + i_batch;
1201  kargs.seqlen_k = adjusted_seqstart_k_ptr[1] - adjusted_seqstart_k_ptr[0];
1202  }
1203  }
1204  else
1205  {
1206  batch_offset_q = static_cast<long_index_t>(i_batch) * kargs.batch_stride_q;
1207  batch_offset_k = static_cast<long_index_t>(i_batch) * kargs.batch_stride_k;
1208  batch_offset_v = static_cast<long_index_t>(i_batch) * kargs.batch_stride_v;
1210  {
1211  batch_offset_bias =
1212  static_cast<long_index_t>(i_batch) * kargs.batch_stride_bias;
1213  }
1214  if constexpr(kStoreLSE)
1215  {
1216  batch_offset_lse = static_cast<long_index_t>(i_batch) * kargs.batch_stride_lse;
1217  }
1218  if constexpr(kHasDropout)
1219  {
1220  batch_offset_randval =
1221  static_cast<long_index_t>(i_batch) * kargs.batch_stride_randval;
1222  }
1223  batch_offset_o = static_cast<long_index_t>(i_batch) * kargs.batch_stride_o;
1224 
1225  // If cumulative seqlen pointers are provided, override per-batch effective lengths
1226  if(kargs.cu_seqlen_q_ptr != nullptr)
1227  {
1228  kargs.seqlen_q =
1229  kargs.cu_seqlen_q_ptr[i_batch + 1] - kargs.cu_seqlen_q_ptr[i_batch];
1230  }
1231  if(kargs.cu_seqlen_k_ptr != nullptr)
1232  {
1233  kargs.seqlen_k =
1234  kargs.cu_seqlen_k_ptr[i_batch + 1] - kargs.cu_seqlen_k_ptr[i_batch];
1235  }
1236  }
1237 
1238  // for simplicity, batch stride we just modify the pointer
1239  const QDataType* q_ptr = reinterpret_cast<const QDataType*>(kargs.q_ptr) +
1240  static_cast<long_index_t>(i_nhead) * kargs.nhead_stride_q +
1241  batch_offset_q;
1242  const KDataType* k_ptr =
1243  reinterpret_cast<const KDataType*>(kargs.k_ptr) +
1244  static_cast<long_index_t>(i_nhead / kargs.nhead_ratio_qk) * kargs.nhead_stride_k +
1245  batch_offset_k;
1246  const VDataType* v_ptr =
1247  reinterpret_cast<const VDataType*>(kargs.v_ptr) +
1248  static_cast<long_index_t>(i_nhead / kargs.nhead_ratio_qk) * kargs.nhead_stride_v +
1249  batch_offset_v;
1250  ODataType* o_ptr = reinterpret_cast<ODataType*>(kargs.o_ptr) +
1251  static_cast<long_index_t>(i_nhead) * kargs.nhead_stride_o +
1252  batch_offset_o;
1253 
1254  // Q/K/V DRAM and DRAM window
1255  const auto q_dram = [&]() {
1256  const auto q_dram_naive = make_naive_tensor_view<address_space_enum::global>(
1257  q_ptr,
1258  make_tuple(kargs.seqlen_q, kargs.hdim_q),
1259  make_tuple(kargs.stride_q, 1),
1261  number<1>{});
1262  if constexpr(FmhaPipeline::kQLoadOnce)
1263  {
1264  return pad_tensor_view(q_dram_naive,
1268  }
1269  else
1270  {
1271  return pad_tensor_view(
1272  q_dram_naive,
1275  }
1276  }();
1277  const auto k_dram = [&]() {
1278  const auto k_dram_naive = make_naive_tensor_view<address_space_enum::global>(
1279  k_ptr,
1280  make_tuple(kargs.seqlen_k, kargs.hdim_q),
1281  make_tuple(kargs.stride_k, 1),
1283  number<1>{});
1284 
1285  constexpr bool kPadSeqLenK_ = kUseAsyncCopy ? kPadSeqLenK : false;
1286  return pad_tensor_view(
1287  k_dram_naive,
1290  }();
1291  const auto v_dram = [&]() {
1292  if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
1293  {
1294  const auto v_dram_naive = make_naive_tensor_view<address_space_enum::global>(
1295  v_ptr,
1296  make_tuple(kargs.seqlen_k, kargs.hdim_v),
1297  make_tuple(kargs.stride_v, 1),
1299  number<1>{});
1300 
1301  const auto v_dram_transposed = transform_tensor_view(
1302  v_dram_naive,
1304  make_pass_through_transform(kargs.seqlen_k)),
1307 
1308  constexpr bool kPadSeqLenK_ = kUseAsyncCopy ? kPadSeqLenK : false;
1309  return pad_tensor_view(
1310  v_dram_transposed,
1313  }
1314  else
1315  {
1316  const auto v_dram_naive = make_naive_tensor_view<address_space_enum::global>(
1317  v_ptr,
1318  make_tuple(kargs.hdim_v, kargs.seqlen_k),
1319  make_tuple(kargs.stride_v, 1),
1321  number<1>{});
1322 
1323  constexpr bool kPadHeadDimV_ = kUseAsyncCopy ? kPadHeadDimV : false;
1324  return pad_tensor_view(
1325  v_dram_naive,
1328  }
1329  }();
1330 
1331  auto q_dram_window = make_tile_window(
1332  q_dram,
1333  [&]() {
1334  if constexpr(FmhaPipeline::kQLoadOnce)
1337  else
1339  }(),
1340  {i_m0, 0});
1341 
1342  auto k_dram_window = make_tile_window(
1343  k_dram,
1345  {0, 0});
1346 
1347  auto v_dram_window = make_tile_window(
1348  v_dram,
1350  {i_n1, 0});
1353  const auto bias_dram_window = [&, i_nhead_ = i_nhead]() {
1354  constexpr auto bias_dram_window_lengths =
1357  {
1358  const BiasDataType* bias_ptr =
1359  reinterpret_cast<const BiasDataType*>(kargs.bias_ptr) +
1360  static_cast<long_index_t>(i_nhead_) * kargs.nhead_stride_bias +
1361  batch_offset_bias;
1362 
1363  const auto bias_dram = [&]() {
1364  const auto bias_dram_naive =
1365  make_naive_tensor_view<address_space_enum::global>(
1366  bias_ptr,
1367  make_tuple(kargs.seqlen_q, kargs.seqlen_k),
1368  make_tuple(kargs.stride_bias, 1),
1370  number<1>{});
1371 
1372  return pad_tensor_view(bias_dram_naive,
1373  bias_dram_window_lengths,
1375  }();
1376 
1377  return make_tile_window(bias_dram, bias_dram_window_lengths, {i_m0, 0});
1378  }
1379  else
1380  {
1381  return make_null_tile_window(bias_dram_window_lengths);
1382  }
1383  }();
1384 
1385  // lse
1386  auto lse_dram_window = [&, i_nhead_ = i_nhead]() {
1387  constexpr auto lse_dram_window_lengths = make_tuple(number<FmhaPipeline::kM0>{});
1388  if constexpr(kStoreLSE)
1389  {
1390  LSEDataType* lse_ptr =
1391  reinterpret_cast<LSEDataType*>(kargs.lse_ptr) +
1392  static_cast<long_index_t>(i_nhead_) * kargs.nhead_stride_lse +
1393  batch_offset_lse;
1394 
1395  const auto lse_dram = [&]() {
1396  const auto lse_dram_naive =
1397  make_naive_tensor_view<address_space_enum::global>(
1398  lse_ptr,
1399  make_tuple(kargs.seqlen_q),
1400  make_tuple(1),
1401  number<1>{},
1402  number<1>{});
1403 
1404  return pad_tensor_view(
1405  lse_dram_naive, lse_dram_window_lengths, sequence<kPadSeqLenQ>{});
1406  }();
1407 
1408  return make_tile_window(lse_dram, lse_dram_window_lengths, {i_m0});
1409  }
1410  else
1411  {
1412  return make_null_tile_window(lse_dram_window_lengths);
1413  }
1414  }();
1415 
1416  auto dropout = [&, i_nhead_ = i_nhead, i_batch_ = i_batch]() {
1417  if constexpr(kHasDropout)
1418  {
1419  return BlockDropout{i_batch_,
1420  i_nhead_,
1421  kargs.num_head_q,
1422  kargs.is_drop_seed_offset_from_host ? kargs.drop_seed.val
1423  : *kargs.drop_seed.ptr,
1424  kargs.is_drop_seed_offset_from_host
1425  ? kargs.drop_offset.val
1426  : *kargs.drop_offset.ptr,
1427  kargs.rp_undrop,
1428  kargs.p_undrop_in_uint8_t,
1429  kargs.is_store_randval};
1430  }
1431  else
1432  {
1433  return NullBlockDropout{};
1434  };
1435  }();
1436 
1437  auto randval_dram_window = [&, i_nhead_ = i_nhead]() {
1438  constexpr auto randval_dram_window_lengths =
1440  if constexpr(kHasDropout)
1441  {
1442  RandValOutputDataType* rand_val_ptr =
1443  reinterpret_cast<RandValOutputDataType*>(kargs.rand_val_ptr) +
1444  static_cast<long_index_t>(i_nhead_) * kargs.nhead_stride_randval +
1445  batch_offset_randval;
1446 
1447  const auto randval_dram = [&]() {
1448  const auto randval_dram_naive =
1449  make_naive_tensor_view<address_space_enum::global>(
1450  rand_val_ptr,
1451  make_tuple(kargs.seqlen_q, kargs.seqlen_k),
1452  make_tuple(kargs.stride_randval, 1),
1453  number<1>{},
1454  number<1>{});
1455 
1456  return pad_tensor_view(randval_dram_naive,
1457  randval_dram_window_lengths,
1459  }();
1460 
1461  return make_tile_window(randval_dram, randval_dram_window_lengths, {i_m0, 0});
1462  }
1463  else
1464  {
1465  return make_null_tile_window(randval_dram_window_lengths);
1466  }
1467  }();
1468 
1469  FmhaMask mask = [&]() {
1470  if constexpr(kHasMask)
1471  return ck_tile::make_generic_attention_mask_from_lr_window<FmhaMask>(
1472  kargs.window_size_left,
1473  kargs.window_size_right,
1474  kargs.seqlen_q,
1475  kargs.seqlen_k,
1477  else
1478  return FmhaMask{kargs.seqlen_q, kargs.seqlen_k};
1479  }();
1480 
1481  // WA i_batch capture structure binding before c++20
1482  auto position_encoding = [&, i_batch_ = i_batch, i_nhead_ = i_nhead]() {
1483  if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI)
1484  {
1485  // data loading, shared by entire wg
1486  // TODO: how to use s_read?
1487  SaccDataType slope =
1488  *(reinterpret_cast<const SaccDataType*>(kargs.alibi_slope_ptr) +
1489  i_batch_ * kargs.alibi_slope_stride + i_nhead_);
1490 #if CK_TILE_FMHA_FWD_FAST_EXP2
1491  slope *= ck_tile::log2e_v<>;
1492 #endif
1493  if constexpr(kHasMask)
1494  {
1495  return make_alibi_from_lr_mask<SaccDataType, true>(slope,
1496  kargs.window_size_left,
1497  kargs.window_size_right,
1498  kargs.seqlen_q,
1499  kargs.seqlen_k,
1500  kargs.mask_type);
1501  }
1502  else
1503  {
1505  slope, kargs.seqlen_q, kargs.seqlen_k, AlibiMode::FROM_BOTTOM_RIGHT};
1506  }
1507  }
1508  else
1509  {
1511  }
1512  }();
1513 
1514  AttentionVariant variant;
1515  const auto variant_params = [&] {
1516  if constexpr(kHasLogitsSoftCap)
1517  {
1519  mask, kargs.scale_s, kargs.logits_soft_cap, kargs.logits_soft_cap_rcp};
1520  }
1521  else
1522  {
1523  return ck_tile::StandardAttentionParams<FmhaMask>{mask, kargs.scale_s};
1524  }
1525  }();
1526 
1527  BlockIndices block_indices{i_batch, i_nhead, i_nhead / kargs.nhead_ratio_qk};
1528 
1529  auto o_acc_tile = [&]() {
1530  if constexpr(kDoFp8StaticQuant)
1531  {
1532  auto o_acc_element_func = [&]() {
1533  if constexpr(std::is_same_v<ODataType, ck_tile::fp8_t>)
1535  ck_tile::scales{kargs.scale_o});
1536  else
1537  return ck_tile::scales{kargs.scale_o};
1538  }();
1539  return FmhaPipeline{}(q_dram_window,
1540  identity{}, // q_element_func
1541  k_dram_window,
1542  identity{}, // k_element_func
1543  v_dram_window,
1544  identity{}, // v_element_func
1545  bias_dram_window,
1546  identity{}, // bias_element_func
1547  randval_dram_window,
1548  lse_dram_window,
1549  identity{}, // lse_element_func
1550  identity{}, // s_acc_element_func
1551  scales{kargs.scale_p}, // p_compute_element_func
1552  o_acc_element_func, // o_acc_element_func
1553  mask,
1554  position_encoding,
1555  kargs.scale_s,
1556  variant,
1557  variant_params,
1558  block_indices,
1559  smem_ptr,
1560  dropout);
1561  }
1562  else
1563  {
1564  return FmhaPipeline{}(q_dram_window,
1565  k_dram_window,
1566  v_dram_window,
1567  bias_dram_window,
1568  randval_dram_window,
1569  lse_dram_window,
1570  mask,
1571  position_encoding,
1572  kargs.scale_s,
1573  variant,
1574  variant_params,
1575  block_indices,
1576  smem_ptr,
1577  dropout);
1578  }
1579  }();
1580 
1581  // O DRAM and O DRAM window
1582  auto o_dram = [&]() {
1583  const auto o_dram_naive = make_naive_tensor_view<address_space_enum::global>(
1584  o_ptr,
1585  make_tuple(kargs.seqlen_q, kargs.hdim_v),
1586  make_tuple(kargs.stride_o, 1),
1588  number<1>{});
1589 
1590  return pad_tensor_view(
1591  o_dram_naive,
1594  }();
1595 
1596  auto o_dram_window = make_tile_window(
1597  o_dram,
1599  {i_m0, i_n1});
1600 
1601  EpiloguePipeline{}(o_dram_window, o_acc_tile, nullptr);
1602  }
1603  else
1604  {
1605  // TODO: Refine the logical here.
1606  // In Decode case
1607  // 1. we don't expect KV data reused by different ThreadGroups, bypass the cache
1608  // 2. limit the LDS usage, as we want higher occupancy
1609  // In Prefill case
1610  // 1. we expect KV data reused by different ThreadGroups, use cache
1611  // 2. use more LDS, as we want better memory latency hiding
1612  // If SplitKV off, we don't expect Q data reused by different ThreadGroups, bypass the
1613  // cache
1614  constexpr bool PrefillCase = FmhaPipeline::kM0 >= 128;
1615  // divide problem
1616  const auto [i_tile_m, i_tile_n, i_nhead, i_batch] = GetTileIndex(kargs);
1617 
1618  const index_t i_m0 = i_tile_m * FmhaPipeline::kM0;
1619  const index_t i_n1 = i_tile_n * FmhaPipeline::kN1;
1620 
1621  long_index_t batch_offset_q = 0;
1622  long_index_t batch_offset_k = 0; // unused for paged-kvcache
1623  long_index_t batch_offset_v = 0; // unused for paged-kvcache
1624  long_index_t batch_offset_bias = 0;
1625  long_index_t batch_offset_lse = 0;
1626  long_index_t batch_offset_o = 0;
1627  // index_t kv_l2p_offset =
1628  // 0; // logical-to-physical offset of seqlen_k coordinate. only used for
1629  // paged-kvcache
1630 
1631  if constexpr(kIsGroupMode)
1632  {
1633  // get starting offset for each batch - use seqstart_q_ptr/seqstart_k_ptr for
1634  // physical starts
1635  const long_index_t query_start = kargs.seqstart_q_ptr[i_batch];
1636  const long_index_t key_start = kargs.seqstart_k_ptr[i_batch];
1637 
1638  batch_offset_q = query_start * kargs.stride_q;
1639  batch_offset_k = key_start * kargs.stride_k;
1640  if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
1641  {
1642  batch_offset_v = key_start * kargs.stride_v;
1643  }
1644  else
1645  {
1646  // col-major V: offset along seqlen dimension is scalar index
1647  batch_offset_v = key_start;
1648  }
1650  {
1651  batch_offset_bias = query_start * kargs.stride_bias;
1652  }
1653 
1654  // LSE layout is [nhead, total_seqlen] following the physical layout for Q/O
1655  batch_offset_lse = query_start;
1656  batch_offset_o = query_start * kargs.stride_o;
1657 
1658  // get real # queries & # keys under group mode
1659  if(kargs.seqlen_q_ptr != nullptr)
1660  {
1661  kargs.seqlen_q = kargs.seqlen_q_ptr[i_batch];
1662  }
1663  else if(kargs.cu_seqlen_q_ptr != nullptr)
1664  {
1665  kargs.seqlen_q =
1666  kargs.cu_seqlen_q_ptr[i_batch + 1] - kargs.cu_seqlen_q_ptr[i_batch];
1667  }
1668  else
1669  {
1670  kargs.seqlen_q =
1671  kargs.seqstart_q_ptr[i_batch + 1] - kargs.seqstart_q_ptr[i_batch];
1672  }
1673 
1674  // # of required blocks is different in each groups, terminate unnecessary blocks
1675  // earlier
1676  if(kargs.seqlen_q <= i_m0)
1677  {
1678  return;
1679  }
1680 
1681  if(kargs.seqlen_k_ptr != nullptr)
1682  {
1683  kargs.seqlen_k = kargs.seqlen_k_ptr[i_batch];
1684  }
1685  else if(kargs.cu_seqlen_k_ptr != nullptr)
1686  {
1687  kargs.seqlen_k =
1688  kargs.cu_seqlen_k_ptr[i_batch + 1] - kargs.cu_seqlen_k_ptr[i_batch];
1689  }
1690  else
1691  {
1692  kargs.seqlen_k =
1693  kargs.seqstart_k_ptr[i_batch + 1] - kargs.seqstart_k_ptr[i_batch];
1694  }
1695  }
1696  else
1697  {
1698  batch_offset_q = static_cast<long_index_t>(i_batch) * kargs.batch_stride_q;
1699  batch_offset_k = static_cast<long_index_t>(i_batch) * kargs.batch_stride_k;
1700  batch_offset_v = static_cast<long_index_t>(i_batch) * kargs.batch_stride_v;
1701  if constexpr(kStoreLSE)
1702  {
1703  batch_offset_lse = static_cast<long_index_t>(i_batch) * kargs.batch_stride_lse;
1704  }
1705  batch_offset_o = static_cast<long_index_t>(i_batch) * kargs.batch_stride_o;
1706 
1708  {
1709  batch_offset_bias =
1710  static_cast<long_index_t>(i_batch) * kargs.batch_stride_bias;
1711  }
1712 
1713  // If cumulative seqlen pointers are provided, override per-batch effective lengths
1714  if(kargs.cu_seqlen_q_ptr != nullptr)
1715  {
1716  kargs.seqlen_q =
1717  kargs.cu_seqlen_q_ptr[i_batch + 1] - kargs.cu_seqlen_q_ptr[i_batch];
1718  }
1719  if(kargs.cu_seqlen_k_ptr != nullptr)
1720  {
1721  kargs.seqlen_k =
1722  kargs.cu_seqlen_k_ptr[i_batch + 1] - kargs.cu_seqlen_k_ptr[i_batch];
1723  }
1724  }
1725 
1726  // for simplicity, batch stride we just modify the pointer
1727  const index_t i_nhead_k = i_nhead / kargs.nhead_ratio_qk;
1728 
1729  const QDataType* q_ptr = reinterpret_cast<const QDataType*>(kargs.q_ptr) +
1730  static_cast<long_index_t>(i_nhead) * kargs.nhead_stride_q +
1731  batch_offset_q;
1732  const KDataType* k_ptr = reinterpret_cast<const KDataType*>(kargs.k_ptr) +
1733  static_cast<long_index_t>(i_nhead_k) * kargs.nhead_stride_k +
1734  batch_offset_k;
1735  const VDataType* v_ptr = reinterpret_cast<const VDataType*>(kargs.v_ptr) +
1736  static_cast<long_index_t>(i_nhead_k) * kargs.nhead_stride_v +
1737  batch_offset_v;
1738 
1739  ODataType* o_ptr = reinterpret_cast<ODataType*>(kargs.o_ptr) +
1740  static_cast<long_index_t>(i_nhead) * kargs.nhead_stride_o +
1741  batch_offset_o;
1742 
1743  // Q/K/V DRAM and DRAM window
1744  const auto q_dram = [&] {
1745  const auto q_dram_naive = [&] {
1746  {
1747  return make_naive_tensor_view<address_space_enum::global,
1748  memory_operation_enum::set,
1750  q_ptr,
1751  make_tuple(kargs.seqlen_q, kargs.hdim_q),
1752  make_tuple(kargs.stride_q, 1),
1754  number<1>{});
1755  }
1756  }();
1757 
1758  if constexpr(FmhaPipeline::kQLoadOnce)
1759  {
1760  const auto seqlen_q = kargs.seqlen_q;
1761  const auto q_dram_pad = pad_tensor_view(
1762  q_dram_naive,
1765 #if CK_TILE_FMHA_HANDLE_XOR_LENGTH_FOLD
1766  constexpr index_t LDSLayerSize = 256 / sizeof(QDataType);
1767  constexpr index_t XorLengthFold = LDSLayerSize / (FmhaPipeline::kQKHeaddim);
1768 
1769  if constexpr(XorLengthFold > 1)
1770  {
1771  const auto q_dram_unmerged = transform_tensor_view(
1772  q_dram_pad,
1773  make_tuple(
1775  make_tuple(seqlen_q / XorLengthFold, XorLengthFold)),
1779 
1780  const auto q_dram_merged = transform_tensor_view(
1781  q_dram_unmerged,
1782  make_tuple(make_pass_through_transform(seqlen_q / XorLengthFold),
1784  XorLengthFold, number<FmhaPipeline::kQKHeaddim>{}))),
1787 
1788  const auto q_dram_unmerged_xor = transform_tensor_view(
1789  q_dram_merged,
1790  make_tuple(make_pass_through_transform(seqlen_q / XorLengthFold),
1796 
1797  const auto q_dram_permuted = transform_tensor_view(
1798  q_dram_unmerged_xor,
1799  make_tuple(
1801  make_tuple(seqlen_q / XorLengthFold,
1806 
1807  const auto q_dram_tmp = transform_tensor_view(
1808  q_dram_permuted,
1809  make_tuple(
1810  make_pass_through_transform(seqlen_q / XorLengthFold),
1813  number<FmhaPipeline::kQKHeaddim /
1814  FmhaPipeline::kAlignmentQ>{})),
1818 
1819  return transform_tensor_view(
1820  q_dram_tmp,
1821  make_tuple(
1823  make_tuple(seqlen_q / XorLengthFold, number<XorLengthFold>{})),
1829  }
1830  else
1831 #endif // CK_TILE_FMHA_HANDLE_XOR_LENGTH_FOLD
1832  {
1833  const auto q_dram_unmerged = transform_tensor_view(
1834  q_dram_pad,
1835  make_tuple(
1836  make_pass_through_transform(seqlen_q),
1842 
1843  const auto q_dram_permuted = transform_tensor_view(
1844  q_dram_unmerged,
1845  make_tuple(
1846  make_xor_transform(make_tuple(seqlen_q,
1847  number<FmhaPipeline::kQKHeaddim /
1848  FmhaPipeline::kAlignmentQ>{})),
1852 
1853  return transform_tensor_view(
1854  q_dram_permuted,
1855  make_tuple(
1856  make_pass_through_transform(seqlen_q),
1862  }
1863  }
1864  else
1865  {
1866  return pad_tensor_view(
1867  q_dram_naive,
1870  }
1871  }();
1872 
1873  const auto make_k_dram = [&](const KDataType* data, index_t height) {
1874  const auto k_dram_naive = make_naive_tensor_view<address_space_enum::global>(
1875  data, // will update this pointer if using paged-kvcache
1876  make_tuple(height, kargs.hdim_q),
1877  make_tuple(kargs.stride_k, 1),
1879  number<1>{});
1880 
1881  const auto k_dram_pad = pad_tensor_view(
1882  k_dram_naive,
1885 
1886  constexpr auto kDramTileK =
1887  FmhaPipeline::kKLoadOnce ? FmhaPipeline::kQKHeaddim : FmhaPipeline::kK0;
1888 
1889 #if CK_TILE_FMHA_HANDLE_XOR_LENGTH_FOLD
1890  constexpr index_t LDSLayerSize = 256 / sizeof(KDataType);
1891  constexpr index_t XorLengthFold = LDSLayerSize / (FmhaPipeline::kQKHeaddim);
1892 
1893  if constexpr(XorLengthFold > 1)
1894  {
1895  const auto k_dram_unmerged = transform_tensor_view(
1896  k_dram_pad,
1898  make_tuple(height / XorLengthFold, XorLengthFold)),
1902 
1903  const auto k_dram_merged = transform_tensor_view(
1904  k_dram_unmerged,
1905  make_tuple(make_pass_through_transform(height / XorLengthFold),
1907  XorLengthFold, number<FmhaPipeline::kQKHeaddim>{}))),
1910 
1911  const auto k_dram_unmerged_xor = transform_tensor_view(
1912  k_dram_merged,
1913  make_tuple(make_pass_through_transform(height / XorLengthFold),
1919 
1920  const auto k_dram_permuted = transform_tensor_view(
1921  k_dram_unmerged_xor,
1922  make_tuple(
1924  make_tuple(height / XorLengthFold,
1929 
1930  const auto k_dram_tmp = transform_tensor_view(
1931  k_dram_permuted,
1932  make_tuple(
1933  make_pass_through_transform(height / XorLengthFold),
1936  number<FmhaPipeline::kQKHeaddim / FmhaPipeline::kAlignmentK>{})),
1940 
1941  return transform_tensor_view(
1942  k_dram_tmp,
1943  make_tuple(
1945  make_tuple(height / XorLengthFold, number<XorLengthFold>{})),
1951  }
1952  else
1953 #endif // CK_TILE_FMHA_HANDLE_XOR_LENGTH_FOLD
1954  {
1955  const auto k_dram_unmerged = transform_tensor_view(
1956  k_dram_pad,
1959  make_tuple(number<FmhaPipeline::kQKHeaddim / kDramTileK /
1960  FmhaPipeline::kAlignmentK>{},
1961  number<kDramTileK / FmhaPipeline::kAlignmentK>{},
1965 
1966  const auto k_dram_permuted = transform_tensor_view(
1967  k_dram_unmerged,
1968  make_tuple(
1972  number<FmhaPipeline::kQKHeaddim / kDramTileK /
1973  FmhaPipeline::kAlignmentK>{}),
1977 
1978  return transform_tensor_view(
1979  k_dram_permuted,
1982  make_tuple(number<FmhaPipeline::kQKHeaddim / kDramTileK /
1983  FmhaPipeline::kAlignmentK>{},
1984  number<kDramTileK / FmhaPipeline::kAlignmentK>{},
1988  }
1989  };
1990  const auto k_dram = [&]() {
1991  {
1992  return make_k_dram(k_ptr, kargs.seqlen_k);
1993  }
1994  }();
1995 
1996  const auto make_v_dram = [&](const VDataType* data, index_t length) {
1997  const auto v_dram_naive = make_naive_tensor_view<address_space_enum::global>(
1998  data, // will update this pointer if using paged-kvcache
1999  make_tuple(length, kargs.hdim_v),
2000  make_tuple(kargs.stride_v, 1),
2002  number<1>{});
2003 
2004  // TODO: Add kVHeadDim
2005  constexpr index_t XorGroupSize =
2006  FmhaPipeline::Problem::BlockFmhaShape::Gemm1WarpTile::at(number<0>{});
2007 
2008  const auto v_dram_pad = pad_tensor_view(
2009  v_dram_naive,
2012 
2013 #if CK_TILE_FMHA_HANDLE_XOR_LENGTH_FOLD
2014  constexpr index_t LDSLayerSize = 256 / sizeof(VDataType);
2015  constexpr index_t XorLengthFold = LDSLayerSize / (FmhaPipeline::kQKHeaddim);
2016 
2017  if constexpr(XorLengthFold > 1)
2018  {
2019  const auto v_dram_unmerged = transform_tensor_view(
2020  v_dram_pad,
2022  make_tuple(length / XorLengthFold, XorLengthFold)),
2026 
2027  const auto v_dram_merged = transform_tensor_view(
2028  v_dram_unmerged,
2029  make_tuple(make_pass_through_transform(length / XorLengthFold),
2031  XorLengthFold, number<FmhaPipeline::kQKHeaddim>{}))),
2034 
2035  const auto v_dram_unmerged_xor = transform_tensor_view(
2036  v_dram_merged,
2037  make_tuple(
2038  make_pass_through_transform(length / XorLengthFold),
2040  number<XorGroupSize>{}))),
2043 
2044  const auto v_dram_permuted = transform_tensor_view(
2045  v_dram_unmerged_xor,
2046  make_tuple(
2047  make_xor_transform(make_tuple(length / XorLengthFold,
2052 
2053  const auto v_dram_tmp = transform_tensor_view(
2054  v_dram_permuted,
2055  make_tuple(make_pass_through_transform(length / XorLengthFold),
2058  number<FmhaPipeline::kQKHeaddim / XorGroupSize>{})),
2062 
2063  return transform_tensor_view(
2064  v_dram_tmp,
2066  make_tuple(length / XorLengthFold, number<XorLengthFold>{})),
2069  number<XorGroupSize>{}))),
2072  }
2073  else
2074 #endif // CK_TILE_FMHA_HANDLE_XOR_LENGTH_FOLD
2075  {
2076  const auto v_dram_unmerged = transform_tensor_view(
2077  v_dram_pad,
2081  number<XorGroupSize>{}))),
2084 
2085  const auto v_dram_permuted = transform_tensor_view(
2086  v_dram_unmerged,
2092 
2093  return transform_tensor_view(
2094  v_dram_permuted,
2098  number<XorGroupSize>{}))),
2101  }
2102  };
2103 
2104  const auto v_dram = [&]() {
2105  {
2106  return make_v_dram(v_ptr, kargs.seqlen_k);
2107  }
2108  }();
2109 
2110  auto q_dram_window = make_tile_window(
2111  q_dram,
2112  [&]() {
2113  if constexpr(FmhaPipeline::kQLoadOnce)
2116  else
2118  }(),
2119  {i_m0, 0});
2120 
2121  auto k_dram_window = make_tile_window(
2122  k_dram,
2124  {0, 0});
2125 
2126  auto v_dram_window = make_tile_window(
2127  v_dram,
2129  {0, 0});
2130 
2133  const auto bias_dram_window = [&, i_nhead_ = i_nhead]() {
2134  constexpr auto bias_dram_window_lengths =
2137  {
2138  const BiasDataType* bias_ptr =
2139  reinterpret_cast<const BiasDataType*>(kargs.bias_ptr) +
2140  static_cast<long_index_t>(i_nhead_) * kargs.nhead_stride_bias +
2141  batch_offset_bias;
2142 
2143  const auto bias_dram = [&]() {
2144  const auto bias_dram_naive =
2145  make_naive_tensor_view<address_space_enum::global>(
2146  bias_ptr,
2147  make_tuple(kargs.seqlen_q, kargs.seqlen_k),
2148  make_tuple(kargs.stride_bias, 1),
2150  number<1>{});
2151 
2152  return pad_tensor_view(bias_dram_naive,
2153  bias_dram_window_lengths,
2155  }();
2156 
2157  return make_tile_window(bias_dram, bias_dram_window_lengths, {i_m0, 0});
2158  }
2159  else
2160  {
2161  return make_null_tile_window(bias_dram_window_lengths);
2162  }
2163  }();
2164 
2165  // lse acc
2166  auto lse_dram_window = [&, i_nhead_ = i_nhead]() {
2167  constexpr auto lse_dram_window_lengths = make_tuple(number<FmhaPipeline::kM0>{});
2168  if constexpr(kStoreLSE)
2169  {
2170  LSEDataType* lse_ptr =
2171  reinterpret_cast<LSEDataType*>(kargs.lse_ptr) +
2172  static_cast<long_index_t>(i_nhead_) * kargs.nhead_stride_lse +
2173  batch_offset_lse;
2174 
2175  const auto lse_dram = [&] {
2176  const auto lse_dram_naive = [&] {
2177  {
2178  return make_naive_tensor_view<address_space_enum::global>(
2179  lse_ptr,
2180  make_tuple(kargs.seqlen_q),
2181  make_tuple(1),
2182  number<1>{},
2183  number<1>{});
2184  }
2185  }();
2186  return pad_tensor_view(
2187  lse_dram_naive, lse_dram_window_lengths, sequence<kPadSeqLenQ>{});
2188  }();
2189 
2190  return make_tile_window(lse_dram, lse_dram_window_lengths, {i_m0});
2191  }
2192  else
2193  {
2194  return make_null_tile_window(lse_dram_window_lengths);
2195  }
2196  }();
2197 
2198  FmhaMask mask = [&]() {
2199  if constexpr(kHasMask)
2200  return ck_tile::make_generic_attention_mask_from_lr_window<FmhaMask>(
2201  kargs.window_size_left,
2202  kargs.window_size_right,
2203  kargs.seqlen_q,
2204  kargs.seqlen_k,
2206  else
2207  return FmhaMask{kargs.seqlen_q, kargs.seqlen_k};
2208  }();
2209 
2210  // WA i_batch capture structure binding before c++20
2211  auto position_encoding = [&, i_batch_ = i_batch, i_nhead_ = i_nhead]() {
2212  if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI)
2213  {
2214  // data loading, shared by entire wg
2215  // TODO: how to use s_read?
2216  SaccDataType slope =
2217  *(reinterpret_cast<const SaccDataType*>(kargs.alibi_slope_ptr) +
2218  i_batch_ * kargs.alibi_slope_stride + i_nhead_);
2219 #if CK_TILE_FMHA_FWD_FAST_EXP2
2220  slope *= ck_tile::log2e_v<>;
2221 #endif
2222  if constexpr(kHasMask)
2223  {
2224  return make_alibi_from_lr_mask<SaccDataType, true, 32>(
2225  slope,
2226  kargs.window_size_left,
2227  kargs.window_size_right,
2228  kargs.seqlen_q,
2229  kargs.seqlen_k,
2230  kargs.mask_type);
2231  }
2232  else
2233  {
2235  slope, kargs.seqlen_q, kargs.seqlen_k, AlibiMode::FROM_BOTTOM_RIGHT};
2236  }
2237  }
2238  else
2239  {
2241  }
2242  }();
2243 
2244  auto o_acc_tile = [&]() {
2245  if constexpr(PrefillCase)
2246  {
2247  // allocate double lds
2248  // add __restrict__ here to avoid aliasing
2249  __shared__ char smem_ptrk0
2250  [FmhaPipeline::Policy::template GetSmemSizeK<typename FmhaPipeline::Problem,
2251  true>()];
2252  __shared__ char smem_ptrk1
2253  [FmhaPipeline::Policy::template GetSmemSizeK<typename FmhaPipeline::Problem,
2254  true>()];
2255  __shared__ char smem_ptrv0[FmhaPipeline::Policy::template GetSmemSizeV<
2256  typename FmhaPipeline::Problem>()];
2257  __shared__ char smem_ptrv1[FmhaPipeline::Policy::template GetSmemSizeV<
2258  typename FmhaPipeline::Problem>()];
2259 
2260  return FmhaPipeline{}(q_dram_window,
2261  k_dram_window,
2262  v_dram_window,
2263  bias_dram_window,
2264  lse_dram_window,
2265  mask,
2266  position_encoding,
2267  kargs.scale_s,
2268  smem_ptrk0,
2269  smem_ptrk1,
2270  smem_ptrv0,
2271  smem_ptrv1);
2272  }
2273  else
2274  {
2275  __shared__ char smem_ptr[GetSmemSize()];
2276  return FmhaPipeline{}(q_dram_window,
2277  k_dram_window,
2278  v_dram_window,
2279  bias_dram_window,
2280  lse_dram_window,
2281  mask,
2282  position_encoding,
2283  kargs.scale_s,
2284  smem_ptr);
2285  }
2286  }();
2287 
2288  // Oacc DRAM and Oacc DRAM window
2289  auto o_dram = [&] {
2290  const auto o_dram_naive = [&] {
2291  {
2292  return make_naive_tensor_view<address_space_enum::global>(
2293  o_ptr,
2294  make_tuple(kargs.seqlen_q, kargs.hdim_v),
2295  make_tuple(kargs.stride_o, 1),
2297  number<1>{});
2298  }
2299  }();
2300 
2301  return pad_tensor_view(
2302  o_dram_naive,
2305  }();
2306 
2307  auto o_dram_window = make_tile_window(
2308  o_dram,
2310  {i_m0, i_n1});
2311 
2312  EpiloguePipeline{}(o_dram_window, o_acc_tile, nullptr);
2313  }
2314  }
2315 };
2316 
2317 } // namespace ck_tile
#define CK_TILE_DEVICE
Definition: config.hpp:41
#define CK_TILE_HOST
Definition: config.hpp:40
#define CK_TILE_HOST_DEVICE
Definition: config.hpp:42
#define _TS_
#define _SS_
__host__ constexpr __device__ T max(T x)
Definition: math.hpp:84
__host__ T floor(T x)
Definition: math_v2.hpp:367
Definition: cluster_descriptor.hpp:13
constexpr CK_TILE_DEVICE auto make_null_tile_window(const WindowLengths &window_lengths)
Definition: null_tile_window.hpp:66
_BitInt(8) fp8_t
Definition: float8.hpp:204
constexpr CK_TILE_HOST_DEVICE auto integer_divide_ceil(X x, Y y)
Definition: math.hpp:149
__device__ uint32_t amd_wave_read_first_lane(uint16_t v)
Definition: amd_buffer_addressing.hpp:35
_Float16 fp16_t
Definition: half.hpp:110
constexpr CK_TILE_HOST_DEVICE auto transform_tensor_view(const OldTensorView &old_tensor_view, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition: tensor_view.hpp:511
bfloat16_t bf16_t
Definition: bfloat16.hpp:113
int32_t index_t
Definition: integer.hpp:9
constexpr CK_TILE_HOST_DEVICE auto pad_tensor_view(const TensorView &tensor_view, const TileLengths &tile_lengths, DoPads)
Definition: tensor_view.hpp:530
constexpr CK_TILE_HOST_DEVICE auto make_pass_through_transform(const LowLength &low_length)
Definition: coordinate_transform.hpp:1558
float fp32_t
Definition: pk_fp4.hpp:21
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition: type_traits.hpp:21
constexpr CK_TILE_HOST_DEVICE auto make_unmerge_transform(const UpLengths &up_lengths, bool_constant< Use24BitIntegerCalculation >=bool_constant< false >{})
Definition: coordinate_transform.hpp:1622
constexpr CK_TILE_HOST_DEVICE auto make_merge_transform_v3_division_mod(const LowLengths &low_lengths)
Definition: coordinate_transform.hpp:1609
int64_t long_index_t
Definition: integer.hpp:11
int32_t int32_t
Definition: integer.hpp:10
constexpr CK_TILE_DEVICE auto make_tile_window(null_tensor_view, const WindowLengths &window_lengths, const multi_index< WindowLengths::size()> &, Ts &&...)
Definition: null_tile_window.hpp:75
unsigned _BitInt(8) bf8_t
Definition: float8.hpp:206
constexpr CK_TILE_HOST_DEVICE auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:360
constexpr CK_TILE_HOST_DEVICE auto make_xor_transform(const LowLengths &low_lengths)
Definition: coordinate_transform.hpp:1662
GenericAttentionMaskEnum
Definition: block_masking.hpp:11
constexpr CK_TILE_HOST_DEVICE auto make_naive_tensor_view(DataType *__restrict__ p, const tuple< Lengths... > &lengths, const tuple< Strides... > &strides, number< GuaranteedLastDimensionVectorLength >=number<-1 >{}, number< GuaranteedLastDimensionVectorStride >=number<-1 >{})
Definition: tensor_view.hpp:471
__host__ __device__ composes(Ts &&...) -> composes< remove_cvref_t< Ts >... >
FIXME: create macro to replace 'host device' and nothing more.
constexpr CK_TILE_HOST_DEVICE T max(T x)
Definition: math.hpp:161
typename conditional< predicate, X, Y >::type conditional_t
Definition: functional.hpp:115
unsigned char uint8_t
Definition: stdint.h:124
unsigned __int64 uint64_t
Definition: stdint.h:136
Definition: block_position_encoding.hpp:48
Definition: block_attention_bias_enum.hpp:19
Definition: block_dropout.hpp:53
const float rp_undrop
Definition: block_dropout.hpp:377
Definition: block_position_encoding.hpp:137
Definition: fmha_fwd_kernel.hpp:330
ck_tile::index_t kv_head_idx
Definition: fmha_fwd_kernel.hpp:333
ck_tile::index_t batch_idx
Definition: fmha_fwd_kernel.hpp:331
ck_tile::index_t qo_head_idx
Definition: fmha_fwd_kernel.hpp:332
Definition: fmha_fwd_kernel.hpp:194
ck_tile::index_t alibi_slope_stride
Definition: fmha_fwd_kernel.hpp:197
const void * alibi_slope_ptr
Definition: fmha_fwd_kernel.hpp:196
Definition: fmha_fwd_kernel.hpp:189
ck_tile::index_t batch_stride_bias
Definition: fmha_fwd_kernel.hpp:190
Definition: fmha_fwd_kernel.hpp:270
ck_tile::index_t batch_stride_randval
Definition: fmha_fwd_kernel.hpp:271
Definition: fmha_fwd_kernel.hpp:291
ck_tile::index_t batch_stride_o
Definition: fmha_fwd_kernel.hpp:295
const int32_t * cu_seqlen_k_ptr
Definition: fmha_fwd_kernel.hpp:300
ck_tile::index_t batch_stride_q
Definition: fmha_fwd_kernel.hpp:292
ck_tile::index_t batch_stride_k
Definition: fmha_fwd_kernel.hpp:293
const int32_t * cu_seqlen_q_ptr
Definition: fmha_fwd_kernel.hpp:299
ck_tile::index_t batch_stride_v
Definition: fmha_fwd_kernel.hpp:294
Definition: fmha_fwd_kernel.hpp:182
const void * bias_ptr
Definition: fmha_fwd_kernel.hpp:183
ck_tile::index_t stride_bias
Definition: fmha_fwd_kernel.hpp:184
ck_tile::index_t nhead_stride_bias
Definition: fmha_fwd_kernel.hpp:185
Definition: fmha_fwd_kernel.hpp:235
void init_dropout(float p_drop, const uint64_t *seed_ptr, const uint64_t *offset_ptr)
Definition: fmha_fwd_kernel.hpp:248
float rp_undrop
Definition: fmha_fwd_kernel.hpp:260
ck_tile::index_t stride_randval
Definition: fmha_fwd_kernel.hpp:265
ck_tile::index_t nhead_stride_randval
Definition: fmha_fwd_kernel.hpp:266
void * rand_val_ptr
Definition: fmha_fwd_kernel.hpp:263
void init_dropout(float p_drop, uint64_t seed, uint64_t offset)
Definition: fmha_fwd_kernel.hpp:236
bool is_store_randval
Definition: fmha_fwd_kernel.hpp:262
uint8_t p_undrop_in_uint8_t
Definition: fmha_fwd_kernel.hpp:261
Definition: fmha_fwd_kernel.hpp:131
ck_tile::index_t nhead_stride_k
Definition: fmha_fwd_kernel.hpp:154
float scale_s
Definition: fmha_fwd_kernel.hpp:146
ck_tile::index_t seqlen_k
Definition: fmha_fwd_kernel.hpp:138
ck_tile::index_t nhead_stride_o
Definition: fmha_fwd_kernel.hpp:156
ck_tile::index_t nhead_ratio_qk
Definition: fmha_fwd_kernel.hpp:145
ck_tile::index_t num_head_q
Definition: fmha_fwd_kernel.hpp:142
ck_tile::index_t hdim_q
Definition: fmha_fwd_kernel.hpp:139
const void * v_ptr
Definition: fmha_fwd_kernel.hpp:134
void * o_ptr
Definition: fmha_fwd_kernel.hpp:135
const void * k_ptr
Definition: fmha_fwd_kernel.hpp:133
ck_tile::index_t nhead_stride_q
Definition: fmha_fwd_kernel.hpp:153
ck_tile::index_t stride_k
Definition: fmha_fwd_kernel.hpp:149
ck_tile::index_t stride_o
Definition: fmha_fwd_kernel.hpp:151
ck_tile::index_t stride_v
Definition: fmha_fwd_kernel.hpp:150
ck_tile::index_t hdim_v
Definition: fmha_fwd_kernel.hpp:140
ck_tile::index_t nhead_stride_v
Definition: fmha_fwd_kernel.hpp:155
const void * q_ptr
Definition: fmha_fwd_kernel.hpp:132
ck_tile::index_t seqlen_q
Definition: fmha_fwd_kernel.hpp:137
ck_tile::index_t stride_q
Definition: fmha_fwd_kernel.hpp:148
Definition: fmha_fwd_kernel.hpp:214
ck_tile::index_t batch_stride_lse
Definition: fmha_fwd_kernel.hpp:217
void * lse_ptr
Definition: fmha_fwd_kernel.hpp:215
ck_tile::index_t nhead_stride_lse
Definition: fmha_fwd_kernel.hpp:216
Definition: fmha_fwd_kernel.hpp:221
bool is_drop_seed_offset_from_host
Definition: fmha_fwd_kernel.hpp:231
ValueOrPointer< uint64_t > drop_seed
Definition: fmha_fwd_kernel.hpp:229
ValueOrPointer< uint64_t > drop_offset
Definition: fmha_fwd_kernel.hpp:230
Definition: fmha_fwd_kernel.hpp:124
Definition: fmha_fwd_kernel.hpp:208
float scale_o
Definition: fmha_fwd_kernel.hpp:210
float scale_p
Definition: fmha_fwd_kernel.hpp:209
Definition: fmha_fwd_kernel.hpp:316
const int32_t * seqlen_q_ptr
Definition: fmha_fwd_kernel.hpp:319
const int32_t * seqstart_q_ptr
Definition: fmha_fwd_kernel.hpp:317
const int32_t * seqlen_k_ptr
Definition: fmha_fwd_kernel.hpp:320
const int32_t * cu_seqlen_k_ptr
Definition: fmha_fwd_kernel.hpp:324
const int32_t * cu_seqlen_q_ptr
Definition: fmha_fwd_kernel.hpp:323
const int32_t * seqstart_k_ptr
Definition: fmha_fwd_kernel.hpp:318
Definition: fmha_fwd_kernel.hpp:160
float logits_soft_cap
Definition: fmha_fwd_kernel.hpp:177
float logits_soft_cap_rcp
Definition: fmha_fwd_kernel.hpp:178
void init_logits_soft_cap(float logits_soft_cap_)
Definition: fmha_fwd_kernel.hpp:163
Definition: fmha_fwd_kernel.hpp:201
ck_tile::GenericAttentionMaskEnum mask_type
Definition: fmha_fwd_kernel.hpp:204
ck_tile::index_t window_size_right
Definition: fmha_fwd_kernel.hpp:203
ck_tile::index_t window_size_left
Definition: fmha_fwd_kernel.hpp:203
Definition: fmha_fwd_kernel.hpp:275
ck_tile::index_t min_seqlen_q
Definition: fmha_fwd_kernel.hpp:276
Definition: fmha_fwd_kernel.hpp:75
Definition: fmha_fwd_kernel.hpp:27
static constexpr bool kHasDropout
Definition: fmha_fwd_kernel.hpp:56
static CK_TILE_HOST std::string GetName()
Definition: fmha_fwd_kernel.hpp:85
static constexpr CK_TILE_HOST std::enable_if_t< Cond, Kargs > MakeKargs(const void *q_ptr, const void *k_ptr, const void *v_ptr, const void *bias_ptr, void *rand_val_ptr, void *lse_ptr, void *o_ptr, const void *seqstart_q_ptr, const void *seqstart_k_ptr, const void *seqlen_q_ptr, const void *seqlen_k_ptr, ck_tile::index_t hdim_q, ck_tile::index_t hdim_v, ck_tile::index_t num_head_q, ck_tile::index_t nhead_ratio_qk, float scale_s, float scale_p, float scale_o, float logits_soft_cap, ck_tile::index_t stride_q, ck_tile::index_t stride_k, ck_tile::index_t stride_v, ck_tile::index_t stride_bias, ck_tile::index_t stride_randval, ck_tile::index_t stride_o, ck_tile::index_t nhead_stride_q, ck_tile::index_t nhead_stride_k, ck_tile::index_t nhead_stride_v, ck_tile::index_t nhead_stride_bias, ck_tile::index_t nhead_stride_randval, ck_tile::index_t nhead_stride_lse, ck_tile::index_t nhead_stride_o, ck_tile::index_t window_size_left, ck_tile::index_t window_size_right, ck_tile::index_t mask_type, ck_tile::index_t min_seqlen_q, float p_drop, bool s_randval, const std::tuple< const void *, const void * > &drop_seed_offset, const void *cu_seqlen_q_ptr=nullptr, const void *cu_seqlen_k_ptr=nullptr)
Definition: fmha_fwd_kernel.hpp:906
static constexpr bool kIsAvailable
Definition: fmha_fwd_kernel.hpp:70
static constexpr bool kDoFp8StaticQuant
Definition: fmha_fwd_kernel.hpp:57
static constexpr bool kStoreLSE
Definition: fmha_fwd_kernel.hpp:55
static constexpr CK_TILE_HOST std::enable_if_t< Cond, Kargs > MakeKargs(const void *q_ptr, const void *k_ptr, const void *v_ptr, const void *bias_ptr, void *rand_val_ptr, void *lse_ptr, void *o_ptr, ck_tile::index_t seqlen_q, ck_tile::index_t seqlen_k, ck_tile::index_t hdim_q, ck_tile::index_t hdim_v, ck_tile::index_t num_head_q, ck_tile::index_t nhead_ratio_qk, float scale_s, float scale_p, float scale_o, float logits_soft_cap, ck_tile::index_t stride_q, ck_tile::index_t stride_k, ck_tile::index_t stride_v, ck_tile::index_t stride_bias, ck_tile::index_t stride_randval, ck_tile::index_t stride_o, ck_tile::index_t nhead_stride_q, ck_tile::index_t nhead_stride_k, ck_tile::index_t nhead_stride_v, ck_tile::index_t nhead_stride_bias, ck_tile::index_t nhead_stride_randval, ck_tile::index_t nhead_stride_lse, ck_tile::index_t nhead_stride_o, ck_tile::index_t batch_stride_q, ck_tile::index_t batch_stride_k, ck_tile::index_t batch_stride_v, ck_tile::index_t batch_stride_bias, ck_tile::index_t batch_stride_randval, ck_tile::index_t batch_stride_lse, ck_tile::index_t batch_stride_o, ck_tile::index_t window_size_left, ck_tile::index_t window_size_right, ck_tile::index_t mask_type, float p_drop, bool s_randval, const std::tuple< uint64_t, uint64_t > &drop_seed_offset, const void *cu_seqlen_q_ptr=nullptr, const void *cu_seqlen_k_ptr=nullptr)
Definition: fmha_fwd_kernel.hpp:482
ck_tile::remove_cvref_t< typename FmhaPipeline::KDataType > KDataType
Definition: fmha_fwd_kernel.hpp:37
std::conditional_t< kIsGroupMode, FmhaFwdGroupModeKargs, FmhaFwdBatchModeKargs > Kargs
Definition: fmha_fwd_kernel.hpp:327
static constexpr ck_tile::index_t kBlockPerCu
Definition: fmha_fwd_kernel.hpp:32
ck_tile::remove_cvref_t< typename FmhaPipeline::ODataType > ODataType
Definition: fmha_fwd_kernel.hpp:43
ck_tile::remove_cvref_t< typename FmhaPipeline::VLayout > VLayout
Definition: fmha_fwd_kernel.hpp:46
static constexpr ck_tile::index_t kBlockSize
Definition: fmha_fwd_kernel.hpp:30
ck_tile::remove_cvref_t< typename FmhaPipeline::BiasDataType > BiasDataType
Definition: fmha_fwd_kernel.hpp:39
static constexpr CK_TILE_HOST auto GridSize(ck_tile::index_t batch_size_, ck_tile::index_t nhead_, ck_tile::index_t seqlen_q_, ck_tile::index_t hdim_v_, bool has_padded_seqlen_k=false)
Definition: fmha_fwd_kernel.hpp:992
static constexpr CK_TILE_HOST std::enable_if_t< Cond, Kargs > MakeKargsImpl(const void *q_ptr, const void *k_ptr, const void *v_ptr, const void *bias_ptr, void *rand_val_ptr, void *lse_ptr, void *o_ptr, const void *seqstart_q_ptr, const void *seqstart_k_ptr, const void *seqlen_q_ptr, const void *seqlen_k_ptr, ck_tile::index_t hdim_q, ck_tile::index_t hdim_v, ck_tile::index_t num_head_q, ck_tile::index_t nhead_ratio_qk, float scale_s, float scale_p, float scale_o, float logits_soft_cap, ck_tile::index_t stride_q, ck_tile::index_t stride_k, ck_tile::index_t stride_v, ck_tile::index_t stride_bias, ck_tile::index_t stride_randval, ck_tile::index_t stride_o, ck_tile::index_t nhead_stride_q, ck_tile::index_t nhead_stride_k, ck_tile::index_t nhead_stride_v, ck_tile::index_t nhead_stride_bias, ck_tile::index_t nhead_stride_randval, ck_tile::index_t nhead_stride_lse, ck_tile::index_t nhead_stride_o, ck_tile::index_t window_size_left, ck_tile::index_t window_size_right, ck_tile::index_t mask_type, ck_tile::index_t min_seqlen_q, float p_drop, bool s_randval, std::variant< std::pair< uint64_t, uint64_t >, std::pair< const void *, const void * >> drop_seed_offset, const void *cu_seqlen_q_ptr=nullptr, const void *cu_seqlen_k_ptr=nullptr)
Definition: fmha_fwd_kernel.hpp:675
ck_tile::remove_cvref_t< typename FmhaPipeline::VDataType > VDataType
Definition: fmha_fwd_kernel.hpp:38
static constexpr ck_tile::index_t kBlockPerCuInput
Definition: fmha_fwd_kernel.hpp:34
static constexpr auto BiasEnum
Definition: fmha_fwd_kernel.hpp:54
static constexpr bool kPadHeadDimV
Definition: fmha_fwd_kernel.hpp:52
static constexpr CK_TILE_DEVICE auto GetTileIndex(const Kargs &kargs)
Definition: fmha_fwd_kernel.hpp:1017
static constexpr bool kSkipMinSeqlenQ
Definition: fmha_fwd_kernel.hpp:58
static constexpr CK_TILE_HOST std::enable_if_t< Cond, Kargs > MakeKargsImpl(const void *q_ptr, const void *k_ptr, const void *v_ptr, const void *bias_ptr, void *rand_val_ptr, void *lse_ptr, void *o_ptr, ck_tile::index_t seqlen_q, ck_tile::index_t seqlen_k, ck_tile::index_t hdim_q, ck_tile::index_t hdim_v, ck_tile::index_t num_head_q, ck_tile::index_t nhead_ratio_qk, float scale_s, float scale_p, float scale_o, float logits_soft_cap, ck_tile::index_t stride_q, ck_tile::index_t stride_k, ck_tile::index_t stride_v, ck_tile::index_t stride_bias, ck_tile::index_t stride_randval, ck_tile::index_t stride_o, ck_tile::index_t nhead_stride_q, ck_tile::index_t nhead_stride_k, ck_tile::index_t nhead_stride_v, ck_tile::index_t nhead_stride_bias, ck_tile::index_t nhead_stride_randval, ck_tile::index_t nhead_stride_lse, ck_tile::index_t nhead_stride_o, ck_tile::index_t batch_stride_q, ck_tile::index_t batch_stride_k, ck_tile::index_t batch_stride_v, ck_tile::index_t batch_stride_bias, ck_tile::index_t batch_stride_randval, ck_tile::index_t batch_stride_lse, ck_tile::index_t batch_stride_o, ck_tile::index_t window_size_left, ck_tile::index_t window_size_right, ck_tile::index_t mask_type, float p_drop, bool s_randval, std::variant< std::pair< uint64_t, uint64_t >, std::pair< const void *, const void * >> drop_seed_offset, const void *cu_seqlen_q_ptr=nullptr, const void *cu_seqlen_k_ptr=nullptr)
Definition: fmha_fwd_kernel.hpp:338
static constexpr std::string_view kPipelineName
Definition: fmha_fwd_kernel.hpp:72
static constexpr CK_TILE_HOST std::enable_if_t< Cond, Kargs > MakeKargs(const void *q_ptr, const void *k_ptr, const void *v_ptr, const void *bias_ptr, void *rand_val_ptr, void *lse_ptr, void *o_ptr, const void *seqstart_q_ptr, const void *seqstart_k_ptr, const void *seqlen_q_ptr, const void *seqlen_k_ptr, ck_tile::index_t hdim_q, ck_tile::index_t hdim_v, ck_tile::index_t num_head_q, ck_tile::index_t nhead_ratio_qk, float scale_s, float scale_p, float scale_o, float logits_soft_cap, ck_tile::index_t stride_q, ck_tile::index_t stride_k, ck_tile::index_t stride_v, ck_tile::index_t stride_bias, ck_tile::index_t stride_randval, ck_tile::index_t stride_o, ck_tile::index_t nhead_stride_q, ck_tile::index_t nhead_stride_k, ck_tile::index_t nhead_stride_v, ck_tile::index_t nhead_stride_bias, ck_tile::index_t nhead_stride_randval, ck_tile::index_t nhead_stride_lse, ck_tile::index_t nhead_stride_o, ck_tile::index_t window_size_left, ck_tile::index_t window_size_right, ck_tile::index_t mask_type, ck_tile::index_t min_seqlen_q, float p_drop, bool s_randval, const std::tuple< uint64_t, uint64_t > &drop_seed_offset, const void *cu_seqlen_q_ptr=nullptr, const void *cu_seqlen_k_ptr=nullptr)
Definition: fmha_fwd_kernel.hpp:817
ck_tile::remove_cvref_t< typename FmhaPipeline::LSEDataType > LSEDataType
Definition: fmha_fwd_kernel.hpp:42
ck_tile::remove_cvref_t< typename FmhaPipeline::QDataType > QDataType
Definition: fmha_fwd_kernel.hpp:36
static CK_TILE_HOST dim3 BlockSize()
Definition: fmha_fwd_kernel.hpp:1082
CK_TILE_DEVICE void run_(Kargs kargs) const
Definition: fmha_fwd_kernel.hpp:1105
ck_tile::remove_cvref_t< typename FmhaPipeline::AttentionVariant > AttentionVariant
Definition: fmha_fwd_kernel.hpp:60
static constexpr CK_TILE_HOST_DEVICE ck_tile::index_t GetSmemSize()
Definition: fmha_fwd_kernel.hpp:1094
static constexpr bool kUseTrLoad
Definition: fmha_fwd_kernel.hpp:66
static constexpr bool kHasMask
Definition: fmha_fwd_kernel.hpp:62
static constexpr bool kUseAsyncCopy
Definition: fmha_fwd_kernel.hpp:64
ck_tile::remove_cvref_t< FmhaPipeline_ > FmhaPipeline
Definition: fmha_fwd_kernel.hpp:28
static constexpr bool kPadHeadDimQ
Definition: fmha_fwd_kernel.hpp:51
ck_tile::remove_cvref_t< typename FmhaPipeline::SaccDataType > SaccDataType
Definition: fmha_fwd_kernel.hpp:44
static constexpr bool kPadSeqLenQ
Definition: fmha_fwd_kernel.hpp:49
static constexpr CK_TILE_HOST std::enable_if_t< Cond, Kargs > MakeKargs(const void *q_ptr, const void *k_ptr, const void *v_ptr, const void *bias_ptr, void *rand_val_ptr, void *lse_ptr, void *o_ptr, ck_tile::index_t seqlen_q, ck_tile::index_t seqlen_k, ck_tile::index_t hdim_q, ck_tile::index_t hdim_v, ck_tile::index_t num_head_q, ck_tile::index_t nhead_ratio_qk, float scale_s, float scale_p, float scale_o, float logits_soft_cap, ck_tile::index_t stride_q, ck_tile::index_t stride_k, ck_tile::index_t stride_v, ck_tile::index_t stride_bias, ck_tile::index_t stride_randval, ck_tile::index_t stride_o, ck_tile::index_t nhead_stride_q, ck_tile::index_t nhead_stride_k, ck_tile::index_t nhead_stride_v, ck_tile::index_t nhead_stride_bias, ck_tile::index_t nhead_stride_randval, ck_tile::index_t nhead_stride_lse, ck_tile::index_t nhead_stride_o, ck_tile::index_t batch_stride_q, ck_tile::index_t batch_stride_k, ck_tile::index_t batch_stride_v, ck_tile::index_t batch_stride_bias, ck_tile::index_t batch_stride_randval, ck_tile::index_t batch_stride_lse, ck_tile::index_t batch_stride_o, ck_tile::index_t window_size_left, ck_tile::index_t window_size_right, ck_tile::index_t mask_type, float p_drop, bool s_randval, const std::tuple< const void *, const void * > &drop_seed_offset, const void *cu_seqlen_q_ptr=nullptr, const void *cu_seqlen_k_ptr=nullptr)
Definition: fmha_fwd_kernel.hpp:579
ck_tile::remove_cvref_t< typename FmhaPipeline::FmhaMask > FmhaMask
Definition: fmha_fwd_kernel.hpp:61
static constexpr bool kHasLogitsSoftCap
Definition: fmha_fwd_kernel.hpp:53
ck_tile::remove_cvref_t< typename FmhaPipeline::RandValOutputDataType > RandValOutputDataType
Definition: fmha_fwd_kernel.hpp:41
ck_tile::remove_cvref_t< EpiloguePipeline_ > EpiloguePipeline
Definition: fmha_fwd_kernel.hpp:29
static constexpr bool kPadSeqLenK
Definition: fmha_fwd_kernel.hpp:50
static constexpr bool kIsGroupMode
Definition: fmha_fwd_kernel.hpp:48
CK_TILE_DEVICE void operator()(Kargs kargs) const
Definition: fmha_fwd_kernel.hpp:1099
Definition: variants.hpp:63
float logits_soft_cap
Definition: variants.hpp:128
Definition: block_dropout.hpp:39
Definition: variants.hpp:51
Definition: integral_constant.hpp:13
Definition: functional.hpp:86
Definition: coordinate_transform.hpp:1392
Definition: unary_element_function.hpp:56
Definition: math.hpp:28
Definition: sequence.hpp:49
const T * ptr
Definition: fmha_fwd_kernel.hpp:226