/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp Source File
fmha_fwd_kernel.hpp
Go to the documentation of this file.
1 // Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
2 // SPDX-License-Identifier: MIT
3 
4 #pragma once
5 
6 #include "ck_tile/core.hpp"
7 #include "ck_tile/ops/common.hpp"
11 
12 #include <string>
13 #include <type_traits>
14 #include <utility>
15 #include <variant>
16 
17 #define CK_TILE_FMHA_HANDLE_XOR_LENGTH_FOLD 0
18 // S[seqlen_q, seqlen_k] = Q[seqlen_q, hdim_q] @ K[seqlen_k, hdim_q]
19 // S'[seqlen_q, seqlen_k] = S[seqlen_q, seqlen_k] * Scale[1]
20 // S''[seqlen_q, seqlen_k] = S'[seqlen_q, seqlen_k] + Bias[seqlen_q, seqlen_k]
21 // P[seqlen_q, seqlen_k] = Softmax(S''[seqlen_q, seqlen_k])
22 // O[seqlen_q, hdim_v] = P[seqlen_q, seqlen_k] @ V^T[hdim_v, seqlen_k]
23 
24 namespace ck_tile {
25 
26 template <typename FmhaPipeline_, typename EpiloguePipeline_>
28 {
31  static constexpr ck_tile::index_t kBlockSize = FmhaPipeline::kBlockSize;
32 
33  static constexpr ck_tile::index_t kBlockPerCu = FmhaPipeline::kBlockPerCu;
34  static_assert(kBlockPerCu > 0);
35  static constexpr ck_tile::index_t kBlockPerCuInput = FmhaPipeline::Problem::kBlockPerCu;
36 
47 
49 
50  static constexpr bool kIsGroupMode = FmhaPipeline::kIsGroupMode;
51  static constexpr bool kPadSeqLenQ = FmhaPipeline::kPadSeqLenQ;
52  static constexpr bool kPadSeqLenK = FmhaPipeline::kPadSeqLenK;
53  static constexpr bool kPadHeadDimQ = FmhaPipeline::kPadHeadDimQ;
54  static constexpr bool kPadHeadDimV = FmhaPipeline::kPadHeadDimV;
55  static constexpr bool kHasLogitsSoftCap = FmhaPipeline::kHasLogitsSoftCap;
56  static constexpr auto BiasEnum = FmhaPipeline::BiasEnum;
57  static constexpr bool kStoreLSE = FmhaPipeline::kStoreLSE;
58  static constexpr bool kHasDropout = FmhaPipeline::kHasDropout;
59  static constexpr auto QScaleEnum = FmhaPipeline::Problem::QScaleEnum;
60  static constexpr bool kSkipMinSeqlenQ = FmhaPipeline::Problem::kSkipMinSeqlenQ;
61  static constexpr bool kHasSink = FmhaPipeline::kHasSink;
62 
65  static constexpr bool kHasMask = FmhaMask::IsMasking;
66 
67  static constexpr bool kUseAsyncCopy = FmhaPipeline::Policy::AsyncCopy;
68 
69  static constexpr bool kUseTrLoad = FmhaPipeline::Problem::kUseTrLoad;
70 #if defined(__gfx950__)
71  static constexpr bool kIsAvailable = true;
72 #else
73  static constexpr bool kIsAvailable = !kUseTrLoad;
74 #endif
75  static constexpr std::string_view kPipelineName = FmhaPipeline::name;
76 
77  template <ck_tile::index_t I> // to avoid duplicated base class prblem, introduce an template
78  // arg
80  {
81  };
82 
83  // kargs use aggregate initializer, so no constructor will provided
84  // use inheritance to minimize karg size
85  // user need to use MakeKargs() function to create kargs.
87  {
88  const void* q_ptr;
89  const void* k_ptr;
90  const void* v_ptr;
91  void* o_ptr;
92 
97 
99  // for MQA/GQA, nhead could be different. This parameter is nhead_q / nhead_k
100  // if this param is larger than 1, indicate MQA/GQA case
102  float scale_s;
103 
108 
113  };
114 
116  {
118 
119  void init_logits_soft_cap(float logits_soft_cap_)
120  {
121  if(0 < logits_soft_cap_)
122  {
123  logits_soft_cap = logits_soft_cap_;
125  }
126  else
127  {
128  logits_soft_cap = 0.f;
129  logits_soft_cap_rcp = 0.f;
130  }
131  }
132 
135  };
136 
138  {
139  const void* bias_ptr = nullptr;
142  };
143 
145  {
147  };
148 
150  {
151  // alibi is batch*nhead*1, no matter in batch/group mode, they are the same
152  const void* alibi_slope_ptr;
153  ck_tile::index_t alibi_slope_stride; // stride in batch, or 0 for all batch share same slope
154  };
155 
157  {
158  // ck_tile::index_t window_size_left, window_size_right;
161  };
162 
164  {
165  const void* q_descale_ptr = nullptr;
166  const void* k_descale_ptr = nullptr;
167  const void* v_descale_ptr = nullptr;
168  };
169 
171  {
172  void* lse_ptr = nullptr;
175  };
176 
178  {
179  template <typename T>
181  {
182  T val;
183  const T* ptr;
184  };
185 
189  };
190 
192  {
193  void init_dropout(float p_drop, uint64_t seed, uint64_t offset)
194  {
195  float p_undrop = 1.0 - p_drop;
198  rp_undrop = 1.0 / p_undrop;
199 
200  this->drop_seed.val = seed;
201  this->drop_offset.val = offset;
202  this->is_drop_seed_offset_from_host = true;
203  }
204 
205  void init_dropout(float p_drop, const uint64_t* seed_ptr, const uint64_t* offset_ptr)
206  {
207  float p_undrop = 1.0 - p_drop;
210  rp_undrop = 1.0 / p_undrop;
211 
212  this->drop_seed.ptr = seed_ptr;
213  this->drop_offset.ptr = offset_ptr;
214  this->is_drop_seed_offset_from_host = false;
215  }
216 
217  float rp_undrop = 1;
219  bool is_store_randval = false;
220  void* rand_val_ptr = nullptr;
221 
224  };
225 
227  {
229  };
230 
232  {
234  };
235 
238  std::conditional_t<BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS,
239  FmhaFwdBatchModeBiasKargs,
240  std::conditional_t<BiasEnum == BlockAttentionBiasEnum::ALIBI,
241  FmhaFwdAlibiKargs,
242  FmhaFwdEmptyKargs<0>>>,
243  std::conditional_t<kHasMask, FmhaFwdMaskKargs, FmhaFwdEmptyKargs<1>>,
244  std::conditional_t<kStoreLSE, FmhaFwdCommonLSEKargs, FmhaFwdEmptyKargs<2>>,
245  std::conditional_t<QScaleEnum == BlockAttentionQuantScaleEnum::PERTENSOR,
246  FmhaFwdCommonQScaleKargs,
247  FmhaFwdEmptyKargs<3>>,
248  std::conditional_t<kHasDropout, FmhaFwdBatchModeDropoutKargs, FmhaFwdEmptyKargs<4>>,
249  std::conditional_t<kHasLogitsSoftCap, FmhaFwdLogitsSoftCapKargs, FmhaFwdEmptyKargs<5>>
250  {
255 
256  // Optional cumulative sequence length pointers for batch mode
257  // If provided, they override seqlen_q / seqlen_k per-batch to skip tail padding.
258  const int32_t* cu_seqlen_q_ptr = nullptr; // cumulative, length without PAD
259  const int32_t* cu_seqlen_k_ptr = nullptr; // cumulative, length without PAD
260  };
261 
264  std::conditional_t<BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS,
265  FmhaFwdCommonBiasKargs,
266  std::conditional_t<BiasEnum == BlockAttentionBiasEnum::ALIBI,
267  FmhaFwdAlibiKargs,
268  FmhaFwdEmptyKargs<0>>>,
269  std::conditional_t<kHasMask, FmhaFwdMaskKargs, FmhaFwdEmptyKargs<1>>,
270  std::conditional_t<kStoreLSE, FmhaFwdCommonLSEKargs, FmhaFwdEmptyKargs<2>>,
271  std::conditional_t<QScaleEnum == BlockAttentionQuantScaleEnum::PERTENSOR,
272  FmhaFwdCommonQScaleKargs,
273  FmhaFwdEmptyKargs<3>>,
274  std::conditional_t<kHasDropout, FmhaFwdCommonDropoutKargs, FmhaFwdEmptyKargs<4>>,
275  std::conditional_t<kHasLogitsSoftCap, FmhaFwdLogitsSoftCapKargs, FmhaFwdEmptyKargs<5>>,
276  std::conditional_t<kSkipMinSeqlenQ, FmhaFwdSkipMinSeqlenQKargs, FmhaFwdEmptyKargs<6>>
277  {
282 
283  // Optional per-sequence and cumulative logical (excluding padding) sequence length arrays
284  const int32_t* cu_seqlen_q_ptr = nullptr;
285  const int32_t* cu_seqlen_k_ptr = nullptr;
286  };
287 
288  using Kargs = std::conditional_t<kIsGroupMode, FmhaFwdGroupModeKargs, FmhaFwdBatchModeKargs>;
289 
291  {
295  };
296 
297  template <bool Cond = !kIsGroupMode>
298  CK_TILE_HOST static constexpr std::enable_if_t<Cond, Kargs>
299  MakeKargsImpl(const void* q_ptr,
300  const void* k_ptr,
301  const void* v_ptr,
302  const void* bias_ptr,
303  const void* q_descale_ptr,
304  const void* k_descale_ptr,
305  const void* v_descale_ptr,
306  void* rand_val_ptr,
307  void* lse_ptr,
308  void* o_ptr,
309  ck_tile::index_t seqlen_q,
310  ck_tile::index_t seqlen_k,
311  ck_tile::index_t hdim_q,
312  ck_tile::index_t hdim_v,
313  ck_tile::index_t num_head_q,
314  ck_tile::index_t nhead_ratio_qk,
315  float scale_s,
316  float logits_soft_cap,
317  ck_tile::index_t stride_q,
318  ck_tile::index_t stride_k,
319  ck_tile::index_t stride_v,
320  ck_tile::index_t stride_bias,
321  ck_tile::index_t stride_randval,
322  ck_tile::index_t stride_o,
323  ck_tile::index_t nhead_stride_q,
324  ck_tile::index_t nhead_stride_k,
325  ck_tile::index_t nhead_stride_v,
326  ck_tile::index_t nhead_stride_bias,
327  ck_tile::index_t nhead_stride_randval,
328  ck_tile::index_t nhead_stride_lse,
329  ck_tile::index_t nhead_stride_o,
330  ck_tile::index_t batch_stride_q,
331  ck_tile::index_t batch_stride_k,
332  ck_tile::index_t batch_stride_v,
333  ck_tile::index_t batch_stride_bias,
334  ck_tile::index_t batch_stride_randval,
335  ck_tile::index_t batch_stride_lse,
336  ck_tile::index_t batch_stride_o,
337  ck_tile::index_t window_size_left,
338  ck_tile::index_t window_size_right,
339  ck_tile::index_t sink_size,
340  ck_tile::index_t mask_type,
341  float p_drop,
342  bool s_randval,
343  std::variant<std::pair<uint64_t, uint64_t>, std::pair<const void*, const void*>>
344  drop_seed_offset,
345  const void* cu_seqlen_q_ptr = nullptr,
346  const void* cu_seqlen_k_ptr = nullptr)
347  {
348  Kargs kargs{{q_ptr,
349  k_ptr,
350  v_ptr,
351  o_ptr,
352  seqlen_q,
353  seqlen_k,
354  hdim_q,
355  hdim_v,
356  num_head_q,
357  nhead_ratio_qk,
358 #if CK_TILE_FMHA_FWD_FAST_EXP2
359  static_cast<float>(scale_s * ck_tile::log2e_v<>),
360 #else
361  scale_s,
362 #endif
363  stride_q,
364  stride_k,
365  stride_v,
366  stride_o,
367  nhead_stride_q,
368  nhead_stride_k,
369  nhead_stride_v,
370  nhead_stride_o}, // args for common karg
371  {}, // placeholder for bias
372  {}, // placeholder for mask
373  {}, // placeholder for lse
374  {}, // placeholder for qscale
375  {}, // placeholder for dropout
376  {}, // placeholder for logits_soft_cap
377  batch_stride_q,
378  batch_stride_k,
379  batch_stride_v,
380  batch_stride_o};
381 
383  {
384  kargs.bias_ptr = bias_ptr;
385  kargs.stride_bias = stride_bias;
386  kargs.nhead_stride_bias = nhead_stride_bias;
387  kargs.batch_stride_bias = batch_stride_bias;
388  }
389  else if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI)
390  {
391  kargs.alibi_slope_ptr = bias_ptr;
392  kargs.alibi_slope_stride = stride_bias;
393  }
394  if constexpr(kHasMask)
395  {
396  kargs.window_size_left = window_size_left;
397  kargs.window_size_right = window_size_right;
398  kargs.sink_size = sink_size;
399  kargs.mask_type = static_cast<ck_tile::GenericAttentionMaskEnum>(mask_type);
400  }
401  if constexpr(kStoreLSE)
402  {
403  kargs.lse_ptr = lse_ptr;
404  kargs.nhead_stride_lse = nhead_stride_lse;
405  kargs.batch_stride_lse = batch_stride_lse;
406  }
408  {
409  kargs.q_descale_ptr = q_descale_ptr;
410  kargs.k_descale_ptr = k_descale_ptr;
411  kargs.v_descale_ptr = v_descale_ptr;
412  }
413  if constexpr(kHasDropout)
414  {
415  if(drop_seed_offset.index() == 0) // seed & offset come from host
416  {
417  const auto& [seed, offset] = std::get<0>(drop_seed_offset);
418  kargs.init_dropout(p_drop, seed, offset);
419  }
420  else // seed & offset come from device
421  {
422  const auto& [seed_ptr, offset_ptr] = std::get<1>(drop_seed_offset);
423  kargs.init_dropout(p_drop,
424  reinterpret_cast<const uint64_t*>(seed_ptr),
425  reinterpret_cast<const uint64_t*>(offset_ptr));
426  }
427 
428  kargs.rand_val_ptr = rand_val_ptr;
429  kargs.stride_randval = stride_randval;
430  kargs.nhead_stride_randval = nhead_stride_randval;
431  kargs.batch_stride_randval = batch_stride_randval;
432  kargs.is_store_randval = s_randval;
433  }
434  if constexpr(kHasLogitsSoftCap)
435  {
436  kargs.init_logits_soft_cap(logits_soft_cap);
437  }
438 
439  kargs.cu_seqlen_q_ptr = reinterpret_cast<const int32_t*>(cu_seqlen_q_ptr);
440  kargs.cu_seqlen_k_ptr = reinterpret_cast<const int32_t*>(cu_seqlen_k_ptr);
441  return kargs;
442  }
443 
444  // std::variant<> can't take in a list initializer, overload for backward compatibility
445  template <bool Cond = !kIsGroupMode>
446  CK_TILE_HOST static constexpr std::enable_if_t<Cond, Kargs>
447  MakeKargs(const void* q_ptr,
448  const void* k_ptr,
449  const void* v_ptr,
450  const void* bias_ptr,
451  const void* q_descale_ptr,
452  const void* k_descale_ptr,
453  const void* v_descale_ptr,
454  void* rand_val_ptr,
455  void* lse_ptr,
456  void* o_ptr,
457  ck_tile::index_t seqlen_q,
458  ck_tile::index_t seqlen_k,
459  ck_tile::index_t hdim_q,
460  ck_tile::index_t hdim_v,
461  ck_tile::index_t num_head_q,
462  ck_tile::index_t nhead_ratio_qk,
463  float scale_s,
464  float logits_soft_cap,
465  ck_tile::index_t stride_q,
466  ck_tile::index_t stride_k,
467  ck_tile::index_t stride_v,
468  ck_tile::index_t stride_bias,
469  ck_tile::index_t stride_randval,
470  ck_tile::index_t stride_o,
471  ck_tile::index_t nhead_stride_q,
472  ck_tile::index_t nhead_stride_k,
473  ck_tile::index_t nhead_stride_v,
474  ck_tile::index_t nhead_stride_bias,
475  ck_tile::index_t nhead_stride_randval,
476  ck_tile::index_t nhead_stride_lse,
477  ck_tile::index_t nhead_stride_o,
478  ck_tile::index_t batch_stride_q,
479  ck_tile::index_t batch_stride_k,
480  ck_tile::index_t batch_stride_v,
481  ck_tile::index_t batch_stride_bias,
482  ck_tile::index_t batch_stride_randval,
483  ck_tile::index_t batch_stride_lse,
484  ck_tile::index_t batch_stride_o,
485  ck_tile::index_t window_size_left,
486  ck_tile::index_t window_size_right,
487  ck_tile::index_t sink_size,
488  ck_tile::index_t mask_type,
489  float p_drop,
490  bool s_randval,
491  const std::tuple<uint64_t, uint64_t>& drop_seed_offset,
492  const void* cu_seqlen_q_ptr = nullptr,
493  const void* cu_seqlen_k_ptr = nullptr)
494  {
495  return MakeKargsImpl(
496  q_ptr,
497  k_ptr,
498  v_ptr,
499  bias_ptr,
500  q_descale_ptr,
501  k_descale_ptr,
502  v_descale_ptr,
503  rand_val_ptr,
504  lse_ptr,
505  o_ptr,
506  seqlen_q,
507  seqlen_k,
508  hdim_q,
509  hdim_v,
510  num_head_q,
511  nhead_ratio_qk,
512  scale_s,
513  logits_soft_cap,
514  stride_q,
515  stride_k,
516  stride_v,
517  stride_bias,
518  stride_randval,
519  stride_o,
520  nhead_stride_q,
521  nhead_stride_k,
522  nhead_stride_v,
523  nhead_stride_bias,
524  nhead_stride_randval,
525  nhead_stride_lse,
526  nhead_stride_o,
527  batch_stride_q,
528  batch_stride_k,
529  batch_stride_v,
530  batch_stride_bias,
531  batch_stride_randval,
532  batch_stride_lse,
533  batch_stride_o,
534  window_size_left,
535  window_size_right,
536  sink_size,
537  mask_type,
538  p_drop,
539  s_randval,
540  std::make_pair(std::get<0>(drop_seed_offset), std::get<1>(drop_seed_offset)),
541  cu_seqlen_q_ptr,
542  cu_seqlen_k_ptr);
543  }
544 
545  // std::variant<> can't take in a list initializer, overload for backward compatibility
546  template <bool Cond = !kIsGroupMode>
547  CK_TILE_HOST static constexpr std::enable_if_t<Cond, Kargs>
548  MakeKargs(const void* q_ptr,
549  const void* k_ptr,
550  const void* v_ptr,
551  const void* bias_ptr,
552  const void* q_descale_ptr,
553  const void* k_descale_ptr,
554  const void* v_descale_ptr,
555  void* rand_val_ptr,
556  void* lse_ptr,
557  void* o_ptr,
558  ck_tile::index_t seqlen_q,
559  ck_tile::index_t seqlen_k,
560  ck_tile::index_t hdim_q,
561  ck_tile::index_t hdim_v,
562  ck_tile::index_t num_head_q,
563  ck_tile::index_t nhead_ratio_qk,
564  float scale_s,
565  float logits_soft_cap,
566  ck_tile::index_t stride_q,
567  ck_tile::index_t stride_k,
568  ck_tile::index_t stride_v,
569  ck_tile::index_t stride_bias,
570  ck_tile::index_t stride_randval,
571  ck_tile::index_t stride_o,
572  ck_tile::index_t nhead_stride_q,
573  ck_tile::index_t nhead_stride_k,
574  ck_tile::index_t nhead_stride_v,
575  ck_tile::index_t nhead_stride_bias,
576  ck_tile::index_t nhead_stride_randval,
577  ck_tile::index_t nhead_stride_lse,
578  ck_tile::index_t nhead_stride_o,
579  ck_tile::index_t batch_stride_q,
580  ck_tile::index_t batch_stride_k,
581  ck_tile::index_t batch_stride_v,
582  ck_tile::index_t batch_stride_bias,
583  ck_tile::index_t batch_stride_randval,
584  ck_tile::index_t batch_stride_lse,
585  ck_tile::index_t batch_stride_o,
586  ck_tile::index_t window_size_left,
587  ck_tile::index_t window_size_right,
588  ck_tile::index_t sink_size,
589  ck_tile::index_t mask_type,
590  float p_drop,
591  bool s_randval,
592  const std::tuple<const void*, const void*>& drop_seed_offset,
593  const void* cu_seqlen_q_ptr = nullptr,
594  const void* cu_seqlen_k_ptr = nullptr)
595  {
596  return MakeKargsImpl(
597  q_ptr,
598  k_ptr,
599  v_ptr,
600  bias_ptr,
601  q_descale_ptr,
602  k_descale_ptr,
603  v_descale_ptr,
604  rand_val_ptr,
605  lse_ptr,
606  o_ptr,
607  seqlen_q,
608  seqlen_k,
609  hdim_q,
610  hdim_v,
611  num_head_q,
612  nhead_ratio_qk,
613  scale_s,
614  logits_soft_cap,
615  stride_q,
616  stride_k,
617  stride_v,
618  stride_bias,
619  stride_randval,
620  stride_o,
621  nhead_stride_q,
622  nhead_stride_k,
623  nhead_stride_v,
624  nhead_stride_bias,
625  nhead_stride_randval,
626  nhead_stride_lse,
627  nhead_stride_o,
628  batch_stride_q,
629  batch_stride_k,
630  batch_stride_v,
631  batch_stride_bias,
632  batch_stride_randval,
633  batch_stride_lse,
634  batch_stride_o,
635  window_size_left,
636  window_size_right,
637  sink_size,
638  mask_type,
639  p_drop,
640  s_randval,
641  std::make_pair(std::get<0>(drop_seed_offset), std::get<1>(drop_seed_offset)),
642  cu_seqlen_q_ptr,
643  cu_seqlen_k_ptr);
644  }
645 
646  template <bool Cond = kIsGroupMode>
647  CK_TILE_HOST static constexpr std::enable_if_t<Cond, Kargs>
648  MakeKargsImpl(const void* q_ptr,
649  const void* k_ptr,
650  const void* v_ptr,
651  const void* bias_ptr,
652  const void* q_descale_ptr,
653  const void* k_descale_ptr,
654  const void* v_descale_ptr,
655  void* rand_val_ptr,
656  void* lse_ptr,
657  void* o_ptr,
658  const void* seqstart_q_ptr,
659  const void* seqstart_k_ptr,
660  const void* seqlen_q_ptr,
661  const void* seqlen_k_ptr,
662  ck_tile::index_t hdim_q,
663  ck_tile::index_t hdim_v,
664  ck_tile::index_t num_head_q,
665  ck_tile::index_t nhead_ratio_qk,
666  float scale_s,
667  float logits_soft_cap,
668  ck_tile::index_t stride_q,
669  ck_tile::index_t stride_k,
670  ck_tile::index_t stride_v,
671  ck_tile::index_t stride_bias,
672  ck_tile::index_t stride_randval,
673  ck_tile::index_t stride_o,
674  ck_tile::index_t nhead_stride_q,
675  ck_tile::index_t nhead_stride_k,
676  ck_tile::index_t nhead_stride_v,
677  ck_tile::index_t nhead_stride_bias,
678  ck_tile::index_t nhead_stride_randval,
679  ck_tile::index_t nhead_stride_lse,
680  ck_tile::index_t nhead_stride_o,
681  ck_tile::index_t window_size_left,
682  ck_tile::index_t window_size_right,
683  ck_tile::index_t sink_size,
684  ck_tile::index_t mask_type,
685  ck_tile::index_t min_seqlen_q,
686  float p_drop,
687  bool s_randval,
688  std::variant<std::pair<uint64_t, uint64_t>, std::pair<const void*, const void*>>
689  drop_seed_offset,
690  const void* cu_seqlen_q_ptr = nullptr,
691  const void* cu_seqlen_k_ptr = nullptr)
692  {
693  Kargs kargs{{q_ptr,
694  k_ptr,
695  v_ptr,
696  o_ptr,
697  -1, // seqlen will be updated by another pointer
698  -1, //
699  hdim_q,
700  hdim_v,
701  num_head_q,
702  nhead_ratio_qk,
703 #if CK_TILE_FMHA_FWD_FAST_EXP2
704  static_cast<float>(scale_s * ck_tile::log2e_v<>),
705 #else
706  scale_s,
707 #endif
708  stride_q,
709  stride_k,
710  stride_v,
711  stride_o,
712  nhead_stride_q,
713  nhead_stride_k,
714  nhead_stride_v,
715  nhead_stride_o}, // args for common karg
716  {}, // placeholder for bias
717  {}, // placeholder for mask
718  {}, // placeholder for lse
719  {}, // placeholder for qscale
720  {}, // placeholder for dropout
721  {}, // placeholder for logits_soft_cap
722  {}, // placeholder for min_seqlen_q
723  reinterpret_cast<const int32_t*>(seqstart_q_ptr),
724  reinterpret_cast<const int32_t*>(seqstart_k_ptr),
725  reinterpret_cast<const int32_t*>(seqlen_q_ptr),
726  reinterpret_cast<const int32_t*>(seqlen_k_ptr)};
727 
729  {
730  kargs.bias_ptr = bias_ptr;
731  kargs.stride_bias = stride_bias;
732  kargs.nhead_stride_bias = nhead_stride_bias;
733  }
734  else if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI)
735  {
736  kargs.alibi_slope_ptr = bias_ptr;
737  kargs.alibi_slope_stride = stride_bias;
738  }
739  if constexpr(kHasMask)
740  {
741  kargs.window_size_left = window_size_left;
742  kargs.window_size_right = window_size_right;
743  kargs.sink_size = sink_size;
744  kargs.mask_type = static_cast<ck_tile::GenericAttentionMaskEnum>(mask_type);
745  }
746  if constexpr(kStoreLSE)
747  {
748  kargs.lse_ptr = lse_ptr;
749  kargs.nhead_stride_lse = nhead_stride_lse;
750  }
752  {
753  kargs.q_descale_ptr = q_descale_ptr;
754  kargs.k_descale_ptr = k_descale_ptr;
755  kargs.v_descale_ptr = v_descale_ptr;
756  }
757  if constexpr(kHasDropout)
758  {
759  if(drop_seed_offset.index() == 0) // seed & offset come from host
760  {
761  const auto& [seed, offset] = std::get<0>(drop_seed_offset);
762  kargs.init_dropout(p_drop, seed, offset);
763  }
764  else // seed & offset come from device
765  {
766  const auto& [seed_ptr, offset_ptr] = std::get<1>(drop_seed_offset);
767  kargs.init_dropout(p_drop,
768  reinterpret_cast<const uint64_t*>(seed_ptr),
769  reinterpret_cast<const uint64_t*>(offset_ptr));
770  }
771 
772  kargs.rand_val_ptr = rand_val_ptr;
773  kargs.stride_randval = stride_randval;
774  kargs.nhead_stride_randval = nhead_stride_randval;
775  kargs.is_store_randval = s_randval;
776  }
777  if constexpr(kHasLogitsSoftCap)
778  {
779  kargs.init_logits_soft_cap(logits_soft_cap);
780  }
781  if constexpr(kSkipMinSeqlenQ)
782  {
783  kargs.min_seqlen_q = min_seqlen_q;
784  }
785 
786  kargs.cu_seqlen_q_ptr = reinterpret_cast<const int32_t*>(cu_seqlen_q_ptr);
787  kargs.cu_seqlen_k_ptr = reinterpret_cast<const int32_t*>(cu_seqlen_k_ptr);
788  return kargs;
789  }
790 
791  // std::variant<> can't take in a list initializer, overload for backward compatibility
792  template <bool Cond = kIsGroupMode>
793  CK_TILE_HOST static constexpr std::enable_if_t<Cond, Kargs>
794  MakeKargs(const void* q_ptr,
795  const void* k_ptr,
796  const void* v_ptr,
797  const void* bias_ptr,
798  const void* q_descale_ptr,
799  const void* k_descale_ptr,
800  const void* v_descale_ptr,
801  void* rand_val_ptr,
802  void* lse_ptr,
803  void* o_ptr,
804  const void* seqstart_q_ptr,
805  const void* seqstart_k_ptr,
806  const void* seqlen_q_ptr,
807  const void* seqlen_k_ptr,
808  ck_tile::index_t hdim_q,
809  ck_tile::index_t hdim_v,
810  ck_tile::index_t num_head_q,
811  ck_tile::index_t nhead_ratio_qk,
812  float scale_s,
813  float logits_soft_cap,
814  ck_tile::index_t stride_q,
815  ck_tile::index_t stride_k,
816  ck_tile::index_t stride_v,
817  ck_tile::index_t stride_bias,
818  ck_tile::index_t stride_randval,
819  ck_tile::index_t stride_o,
820  ck_tile::index_t nhead_stride_q,
821  ck_tile::index_t nhead_stride_k,
822  ck_tile::index_t nhead_stride_v,
823  ck_tile::index_t nhead_stride_bias,
824  ck_tile::index_t nhead_stride_randval,
825  ck_tile::index_t nhead_stride_lse,
826  ck_tile::index_t nhead_stride_o,
827  ck_tile::index_t window_size_left,
828  ck_tile::index_t window_size_right,
829  ck_tile::index_t sink_size,
830  ck_tile::index_t mask_type,
831  ck_tile::index_t min_seqlen_q,
832  float p_drop,
833  bool s_randval,
834  const std::tuple<uint64_t, uint64_t>& drop_seed_offset,
835  const void* cu_seqlen_q_ptr = nullptr,
836  const void* cu_seqlen_k_ptr = nullptr)
837  {
838  return MakeKargsImpl(
839  q_ptr,
840  k_ptr,
841  v_ptr,
842  bias_ptr,
843  q_descale_ptr,
844  k_descale_ptr,
845  v_descale_ptr,
846  rand_val_ptr,
847  lse_ptr,
848  o_ptr,
849  seqstart_q_ptr,
850  seqstart_k_ptr,
851  seqlen_q_ptr,
852  seqlen_k_ptr,
853  hdim_q,
854  hdim_v,
855  num_head_q,
856  nhead_ratio_qk,
857  scale_s,
858  logits_soft_cap,
859  stride_q,
860  stride_k,
861  stride_v,
862  stride_bias,
863  stride_randval,
864  stride_o,
865  nhead_stride_q,
866  nhead_stride_k,
867  nhead_stride_v,
868  nhead_stride_bias,
869  nhead_stride_randval,
870  nhead_stride_lse,
871  nhead_stride_o,
872  window_size_left,
873  window_size_right,
874  sink_size,
875  mask_type,
876  min_seqlen_q,
877  p_drop,
878  s_randval,
879  std::make_pair(std::get<0>(drop_seed_offset), std::get<1>(drop_seed_offset)),
880  cu_seqlen_q_ptr,
881  cu_seqlen_k_ptr);
882  }
883 
884  // std::variant<> can't take in a list initializer, overload for backward compatibility
885  template <bool Cond = kIsGroupMode>
886  CK_TILE_HOST static constexpr std::enable_if_t<Cond, Kargs>
887  MakeKargs(const void* q_ptr,
888  const void* k_ptr,
889  const void* v_ptr,
890  const void* bias_ptr,
891  const void* q_descale_ptr,
892  const void* k_descale_ptr,
893  const void* v_descale_ptr,
894  void* rand_val_ptr,
895  void* lse_ptr,
896  void* o_ptr,
897  const void* seqstart_q_ptr,
898  const void* seqstart_k_ptr,
899  const void* seqlen_q_ptr,
900  const void* seqlen_k_ptr,
901  ck_tile::index_t hdim_q,
902  ck_tile::index_t hdim_v,
903  ck_tile::index_t num_head_q,
904  ck_tile::index_t nhead_ratio_qk,
905  float scale_s,
906  float logits_soft_cap,
907  ck_tile::index_t stride_q,
908  ck_tile::index_t stride_k,
909  ck_tile::index_t stride_v,
910  ck_tile::index_t stride_bias,
911  ck_tile::index_t stride_randval,
912  ck_tile::index_t stride_o,
913  ck_tile::index_t nhead_stride_q,
914  ck_tile::index_t nhead_stride_k,
915  ck_tile::index_t nhead_stride_v,
916  ck_tile::index_t nhead_stride_bias,
917  ck_tile::index_t nhead_stride_randval,
918  ck_tile::index_t nhead_stride_lse,
919  ck_tile::index_t nhead_stride_o,
920  ck_tile::index_t window_size_left,
921  ck_tile::index_t window_size_right,
922  ck_tile::index_t sink_size,
923  ck_tile::index_t mask_type,
924  ck_tile::index_t min_seqlen_q,
925  float p_drop,
926  bool s_randval,
927  const std::tuple<const void*, const void*>& drop_seed_offset,
928  const void* cu_seqlen_q_ptr = nullptr,
929  const void* cu_seqlen_k_ptr = nullptr)
930  {
931  return MakeKargsImpl(
932  q_ptr,
933  k_ptr,
934  v_ptr,
935  bias_ptr,
936  q_descale_ptr,
937  k_descale_ptr,
938  v_descale_ptr,
939  rand_val_ptr,
940  lse_ptr,
941  o_ptr,
942  seqstart_q_ptr,
943  seqstart_k_ptr,
944  seqlen_q_ptr,
945  seqlen_k_ptr,
946  hdim_q,
947  hdim_v,
948  num_head_q,
949  nhead_ratio_qk,
950  scale_s,
951  logits_soft_cap,
952  stride_q,
953  stride_k,
954  stride_v,
955  stride_bias,
956  stride_randval,
957  stride_o,
958  nhead_stride_q,
959  nhead_stride_k,
960  nhead_stride_v,
961  nhead_stride_bias,
962  nhead_stride_randval,
963  nhead_stride_lse,
964  nhead_stride_o,
965  window_size_left,
966  window_size_right,
967  sink_size,
968  mask_type,
969  min_seqlen_q,
970  p_drop,
971  s_randval,
972  std::make_pair(std::get<0>(drop_seed_offset), std::get<1>(drop_seed_offset)),
973  cu_seqlen_q_ptr,
974  cu_seqlen_k_ptr);
975  }
976 
977  CK_TILE_HOST static constexpr auto GridSize(ck_tile::index_t batch_size_,
978  ck_tile::index_t nhead_,
979  ck_tile::index_t seqlen_q_,
980  ck_tile::index_t hdim_v_,
981  bool has_padded_seqlen_k = false)
982  {
983  // has_padded_seqlen_k is determined by checking (seqlen_k_ptr != nullptr)
984  if(has_padded_seqlen_k)
985  {
986  // TODO: this may need tuning
987  return dim3(nhead_,
988  batch_size_,
989  ck_tile::integer_divide_ceil(seqlen_q_, FmhaPipeline::kM0) *
990  ck_tile::integer_divide_ceil(hdim_v_, FmhaPipeline::kN1));
991  }
992  else
993  {
994  // TODO: this may need tuning
995  return dim3(nhead_,
996  ck_tile::integer_divide_ceil(seqlen_q_, FmhaPipeline::kM0) *
997  ck_tile::integer_divide_ceil(hdim_v_, FmhaPipeline::kN1),
998  batch_size_);
999  }
1000  }
1001 
1002  CK_TILE_DEVICE static constexpr auto GetTileIndex(const Kargs& kargs)
1003  {
1004  bool has_padded_seqlen_k = false;
1005 
1006  if constexpr(kIsGroupMode)
1007  has_padded_seqlen_k = (kargs.seqlen_k_ptr != nullptr);
1008 
1009  if(has_padded_seqlen_k)
1010  {
1011  // const index_t num_tile_m0 = seqlen_q / kM0;
1012  const index_t num_tile_n1 =
1013  ck_tile::integer_divide_ceil(kargs.hdim_v, FmhaPipeline::kN1);
1014 
1015  const index_t i_block = blockIdx.z;
1016  const index_t i_nhead = blockIdx.x;
1017  const index_t i_batch = blockIdx.y;
1018 
1019  const auto f = [](index_t dividend, index_t divisor) {
1020  index_t quotient = dividend / divisor;
1021  index_t modulus = dividend - quotient * divisor;
1022  return ck_tile::make_tuple(quotient, modulus);
1023  };
1024 
1025  const auto [i_tile_m, i_tile_n] = f(i_block, num_tile_n1);
1026 
1027  if constexpr(kHasMask)
1028  {
1029  // assume that num_tile_n1 is always 1
1030  return ck_tile::make_tuple(gridDim.z - 1 - i_tile_m, i_tile_n, i_nhead, i_batch);
1031  }
1032  else
1033  {
1034  return ck_tile::make_tuple(i_tile_m, i_tile_n, i_nhead, i_batch);
1035  }
1036  }
1037  else
1038  {
1039  // const index_t num_tile_m0 = seqlen_q / kM0;
1040  const index_t num_tile_n1 =
1041  ck_tile::integer_divide_ceil(kargs.hdim_v, FmhaPipeline::kN1);
1042 
1043  const index_t i_block = blockIdx.y; // blockIdx.x
1044  const index_t i_nhead = blockIdx.x; // blockIdx.y
1045  const index_t i_batch = blockIdx.z;
1046 
1047  const auto f = [](index_t dividend, index_t divisor) {
1048  index_t quotient = dividend / divisor;
1049  index_t modulus = dividend - quotient * divisor;
1050  return ck_tile::make_tuple(quotient, modulus);
1051  };
1052 
1053  const auto [i_tile_m, i_tile_n] = f(i_block, num_tile_n1);
1054 
1055  if constexpr(kHasMask)
1056  {
1057  // assume that num_tile_n1 is always 1
1058  return ck_tile::make_tuple(gridDim.y - 1 - i_tile_m, i_tile_n, i_nhead, i_batch);
1059  }
1060  else
1061  {
1062  return ck_tile::make_tuple(i_tile_m, i_tile_n, i_nhead, i_batch);
1063  }
1064  }
1065  }
1066 
1067  CK_TILE_HOST static dim3 BlockSize()
1068  {
1069  if(is_wave32())
1070  {
1071  return dim3(kBlockSize / 2);
1072  }
1073  else
1074  {
1075  return dim3(kBlockSize);
1076  }
1077  }
1078 
1080  {
1081  return ck_tile::max(FmhaPipeline::GetSmemSize(), EpiloguePipeline::GetSmemSize());
1082  }
1083 
1084  CK_TILE_DEVICE void operator()(Kargs kargs) const
1085  {
1086  if constexpr(kIsAvailable)
1087  run_(std::move(kargs));
1088  }
1089 
1090  CK_TILE_DEVICE void run_(Kargs kargs) const
1091  {
1092  if constexpr(kPipelineName != "qr_async_trload")
1093  {
1094  // allocate LDS
1095  __shared__ char smem_ptr[GetSmemSize()];
1096 
1097  // divide problem
1098  const auto [i_tile_m, i_tile_n, i_nhead, i_batch] = GetTileIndex(kargs);
1099 
1100  const index_t i_m0 = amd_wave_read_first_lane(i_tile_m * FmhaPipeline::kM0);
1101  const index_t i_n1 = amd_wave_read_first_lane(i_tile_n * FmhaPipeline::kN1);
1102 
1103  long_index_t batch_offset_q = 0;
1104  long_index_t batch_offset_k = 0;
1105  long_index_t batch_offset_v = 0;
1106  long_index_t batch_offset_bias = 0;
1107  long_index_t batch_offset_randval = 0;
1108  long_index_t batch_offset_lse = 0;
1109  long_index_t batch_offset_o = 0;
1110 
1111  if constexpr(kIsGroupMode)
1112  {
1113  // Use seqstart_q_ptr and seqstart_k_ptr for physical starts
1114  const long_index_t query_start = kargs.seqstart_q_ptr[i_batch];
1115  const long_index_t key_start = kargs.seqstart_k_ptr[i_batch];
1116 
1117  // DRAM base offsets use physical starts
1118  batch_offset_q = query_start * kargs.stride_q;
1119  batch_offset_k = key_start * kargs.stride_k;
1120  if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
1121  {
1122  batch_offset_v = key_start * kargs.stride_v;
1123  }
1124  else
1125  {
1126  batch_offset_v = key_start;
1127  }
1129  {
1130  batch_offset_bias = query_start * kargs.stride_bias;
1131  }
1132  if constexpr(kStoreLSE)
1133  {
1134  // LSE follows the physical layout to stay consistent with other tensors
1135  batch_offset_lse = query_start;
1136  }
1137  if constexpr(kHasDropout)
1138  {
1139  batch_offset_randval = query_start * kargs.stride_randval;
1140  }
1141  batch_offset_o = query_start * kargs.stride_o;
1142 
1143  // real logical lengths (exclude PAD)
1144  // Priority: seqlen_q_ptr > cu_seqlen_q_ptr > calculated from seqstart_q_ptr
1145  if(kargs.seqlen_q_ptr != nullptr)
1146  {
1147  kargs.seqlen_q = kargs.seqlen_q_ptr[i_batch];
1148  }
1149  else if(kargs.cu_seqlen_q_ptr != nullptr)
1150  {
1151  kargs.seqlen_q =
1152  kargs.cu_seqlen_q_ptr[i_batch + 1] - kargs.cu_seqlen_q_ptr[i_batch];
1153  }
1154  else
1155  {
1156  const auto adjusted_seqstart_q_ptr = kargs.seqstart_q_ptr + i_batch;
1157  kargs.seqlen_q = adjusted_seqstart_q_ptr[1] - adjusted_seqstart_q_ptr[0];
1158  }
1159 
1160  if constexpr(kSkipMinSeqlenQ)
1161  {
1162  if(kargs.seqlen_q <= kargs.min_seqlen_q)
1163  {
1164  return;
1165  }
1166  }
1167 
1168  // terminate unnecessary blocks earlier
1169  if(kargs.seqlen_q <= i_m0)
1170  {
1171  return;
1172  }
1173 
1174  if(kargs.seqlen_k_ptr != nullptr)
1175  {
1176  kargs.seqlen_k = kargs.seqlen_k_ptr[i_batch];
1177  }
1178  else if(kargs.cu_seqlen_k_ptr != nullptr)
1179  {
1180  kargs.seqlen_k =
1181  kargs.cu_seqlen_k_ptr[i_batch + 1] - kargs.cu_seqlen_k_ptr[i_batch];
1182  }
1183  else
1184  {
1185  const auto adjusted_seqstart_k_ptr = kargs.seqstart_k_ptr + i_batch;
1186  kargs.seqlen_k = adjusted_seqstart_k_ptr[1] - adjusted_seqstart_k_ptr[0];
1187  }
1188  }
1189  else
1190  {
1191  batch_offset_q = static_cast<long_index_t>(i_batch) * kargs.batch_stride_q;
1192  batch_offset_k = static_cast<long_index_t>(i_batch) * kargs.batch_stride_k;
1193  batch_offset_v = static_cast<long_index_t>(i_batch) * kargs.batch_stride_v;
1195  {
1196  batch_offset_bias =
1197  static_cast<long_index_t>(i_batch) * kargs.batch_stride_bias;
1198  }
1199  if constexpr(kStoreLSE)
1200  {
1201  batch_offset_lse = static_cast<long_index_t>(i_batch) * kargs.batch_stride_lse;
1202  }
1203  if constexpr(kHasDropout)
1204  {
1205  batch_offset_randval =
1206  static_cast<long_index_t>(i_batch) * kargs.batch_stride_randval;
1207  }
1208  batch_offset_o = static_cast<long_index_t>(i_batch) * kargs.batch_stride_o;
1209 
1210  // If cumulative seqlen pointers are provided, override per-batch effective lengths
1211  if(kargs.cu_seqlen_q_ptr != nullptr)
1212  {
1213  kargs.seqlen_q =
1214  kargs.cu_seqlen_q_ptr[i_batch + 1] - kargs.cu_seqlen_q_ptr[i_batch];
1215  }
1216  if(kargs.cu_seqlen_k_ptr != nullptr)
1217  {
1218  kargs.seqlen_k =
1219  kargs.cu_seqlen_k_ptr[i_batch + 1] - kargs.cu_seqlen_k_ptr[i_batch];
1220  }
1221  }
1222 
1223  // for simplicity, batch stride we just modify the pointer
1224  const QDataType* q_ptr = reinterpret_cast<const QDataType*>(kargs.q_ptr) +
1225  static_cast<long_index_t>(i_nhead) * kargs.nhead_stride_q +
1226  batch_offset_q;
1227  const KDataType* k_ptr =
1228  reinterpret_cast<const KDataType*>(kargs.k_ptr) +
1229  static_cast<long_index_t>(i_nhead / kargs.nhead_ratio_qk) * kargs.nhead_stride_k +
1230  batch_offset_k;
1231  const VDataType* v_ptr =
1232  reinterpret_cast<const VDataType*>(kargs.v_ptr) +
1233  static_cast<long_index_t>(i_nhead / kargs.nhead_ratio_qk) * kargs.nhead_stride_v +
1234  batch_offset_v;
1235  ODataType* o_ptr = reinterpret_cast<ODataType*>(kargs.o_ptr) +
1236  static_cast<long_index_t>(i_nhead) * kargs.nhead_stride_o +
1237  batch_offset_o;
1238 
1239  // Q/K/V DRAM and DRAM window
1240  const auto q_dram = [&]() {
1241  const auto q_dram_naive = make_naive_tensor_view<address_space_enum::global>(
1242  q_ptr,
1243  make_tuple(kargs.seqlen_q, kargs.hdim_q),
1244  make_tuple(kargs.stride_q, 1),
1246  number<1>{});
1247  if constexpr(FmhaPipeline::kQLoadOnce)
1248  {
1249  return pad_tensor_view(q_dram_naive,
1253  }
1254  else
1255  {
1256  return pad_tensor_view(
1257  q_dram_naive,
1260  }
1261  }();
1262  const auto k_dram = [&]() {
1263  const auto k_dram_naive = make_naive_tensor_view<address_space_enum::global>(
1264  k_ptr,
1265  make_tuple(kargs.seqlen_k, kargs.hdim_q),
1266  make_tuple(kargs.stride_k, 1),
1268  number<1>{});
1269 
1270  constexpr bool kPadSeqLenK_ = kUseAsyncCopy ? kPadSeqLenK : false;
1271  return pad_tensor_view(
1272  k_dram_naive,
1275  }();
1276  const auto v_dram = [&]() {
1277  if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
1278  {
1279  const auto v_dram_naive = make_naive_tensor_view<address_space_enum::global>(
1280  v_ptr,
1281  make_tuple(kargs.seqlen_k, kargs.hdim_v),
1282  make_tuple(kargs.stride_v, 1),
1284  number<1>{});
1285 
1286  const auto v_dram_transposed = transform_tensor_view(
1287  v_dram_naive,
1289  make_pass_through_transform(kargs.seqlen_k)),
1292 
1293  constexpr bool kPadSeqLenK_ = kUseAsyncCopy ? kPadSeqLenK : false;
1294  return pad_tensor_view(
1295  v_dram_transposed,
1298  }
1299  else
1300  {
1301  const auto v_dram_naive = make_naive_tensor_view<address_space_enum::global>(
1302  v_ptr,
1303  make_tuple(kargs.hdim_v, kargs.seqlen_k),
1304  make_tuple(kargs.stride_v, 1),
1306  number<1>{});
1307 
1308  constexpr bool kPadHeadDimV_ = kUseAsyncCopy ? kPadHeadDimV : false;
1309  return pad_tensor_view(
1310  v_dram_naive,
1313  }
1314  }();
1315 
1316  auto q_dram_window = make_tile_window(
1317  q_dram,
1318  [&]() {
1319  if constexpr(FmhaPipeline::kQLoadOnce)
1322  else
1324  }(),
1325  {i_m0, 0});
1326 
1327  auto k_dram_window = make_tile_window(
1328  k_dram,
1330  {0, 0});
1331 
1332  auto v_dram_window = make_tile_window(
1333  v_dram,
1335  {i_n1, 0});
1338  const auto bias_dram_window = [&, i_nhead_ = i_nhead]() {
1339  constexpr auto bias_dram_window_lengths =
1342  {
1343  const BiasDataType* bias_ptr =
1344  reinterpret_cast<const BiasDataType*>(kargs.bias_ptr) +
1345  static_cast<long_index_t>(i_nhead_) * kargs.nhead_stride_bias +
1346  batch_offset_bias;
1347 
1348  const auto bias_dram = [&]() {
1349  const auto bias_dram_naive =
1350  make_naive_tensor_view<address_space_enum::global>(
1351  bias_ptr,
1352  make_tuple(kargs.seqlen_q, kargs.seqlen_k),
1353  make_tuple(kargs.stride_bias, 1),
1355  number<1>{});
1356 
1357  return pad_tensor_view(bias_dram_naive,
1358  bias_dram_window_lengths,
1360  }();
1361 
1362  return make_tile_window(bias_dram, bias_dram_window_lengths, {i_m0, 0});
1363  }
1364  else
1365  {
1366  return make_null_tile_window(bias_dram_window_lengths);
1367  }
1368  }();
1369 
1370  // lse
1371  auto lse_dram_window = [&, i_nhead_ = i_nhead]() {
1372  constexpr auto lse_dram_window_lengths = make_tuple(number<FmhaPipeline::kM0>{});
1373  if constexpr(kStoreLSE)
1374  {
1375  LSEDataType* lse_ptr =
1376  reinterpret_cast<LSEDataType*>(kargs.lse_ptr) +
1377  static_cast<long_index_t>(i_nhead_) * kargs.nhead_stride_lse +
1378  batch_offset_lse;
1379 
1380  const auto lse_dram = [&]() {
1381  const auto lse_dram_naive =
1382  make_naive_tensor_view<address_space_enum::global>(
1383  lse_ptr,
1384  make_tuple(kargs.seqlen_q),
1385  make_tuple(1),
1386  number<1>{},
1387  number<1>{});
1388 
1389  return pad_tensor_view(
1390  lse_dram_naive, lse_dram_window_lengths, sequence<kPadSeqLenQ>{});
1391  }();
1392 
1393  return make_tile_window(lse_dram, lse_dram_window_lengths, {i_m0});
1394  }
1395  else
1396  {
1397  return make_null_tile_window(lse_dram_window_lengths);
1398  }
1399  }();
1400 
1401  auto dropout = [&, i_nhead_ = i_nhead, i_batch_ = i_batch]() {
1402  if constexpr(kHasDropout)
1403  {
1404  return BlockDropout{i_batch_,
1405  i_nhead_,
1406  kargs.num_head_q,
1407  kargs.is_drop_seed_offset_from_host ? kargs.drop_seed.val
1408  : *kargs.drop_seed.ptr,
1409  kargs.is_drop_seed_offset_from_host
1410  ? kargs.drop_offset.val
1411  : *kargs.drop_offset.ptr,
1412  kargs.rp_undrop,
1413  kargs.p_undrop_in_uint8_t,
1414  kargs.is_store_randval};
1415  }
1416  else
1417  {
1418  return NullBlockDropout{};
1419  };
1420  }();
1421 
1422  auto randval_dram_window = [&, i_nhead_ = i_nhead]() {
1423  constexpr auto randval_dram_window_lengths =
1425  if constexpr(kHasDropout)
1426  {
1427  RandValOutputDataType* rand_val_ptr =
1428  reinterpret_cast<RandValOutputDataType*>(kargs.rand_val_ptr) +
1429  static_cast<long_index_t>(i_nhead_) * kargs.nhead_stride_randval +
1430  batch_offset_randval;
1431 
1432  const auto randval_dram = [&]() {
1433  const auto randval_dram_naive =
1434  make_naive_tensor_view<address_space_enum::global>(
1435  rand_val_ptr,
1436  make_tuple(kargs.seqlen_q, kargs.seqlen_k),
1437  make_tuple(kargs.stride_randval, 1),
1439  number<1>{});
1440 
1441  return pad_tensor_view(randval_dram_naive,
1442  randval_dram_window_lengths,
1444  }();
1445 
1446  return make_tile_window(randval_dram, randval_dram_window_lengths, {i_m0, 0});
1447  }
1448  else
1449  {
1450  return make_null_tile_window(randval_dram_window_lengths);
1451  }
1452  }();
1453 
1454  FmhaMask mask = [&]() {
1455  if constexpr(kHasMask)
1456  return ck_tile::make_generic_attention_mask_from_lr_window<FmhaMask>(
1457  kargs.window_size_left,
1458  kargs.window_size_right,
1459  kargs.sink_size,
1460  kargs.seqlen_q,
1461  kargs.seqlen_k,
1463  else
1464  return FmhaMask{kargs.seqlen_q, kargs.seqlen_k};
1465  }();
1466 
1467  // WA i_batch capture structure binding before c++20
1468  auto position_encoding = [&, i_batch_ = i_batch, i_nhead_ = i_nhead]() {
1469  if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI)
1470  {
1471  // data loading, shared by entire wg
1472  // TODO: how to use s_read?
1473  SaccDataType slope =
1474  *(reinterpret_cast<const SaccDataType*>(kargs.alibi_slope_ptr) +
1475  i_batch_ * kargs.alibi_slope_stride + i_nhead_);
1476 #if CK_TILE_FMHA_FWD_FAST_EXP2
1477  slope *= ck_tile::log2e_v<>;
1478 #endif
1479  if constexpr(kHasMask)
1480  {
1481  return make_alibi_from_lr_mask<SaccDataType, true>(slope,
1482  kargs.window_size_left,
1483  kargs.window_size_right,
1484  kargs.seqlen_q,
1485  kargs.seqlen_k,
1486  kargs.mask_type);
1487  }
1488  else
1489  {
1491  slope, kargs.seqlen_q, kargs.seqlen_k, AlibiMode::FROM_BOTTOM_RIGHT};
1492  }
1493  }
1494  else
1495  {
1497  }
1498  }();
1499 
1500  AttentionVariant variant;
1501  const auto variant_params = [&] {
1502  const float scale_s = [&] {
1504  {
1505  float q_descale = *(reinterpret_cast<const float*>(kargs.q_descale_ptr));
1506  float k_descale = *(reinterpret_cast<const float*>(kargs.k_descale_ptr));
1507 
1508  return kargs.scale_s * q_descale * k_descale;
1509  }
1510  else
1511  {
1512  return kargs.scale_s;
1513  }
1514  }();
1515 
1516  if constexpr(kHasLogitsSoftCap)
1517  {
1519  mask, scale_s, kargs.logits_soft_cap, kargs.logits_soft_cap_rcp};
1520  }
1521  else
1522  {
1523  return ck_tile::StandardAttentionParams<FmhaMask>{mask, scale_s};
1524  }
1525  }();
1526 
1527  BlockIndices block_indices{i_batch, i_nhead, i_nhead / kargs.nhead_ratio_qk};
1528 
1529  auto o_acc_tile = [&]() {
1531  {
1532  // TODO - move global load of descale to pipeline
1533  float v_descale = *(reinterpret_cast<const float*>(kargs.v_descale_ptr));
1534 
1535  float scale_p =
1536  ck_tile::type_convert<float>(ck_tile::numeric<PDataType>::max());
1537  float scale_o = v_descale / scale_p;
1538 
1539  auto o_acc_element_func = [&]() {
1540  if constexpr(std::is_same_v<ODataType, ck_tile::fp8_t>)
1542  ck_tile::scales{scale_o});
1543  else
1544  return ck_tile::scales{scale_o};
1545  }();
1546  return FmhaPipeline{}(q_dram_window,
1547  identity{}, // q_element_func
1548  k_dram_window,
1549  identity{}, // k_element_func
1550  v_dram_window,
1551  identity{}, // v_element_func
1552  bias_dram_window,
1553  identity{}, // bias_element_func
1554  randval_dram_window,
1555  lse_dram_window,
1556  identity{}, // lse_element_func
1557  identity{}, // s_acc_element_func
1558  scales{scale_p}, // p_compute_element_func
1559  o_acc_element_func, // o_acc_element_func
1560  mask,
1561  position_encoding,
1562  variant_params.sm_scale,
1563  variant,
1564  variant_params,
1565  block_indices,
1566  smem_ptr,
1567  dropout);
1568  }
1569  else
1570  {
1571  return FmhaPipeline{}(q_dram_window,
1572  k_dram_window,
1573  v_dram_window,
1574  bias_dram_window,
1575  randval_dram_window,
1576  lse_dram_window,
1577  mask,
1578  position_encoding,
1579  variant_params.sm_scale,
1580  variant,
1581  variant_params,
1582  block_indices,
1583  smem_ptr,
1584  dropout);
1585  }
1586  }();
1587 
1588  // O DRAM and O DRAM window
1589  auto o_dram = [&]() {
1590  const auto o_dram_naive = make_naive_tensor_view<address_space_enum::global>(
1591  o_ptr,
1592  make_tuple(kargs.seqlen_q, kargs.hdim_v),
1593  make_tuple(kargs.stride_o, 1),
1595  number<1>{});
1596 
1597  return pad_tensor_view(
1598  o_dram_naive,
1601  }();
1602 
1603  auto o_dram_window = make_tile_window(
1604  o_dram,
1606  {i_m0, i_n1});
1607 
1608  EpiloguePipeline{}(o_dram_window, o_acc_tile, nullptr);
1609  }
1610  else
1611  {
1612  // TODO: Refine the logical here.
1613  // In Decode case
1614  // 1. we don't expect KV data reused by different ThreadGroups, bypass the cache
1615  // 2. limit the LDS usage, as we want higher occupancy
1616  // In Prefill case
1617  // 1. we expect KV data reused by different ThreadGroups, use cache
1618  // 2. use more LDS, as we want better memory latency hiding
1619  // If SplitKV off, we don't expect Q data reused by different ThreadGroups, bypass the
1620  // cache
1621  constexpr bool PrefillCase = FmhaPipeline::kM0 > 64;
1622  // divide problem
1623  const auto [i_tile_m, i_tile_n, i_nhead, i_batch] = GetTileIndex(kargs);
1624 
1625  const index_t i_m0 = i_tile_m * FmhaPipeline::kM0;
1626  const index_t i_n1 = i_tile_n * FmhaPipeline::kN1;
1627 
1628  long_index_t batch_offset_q = 0;
1629  long_index_t batch_offset_k = 0; // unused for paged-kvcache
1630  long_index_t batch_offset_v = 0; // unused for paged-kvcache
1631  long_index_t batch_offset_bias = 0;
1632  long_index_t batch_offset_lse = 0;
1633  long_index_t batch_offset_o = 0;
1634  // index_t kv_l2p_offset =
1635  // 0; // logical-to-physical offset of seqlen_k coordinate. only used for
1636  // paged-kvcache
1637 
1638  if constexpr(kIsGroupMode)
1639  {
1640  // get starting offset for each batch - use seqstart_q_ptr/seqstart_k_ptr for
1641  // physical starts
1642  const long_index_t query_start = kargs.seqstart_q_ptr[i_batch];
1643  const long_index_t key_start = kargs.seqstart_k_ptr[i_batch];
1644 
1645  batch_offset_q = query_start * kargs.stride_q;
1646  batch_offset_k = key_start * kargs.stride_k;
1647  if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
1648  {
1649  batch_offset_v = key_start * kargs.stride_v;
1650  }
1651  else
1652  {
1653  // col-major V: offset along seqlen dimension is scalar index
1654  batch_offset_v = key_start;
1655  }
1657  {
1658  batch_offset_bias = query_start * kargs.stride_bias;
1659  }
1660 
1661  // LSE layout is [nhead, total_seqlen] following the physical layout for Q/O
1662  batch_offset_lse = query_start;
1663  batch_offset_o = query_start * kargs.stride_o;
1664 
1665  // get real # queries & # keys under group mode
1666  if(kargs.seqlen_q_ptr != nullptr)
1667  {
1668  kargs.seqlen_q = kargs.seqlen_q_ptr[i_batch];
1669  }
1670  else if(kargs.cu_seqlen_q_ptr != nullptr)
1671  {
1672  kargs.seqlen_q =
1673  kargs.cu_seqlen_q_ptr[i_batch + 1] - kargs.cu_seqlen_q_ptr[i_batch];
1674  }
1675  else
1676  {
1677  kargs.seqlen_q =
1678  kargs.seqstart_q_ptr[i_batch + 1] - kargs.seqstart_q_ptr[i_batch];
1679  }
1680 
1681  // # of required blocks is different in each groups, terminate unnecessary blocks
1682  // earlier
1683  if(kargs.seqlen_q <= i_m0)
1684  {
1685  return;
1686  }
1687 
1688  if(kargs.seqlen_k_ptr != nullptr)
1689  {
1690  kargs.seqlen_k = kargs.seqlen_k_ptr[i_batch];
1691  }
1692  else if(kargs.cu_seqlen_k_ptr != nullptr)
1693  {
1694  kargs.seqlen_k =
1695  kargs.cu_seqlen_k_ptr[i_batch + 1] - kargs.cu_seqlen_k_ptr[i_batch];
1696  }
1697  else
1698  {
1699  kargs.seqlen_k =
1700  kargs.seqstart_k_ptr[i_batch + 1] - kargs.seqstart_k_ptr[i_batch];
1701  }
1702  }
1703  else
1704  {
1705  batch_offset_q = static_cast<long_index_t>(i_batch) * kargs.batch_stride_q;
1706  batch_offset_k = static_cast<long_index_t>(i_batch) * kargs.batch_stride_k;
1707  batch_offset_v = static_cast<long_index_t>(i_batch) * kargs.batch_stride_v;
1708  if constexpr(kStoreLSE)
1709  {
1710  batch_offset_lse = static_cast<long_index_t>(i_batch) * kargs.batch_stride_lse;
1711  }
1712  batch_offset_o = static_cast<long_index_t>(i_batch) * kargs.batch_stride_o;
1713 
1715  {
1716  batch_offset_bias =
1717  static_cast<long_index_t>(i_batch) * kargs.batch_stride_bias;
1718  }
1719 
1720  // If cumulative seqlen pointers are provided, override per-batch effective lengths
1721  if(kargs.cu_seqlen_q_ptr != nullptr)
1722  {
1723  kargs.seqlen_q =
1724  kargs.cu_seqlen_q_ptr[i_batch + 1] - kargs.cu_seqlen_q_ptr[i_batch];
1725  }
1726  if(kargs.cu_seqlen_k_ptr != nullptr)
1727  {
1728  kargs.seqlen_k =
1729  kargs.cu_seqlen_k_ptr[i_batch + 1] - kargs.cu_seqlen_k_ptr[i_batch];
1730  }
1731  }
1732 
1733  // for simplicity, batch stride we just modify the pointer
1734  const index_t i_nhead_k = i_nhead / kargs.nhead_ratio_qk;
1735 
1736  const QDataType* q_ptr = reinterpret_cast<const QDataType*>(kargs.q_ptr) +
1737  static_cast<long_index_t>(i_nhead) * kargs.nhead_stride_q +
1738  batch_offset_q;
1739  const KDataType* k_ptr = reinterpret_cast<const KDataType*>(kargs.k_ptr) +
1740  static_cast<long_index_t>(i_nhead_k) * kargs.nhead_stride_k +
1741  batch_offset_k;
1742  const VDataType* v_ptr = reinterpret_cast<const VDataType*>(kargs.v_ptr) +
1743  static_cast<long_index_t>(i_nhead_k) * kargs.nhead_stride_v +
1744  batch_offset_v;
1745 
1746  ODataType* o_ptr = reinterpret_cast<ODataType*>(kargs.o_ptr) +
1747  static_cast<long_index_t>(i_nhead) * kargs.nhead_stride_o +
1748  batch_offset_o;
1749 
1750  // Q/K/V DRAM and DRAM window
1751  const auto q_dram = [&] {
1752  const auto q_dram_naive = [&] {
1753  {
1754  return make_naive_tensor_view<address_space_enum::global,
1755  memory_operation_enum::set,
1757  q_ptr,
1758  make_tuple(kargs.seqlen_q, kargs.hdim_q),
1759  make_tuple(kargs.stride_q, 1),
1761  number<1>{});
1762  }
1763  }();
1764 
1765  if constexpr(FmhaPipeline::kQLoadOnce)
1766  {
1767  const auto seqlen_q = kargs.seqlen_q;
1768  const auto q_dram_pad = pad_tensor_view(
1769  q_dram_naive,
1772 #if CK_TILE_FMHA_HANDLE_XOR_LENGTH_FOLD
1773  constexpr index_t LDSLayerSize = 256 / sizeof(QDataType);
1774  constexpr index_t XorLengthFold = LDSLayerSize / (FmhaPipeline::kQKHeaddim);
1775 
1776  if constexpr(XorLengthFold > 1)
1777  {
1778  const auto q_dram_unmerged = transform_tensor_view(
1779  q_dram_pad,
1780  make_tuple(
1782  make_tuple(seqlen_q / XorLengthFold, XorLengthFold)),
1786 
1787  const auto q_dram_merged = transform_tensor_view(
1788  q_dram_unmerged,
1789  make_tuple(make_pass_through_transform(seqlen_q / XorLengthFold),
1791  XorLengthFold, number<FmhaPipeline::kQKHeaddim>{}))),
1794 
1795  const auto q_dram_unmerged_xor = transform_tensor_view(
1796  q_dram_merged,
1797  make_tuple(make_pass_through_transform(seqlen_q / XorLengthFold),
1803 
1804  const auto q_dram_permuted = transform_tensor_view(
1805  q_dram_unmerged_xor,
1806  make_tuple(
1808  make_tuple(seqlen_q / XorLengthFold,
1813 
1814  const auto q_dram_tmp = transform_tensor_view(
1815  q_dram_permuted,
1816  make_tuple(
1817  make_pass_through_transform(seqlen_q / XorLengthFold),
1820  number<FmhaPipeline::kQKHeaddim /
1821  FmhaPipeline::kAlignmentQ>{})),
1825 
1826  return transform_tensor_view(
1827  q_dram_tmp,
1828  make_tuple(
1830  make_tuple(seqlen_q / XorLengthFold, number<XorLengthFold>{})),
1836  }
1837  else
1838 #endif // CK_TILE_FMHA_HANDLE_XOR_LENGTH_FOLD
1839  {
1840  const auto q_dram_unmerged = transform_tensor_view(
1841  q_dram_pad,
1842  make_tuple(
1843  make_pass_through_transform(seqlen_q),
1849 
1850  const auto q_dram_permuted = transform_tensor_view(
1851  q_dram_unmerged,
1852  make_tuple(
1853  make_xor_transform(make_tuple(seqlen_q,
1854  number<FmhaPipeline::kQKHeaddim /
1855  FmhaPipeline::kAlignmentQ>{})),
1859 
1860  return transform_tensor_view(
1861  q_dram_permuted,
1862  make_tuple(
1863  make_pass_through_transform(seqlen_q),
1869  }
1870  }
1871  else
1872  {
1873  return pad_tensor_view(
1874  q_dram_naive,
1877  }
1878  }();
1879 
1880  const auto make_k_dram = [&](const KDataType* data, index_t height) {
1881  const auto k_dram_naive = make_naive_tensor_view<address_space_enum::global>(
1882  data, // will update this pointer if using paged-kvcache
1883  make_tuple(height, kargs.hdim_q),
1884  make_tuple(kargs.stride_k, 1),
1886  number<1>{});
1887 
1888  const auto k_dram_pad = pad_tensor_view(
1889  k_dram_naive,
1892 
1893  constexpr auto kDramTileK =
1894  FmhaPipeline::kKLoadOnce ? FmhaPipeline::kQKHeaddim : FmhaPipeline::kK0;
1895 
1896 #if CK_TILE_FMHA_HANDLE_XOR_LENGTH_FOLD
1897  constexpr index_t LDSLayerSize = 256 / sizeof(KDataType);
1898  constexpr index_t XorLengthFold = LDSLayerSize / (FmhaPipeline::kQKHeaddim);
1899 
1900  if constexpr(XorLengthFold > 1)
1901  {
1902  const auto k_dram_unmerged = transform_tensor_view(
1903  k_dram_pad,
1905  make_tuple(height / XorLengthFold, XorLengthFold)),
1909 
1910  const auto k_dram_merged = transform_tensor_view(
1911  k_dram_unmerged,
1912  make_tuple(make_pass_through_transform(height / XorLengthFold),
1914  XorLengthFold, number<FmhaPipeline::kQKHeaddim>{}))),
1917 
1918  const auto k_dram_unmerged_xor = transform_tensor_view(
1919  k_dram_merged,
1920  make_tuple(make_pass_through_transform(height / XorLengthFold),
1926 
1927  const auto k_dram_permuted = transform_tensor_view(
1928  k_dram_unmerged_xor,
1929  make_tuple(
1931  make_tuple(height / XorLengthFold,
1936 
1937  const auto k_dram_tmp = transform_tensor_view(
1938  k_dram_permuted,
1939  make_tuple(
1940  make_pass_through_transform(height / XorLengthFold),
1943  number<FmhaPipeline::kQKHeaddim / FmhaPipeline::kAlignmentK>{})),
1947 
1948  return transform_tensor_view(
1949  k_dram_tmp,
1950  make_tuple(
1952  make_tuple(height / XorLengthFold, number<XorLengthFold>{})),
1958  }
1959  else
1960 #endif // CK_TILE_FMHA_HANDLE_XOR_LENGTH_FOLD
1961  {
1962  const auto k_dram_unmerged = transform_tensor_view(
1963  k_dram_pad,
1966  make_tuple(number<FmhaPipeline::kQKHeaddim / kDramTileK /
1967  FmhaPipeline::kAlignmentK>{},
1968  number<kDramTileK / FmhaPipeline::kAlignmentK>{},
1972 
1973  const auto k_dram_permuted = transform_tensor_view(
1974  k_dram_unmerged,
1975  make_tuple(
1979  number<FmhaPipeline::kQKHeaddim / kDramTileK /
1980  FmhaPipeline::kAlignmentK>{}),
1984 
1985  return transform_tensor_view(
1986  k_dram_permuted,
1989  make_tuple(number<FmhaPipeline::kQKHeaddim / kDramTileK /
1990  FmhaPipeline::kAlignmentK>{},
1991  number<kDramTileK / FmhaPipeline::kAlignmentK>{},
1995  }
1996  };
1997  const auto k_dram = [&]() {
1998  {
1999  return make_k_dram(k_ptr, kargs.seqlen_k);
2000  }
2001  }();
2002 
2003  const auto make_v_dram = [&](const VDataType* data, index_t length) {
2004  const auto v_dram_naive = make_naive_tensor_view<address_space_enum::global>(
2005  data, // will update this pointer if using paged-kvcache
2006  make_tuple(length, kargs.hdim_v),
2007  make_tuple(kargs.stride_v, 1),
2009  number<1>{});
2010 
2011  // TODO: Add kVHeadDim
2012  constexpr index_t XorGroupSize =
2013  FmhaPipeline::Problem::BlockFmhaShape::Gemm1WarpTile::at(number<0>{});
2014 
2015  const auto v_dram_pad = pad_tensor_view(
2016  v_dram_naive,
2019 
2020 #if CK_TILE_FMHA_HANDLE_XOR_LENGTH_FOLD
2021  constexpr index_t LDSLayerSize = 256 / sizeof(VDataType);
2022  constexpr index_t XorLengthFold = LDSLayerSize / (FmhaPipeline::kQKHeaddim);
2023 
2024  if constexpr(XorLengthFold > 1)
2025  {
2026  const auto v_dram_unmerged = transform_tensor_view(
2027  v_dram_pad,
2029  make_tuple(length / XorLengthFold, XorLengthFold)),
2033 
2034  const auto v_dram_merged = transform_tensor_view(
2035  v_dram_unmerged,
2036  make_tuple(make_pass_through_transform(length / XorLengthFold),
2038  XorLengthFold, number<FmhaPipeline::kQKHeaddim>{}))),
2041 
2042  const auto v_dram_unmerged_xor = transform_tensor_view(
2043  v_dram_merged,
2044  make_tuple(
2045  make_pass_through_transform(length / XorLengthFold),
2047  number<XorGroupSize>{}))),
2050 
2051  const auto v_dram_permuted = transform_tensor_view(
2052  v_dram_unmerged_xor,
2053  make_tuple(
2054  make_xor_transform(make_tuple(length / XorLengthFold,
2059 
2060  const auto v_dram_tmp = transform_tensor_view(
2061  v_dram_permuted,
2062  make_tuple(make_pass_through_transform(length / XorLengthFold),
2065  number<FmhaPipeline::kQKHeaddim / XorGroupSize>{})),
2069 
2070  return transform_tensor_view(
2071  v_dram_tmp,
2073  make_tuple(length / XorLengthFold, number<XorLengthFold>{})),
2076  number<XorGroupSize>{}))),
2079  }
2080  else
2081 #endif // CK_TILE_FMHA_HANDLE_XOR_LENGTH_FOLD
2082  {
2083  const auto v_dram_unmerged = transform_tensor_view(
2084  v_dram_pad,
2088  number<XorGroupSize>{}))),
2091 
2092  const auto v_dram_permuted = transform_tensor_view(
2093  v_dram_unmerged,
2099 
2100  return transform_tensor_view(
2101  v_dram_permuted,
2105  number<XorGroupSize>{}))),
2108  }
2109  };
2110 
2111  const auto v_dram = [&]() {
2112  {
2113  return make_v_dram(v_ptr, kargs.seqlen_k);
2114  }
2115  }();
2116 
2117  auto q_dram_window = make_tile_window(
2118  q_dram,
2119  [&]() {
2120  if constexpr(FmhaPipeline::kQLoadOnce)
2123  else
2125  }(),
2126  {i_m0, 0});
2127 
2128  auto k_dram_window = make_tile_window(
2129  k_dram,
2131  {0, 0});
2132 
2133  auto v_dram_window = make_tile_window(
2134  v_dram,
2136  {0, 0});
2137 
2140  const auto bias_dram_window = [&, i_nhead_ = i_nhead]() {
2141  constexpr auto bias_dram_window_lengths =
2144  {
2145  const BiasDataType* bias_ptr =
2146  reinterpret_cast<const BiasDataType*>(kargs.bias_ptr) +
2147  static_cast<long_index_t>(i_nhead_) * kargs.nhead_stride_bias +
2148  batch_offset_bias;
2149 
2150  const auto bias_dram = [&]() {
2151  const auto bias_dram_naive =
2152  make_naive_tensor_view<address_space_enum::global>(
2153  bias_ptr,
2154  make_tuple(kargs.seqlen_q, kargs.seqlen_k),
2155  make_tuple(kargs.stride_bias, 1),
2157  number<1>{});
2158 
2159  return pad_tensor_view(bias_dram_naive,
2160  bias_dram_window_lengths,
2162  }();
2163 
2164  return make_tile_window(bias_dram, bias_dram_window_lengths, {i_m0, 0});
2165  }
2166  else
2167  {
2168  return make_null_tile_window(bias_dram_window_lengths);
2169  }
2170  }();
2171 
2172  // lse acc
2173  auto lse_dram_window = [&, i_nhead_ = i_nhead]() {
2174  constexpr auto lse_dram_window_lengths = make_tuple(number<FmhaPipeline::kM0>{});
2175  if constexpr(kStoreLSE)
2176  {
2177  LSEDataType* lse_ptr =
2178  reinterpret_cast<LSEDataType*>(kargs.lse_ptr) +
2179  static_cast<long_index_t>(i_nhead_) * kargs.nhead_stride_lse +
2180  batch_offset_lse;
2181 
2182  const auto lse_dram = [&] {
2183  const auto lse_dram_naive = [&] {
2184  {
2185  return make_naive_tensor_view<address_space_enum::global>(
2186  lse_ptr,
2187  make_tuple(kargs.seqlen_q),
2188  make_tuple(1),
2189  number<1>{},
2190  number<1>{});
2191  }
2192  }();
2193  return pad_tensor_view(
2194  lse_dram_naive, lse_dram_window_lengths, sequence<kPadSeqLenQ>{});
2195  }();
2196 
2197  return make_tile_window(lse_dram, lse_dram_window_lengths, {i_m0});
2198  }
2199  else
2200  {
2201  return make_null_tile_window(lse_dram_window_lengths);
2202  }
2203  }();
2204 
2205  FmhaMask mask = [&]() {
2206  if constexpr(kHasMask)
2207  return ck_tile::make_generic_attention_mask_from_lr_window<FmhaMask>(
2208  kargs.window_size_left,
2209  kargs.window_size_right,
2210  kargs.sink_size,
2211  kargs.seqlen_q,
2212  kargs.seqlen_k,
2214  else
2215  return FmhaMask{kargs.seqlen_q, kargs.seqlen_k};
2216  }();
2217 
2218  // WA i_batch capture structure binding before c++20
2219  auto position_encoding = [&, i_batch_ = i_batch, i_nhead_ = i_nhead]() {
2220  if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI)
2221  {
2222  // data loading, shared by entire wg
2223  // TODO: how to use s_read?
2224  SaccDataType slope =
2225  *(reinterpret_cast<const SaccDataType*>(kargs.alibi_slope_ptr) +
2226  i_batch_ * kargs.alibi_slope_stride + i_nhead_);
2227 #if CK_TILE_FMHA_FWD_FAST_EXP2
2228  slope *= ck_tile::log2e_v<>;
2229 #endif
2230  if constexpr(kHasMask)
2231  {
2232  return make_alibi_from_lr_mask<SaccDataType, true, 32>(
2233  slope,
2234  kargs.window_size_left,
2235  kargs.window_size_right,
2236  kargs.seqlen_q,
2237  kargs.seqlen_k,
2238  kargs.mask_type);
2239  }
2240  else
2241  {
2243  slope, kargs.seqlen_q, kargs.seqlen_k, AlibiMode::FROM_BOTTOM_RIGHT};
2244  }
2245  }
2246  else
2247  {
2249  }
2250  }();
2251 
2252  auto o_acc_tile = [&]() {
2253  if constexpr(PrefillCase)
2254  {
2255  // allocate double lds
2256  // add __restrict__ here to avoid aliasing
2257  __shared__ char smem_ptrk0
2258  [FmhaPipeline::Policy::template GetSmemSizeK<typename FmhaPipeline::Problem,
2259  true>()];
2260  __shared__ char smem_ptrk1
2261  [FmhaPipeline::Policy::template GetSmemSizeK<typename FmhaPipeline::Problem,
2262  true>()];
2263  __shared__ char smem_ptrv0[FmhaPipeline::Policy::template GetSmemSizeV<
2264  typename FmhaPipeline::Problem>()];
2265  __shared__ char smem_ptrv1[FmhaPipeline::Policy::template GetSmemSizeV<
2266  typename FmhaPipeline::Problem>()];
2267 
2268  return FmhaPipeline{}(q_dram_window,
2269  k_dram_window,
2270  v_dram_window,
2271  bias_dram_window,
2272  lse_dram_window,
2273  mask,
2274  position_encoding,
2275  kargs.scale_s,
2276  smem_ptrk0,
2277  smem_ptrk1,
2278  smem_ptrv0,
2279  smem_ptrv1);
2280  }
2281  else
2282  {
2283  __shared__ char smem_ptr[GetSmemSize()];
2284  return FmhaPipeline{}(q_dram_window,
2285  k_dram_window,
2286  v_dram_window,
2287  bias_dram_window,
2288  lse_dram_window,
2289  mask,
2290  position_encoding,
2291  kargs.scale_s,
2292  smem_ptr);
2293  }
2294  }();
2295 
2296  // Oacc DRAM and Oacc DRAM window
2297  auto o_dram = [&] {
2298  const auto o_dram_naive = [&] {
2299  {
2300  return make_naive_tensor_view<address_space_enum::global>(
2301  o_ptr,
2302  make_tuple(kargs.seqlen_q, kargs.hdim_v),
2303  make_tuple(kargs.stride_o, 1),
2305  number<1>{});
2306  }
2307  }();
2308 
2309  return pad_tensor_view(
2310  o_dram_naive,
2313  }();
2314 
2315  auto o_dram_window = make_tile_window(
2316  o_dram,
2318  {i_m0, i_n1});
2319 
2320  EpiloguePipeline{}(o_dram_window, o_acc_tile, nullptr);
2321  }
2322  }
2323 };
2324 
2325 } // namespace ck_tile
#define CK_TILE_DEVICE
Definition: config.hpp:45
#define CK_TILE_HOST
Definition: config.hpp:44
#define CK_TILE_HOST_DEVICE
Definition: config.hpp:46
__host__ constexpr __device__ T max(T x)
Definition: math.hpp:84
__host__ T floor(T x)
Definition: math_v2.hpp:367
Definition: cluster_descriptor.hpp:13
constexpr CK_TILE_DEVICE auto make_null_tile_window(const WindowLengths &window_lengths)
Definition: null_tile_window.hpp:66
constexpr CK_TILE_HOST_DEVICE auto integer_divide_ceil(X x, Y y)
Definition: math.hpp:145
__device__ uint32_t amd_wave_read_first_lane(uint16_t v)
Definition: amd_buffer_addressing.hpp:36
CK_TILE_HOST_DEVICE_EXTERN composes(Ts &&...) -> composes< remove_cvref_t< Ts >... >
constexpr CK_TILE_HOST_DEVICE auto transform_tensor_view(const OldTensorView &old_tensor_view, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition: tensor_view.hpp:526
int32_t index_t
Definition: integer.hpp:9
constexpr CK_TILE_HOST_DEVICE auto pad_tensor_view(const TensorView &tensor_view, const TileLengths &tile_lengths, DoPads)
Definition: tensor_view.hpp:545
constexpr CK_TILE_HOST_DEVICE auto make_pass_through_transform(const LowLength &low_length)
Definition: coordinate_transform.hpp:1633
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition: type_traits.hpp:21
constexpr CK_TILE_HOST_DEVICE auto make_unmerge_transform(const UpLengths &up_lengths, bool_constant< Use24BitIntegerCalculation >=bool_constant< false >{})
Definition: coordinate_transform.hpp:1697
constexpr CK_TILE_HOST_DEVICE auto make_merge_transform_v3_division_mod(const LowLengths &low_lengths)
Definition: coordinate_transform.hpp:1684
int64_t long_index_t
Definition: integer.hpp:11
int32_t int32_t
Definition: integer.hpp:10
constexpr CK_TILE_DEVICE auto make_tile_window(null_tensor_view, const WindowLengths &window_lengths, const multi_index< WindowLengths::size()> &, Ts &&...)
Definition: null_tile_window.hpp:75
constexpr CK_TILE_HOST_DEVICE auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:360
constexpr CK_TILE_HOST_DEVICE auto make_xor_transform(const LowLengths &low_lengths)
Definition: coordinate_transform.hpp:1737
GenericAttentionMaskEnum
Definition: block_masking.hpp:11
constexpr CK_TILE_HOST_DEVICE auto make_naive_tensor_view(DataType *__restrict__ p, const tuple< Lengths... > &lengths, const tuple< Strides... > &strides, number< GuaranteedLastDimensionVectorLength >=number<-1 >{}, number< GuaranteedLastDimensionVectorStride >=number<-1 >{})
Definition: tensor_view.hpp:486
constexpr CK_TILE_HOST_DEVICE T max(T x)
Definition: math.hpp:157
typename conditional< predicate, X, Y >::type conditional_t
Definition: functional.hpp:115
unsigned char uint8_t
Definition: stdint.h:124
unsigned __int64 uint64_t
Definition: stdint.h:136
Definition: block_position_encoding.hpp:48
Definition: block_dropout.hpp:53
const float rp_undrop
Definition: block_dropout.hpp:371
Definition: block_position_encoding.hpp:137
Definition: fmha_fwd_kernel.hpp:291
ck_tile::index_t kv_head_idx
Definition: fmha_fwd_kernel.hpp:294
ck_tile::index_t batch_idx
Definition: fmha_fwd_kernel.hpp:292
ck_tile::index_t qo_head_idx
Definition: fmha_fwd_kernel.hpp:293
Definition: fmha_fwd_kernel.hpp:150
ck_tile::index_t alibi_slope_stride
Definition: fmha_fwd_kernel.hpp:153
const void * alibi_slope_ptr
Definition: fmha_fwd_kernel.hpp:152
Definition: fmha_fwd_kernel.hpp:145
ck_tile::index_t batch_stride_bias
Definition: fmha_fwd_kernel.hpp:146
Definition: fmha_fwd_kernel.hpp:227
ck_tile::index_t batch_stride_randval
Definition: fmha_fwd_kernel.hpp:228
Definition: fmha_fwd_kernel.hpp:250
ck_tile::index_t batch_stride_o
Definition: fmha_fwd_kernel.hpp:254
const int32_t * cu_seqlen_k_ptr
Definition: fmha_fwd_kernel.hpp:259
ck_tile::index_t batch_stride_q
Definition: fmha_fwd_kernel.hpp:251
ck_tile::index_t batch_stride_k
Definition: fmha_fwd_kernel.hpp:252
const int32_t * cu_seqlen_q_ptr
Definition: fmha_fwd_kernel.hpp:258
ck_tile::index_t batch_stride_v
Definition: fmha_fwd_kernel.hpp:253
Definition: fmha_fwd_kernel.hpp:138
const void * bias_ptr
Definition: fmha_fwd_kernel.hpp:139
ck_tile::index_t stride_bias
Definition: fmha_fwd_kernel.hpp:140
ck_tile::index_t nhead_stride_bias
Definition: fmha_fwd_kernel.hpp:141
Definition: fmha_fwd_kernel.hpp:192
void init_dropout(float p_drop, const uint64_t *seed_ptr, const uint64_t *offset_ptr)
Definition: fmha_fwd_kernel.hpp:205
float rp_undrop
Definition: fmha_fwd_kernel.hpp:217
ck_tile::index_t stride_randval
Definition: fmha_fwd_kernel.hpp:222
ck_tile::index_t nhead_stride_randval
Definition: fmha_fwd_kernel.hpp:223
void * rand_val_ptr
Definition: fmha_fwd_kernel.hpp:220
void init_dropout(float p_drop, uint64_t seed, uint64_t offset)
Definition: fmha_fwd_kernel.hpp:193
bool is_store_randval
Definition: fmha_fwd_kernel.hpp:219
uint8_t p_undrop_in_uint8_t
Definition: fmha_fwd_kernel.hpp:218
Definition: fmha_fwd_kernel.hpp:87
ck_tile::index_t nhead_stride_k
Definition: fmha_fwd_kernel.hpp:110
float scale_s
Definition: fmha_fwd_kernel.hpp:102
ck_tile::index_t seqlen_k
Definition: fmha_fwd_kernel.hpp:94
ck_tile::index_t nhead_stride_o
Definition: fmha_fwd_kernel.hpp:112
ck_tile::index_t nhead_ratio_qk
Definition: fmha_fwd_kernel.hpp:101
ck_tile::index_t num_head_q
Definition: fmha_fwd_kernel.hpp:98
ck_tile::index_t hdim_q
Definition: fmha_fwd_kernel.hpp:95
const void * v_ptr
Definition: fmha_fwd_kernel.hpp:90
void * o_ptr
Definition: fmha_fwd_kernel.hpp:91
const void * k_ptr
Definition: fmha_fwd_kernel.hpp:89
ck_tile::index_t nhead_stride_q
Definition: fmha_fwd_kernel.hpp:109
ck_tile::index_t stride_k
Definition: fmha_fwd_kernel.hpp:105
ck_tile::index_t stride_o
Definition: fmha_fwd_kernel.hpp:107
ck_tile::index_t stride_v
Definition: fmha_fwd_kernel.hpp:106
ck_tile::index_t hdim_v
Definition: fmha_fwd_kernel.hpp:96
ck_tile::index_t nhead_stride_v
Definition: fmha_fwd_kernel.hpp:111
const void * q_ptr
Definition: fmha_fwd_kernel.hpp:88
ck_tile::index_t seqlen_q
Definition: fmha_fwd_kernel.hpp:93
ck_tile::index_t stride_q
Definition: fmha_fwd_kernel.hpp:104
Definition: fmha_fwd_kernel.hpp:171
ck_tile::index_t batch_stride_lse
Definition: fmha_fwd_kernel.hpp:174
void * lse_ptr
Definition: fmha_fwd_kernel.hpp:172
ck_tile::index_t nhead_stride_lse
Definition: fmha_fwd_kernel.hpp:173
Definition: fmha_fwd_kernel.hpp:164
const void * v_descale_ptr
Definition: fmha_fwd_kernel.hpp:167
const void * k_descale_ptr
Definition: fmha_fwd_kernel.hpp:166
const void * q_descale_ptr
Definition: fmha_fwd_kernel.hpp:165
Definition: fmha_fwd_kernel.hpp:178
bool is_drop_seed_offset_from_host
Definition: fmha_fwd_kernel.hpp:188
ValueOrPointer< uint64_t > drop_seed
Definition: fmha_fwd_kernel.hpp:186
ValueOrPointer< uint64_t > drop_offset
Definition: fmha_fwd_kernel.hpp:187
Definition: fmha_fwd_kernel.hpp:80
Definition: fmha_fwd_kernel.hpp:277
const int32_t * seqlen_q_ptr
Definition: fmha_fwd_kernel.hpp:280
const int32_t * seqstart_q_ptr
Definition: fmha_fwd_kernel.hpp:278
const int32_t * seqlen_k_ptr
Definition: fmha_fwd_kernel.hpp:281
const int32_t * cu_seqlen_k_ptr
Definition: fmha_fwd_kernel.hpp:285
const int32_t * cu_seqlen_q_ptr
Definition: fmha_fwd_kernel.hpp:284
const int32_t * seqstart_k_ptr
Definition: fmha_fwd_kernel.hpp:279
Definition: fmha_fwd_kernel.hpp:116
float logits_soft_cap
Definition: fmha_fwd_kernel.hpp:133
float logits_soft_cap_rcp
Definition: fmha_fwd_kernel.hpp:134
void init_logits_soft_cap(float logits_soft_cap_)
Definition: fmha_fwd_kernel.hpp:119
Definition: fmha_fwd_kernel.hpp:157
ck_tile::GenericAttentionMaskEnum mask_type
Definition: fmha_fwd_kernel.hpp:160
ck_tile::index_t window_size_right
Definition: fmha_fwd_kernel.hpp:159
ck_tile::index_t sink_size
Definition: fmha_fwd_kernel.hpp:159
ck_tile::index_t window_size_left
Definition: fmha_fwd_kernel.hpp:159
Definition: fmha_fwd_kernel.hpp:232
ck_tile::index_t min_seqlen_q
Definition: fmha_fwd_kernel.hpp:233
Definition: fmha_fwd_kernel.hpp:28
static constexpr bool kHasDropout
Definition: fmha_fwd_kernel.hpp:58
static constexpr bool kIsAvailable
Definition: fmha_fwd_kernel.hpp:73
static constexpr bool kStoreLSE
Definition: fmha_fwd_kernel.hpp:57
ck_tile::remove_cvref_t< typename FmhaPipeline::KDataType > KDataType
Definition: fmha_fwd_kernel.hpp:38
std::conditional_t< kIsGroupMode, FmhaFwdGroupModeKargs, FmhaFwdBatchModeKargs > Kargs
Definition: fmha_fwd_kernel.hpp:288
static constexpr ck_tile::index_t kBlockPerCu
Definition: fmha_fwd_kernel.hpp:33
ck_tile::remove_cvref_t< typename FmhaPipeline::ODataType > ODataType
Definition: fmha_fwd_kernel.hpp:45
ck_tile::remove_cvref_t< typename FmhaPipeline::VLayout > VLayout
Definition: fmha_fwd_kernel.hpp:48
static constexpr ck_tile::index_t kBlockSize
Definition: fmha_fwd_kernel.hpp:31
ck_tile::remove_cvref_t< typename FmhaPipeline::BiasDataType > BiasDataType
Definition: fmha_fwd_kernel.hpp:41
static constexpr CK_TILE_HOST auto GridSize(ck_tile::index_t batch_size_, ck_tile::index_t nhead_, ck_tile::index_t seqlen_q_, ck_tile::index_t hdim_v_, bool has_padded_seqlen_k=false)
Definition: fmha_fwd_kernel.hpp:977
static constexpr CK_TILE_HOST std::enable_if_t< Cond, Kargs > MakeKargs(const void *q_ptr, const void *k_ptr, const void *v_ptr, const void *bias_ptr, const void *q_descale_ptr, const void *k_descale_ptr, const void *v_descale_ptr, void *rand_val_ptr, void *lse_ptr, void *o_ptr, ck_tile::index_t seqlen_q, ck_tile::index_t seqlen_k, ck_tile::index_t hdim_q, ck_tile::index_t hdim_v, ck_tile::index_t num_head_q, ck_tile::index_t nhead_ratio_qk, float scale_s, float logits_soft_cap, ck_tile::index_t stride_q, ck_tile::index_t stride_k, ck_tile::index_t stride_v, ck_tile::index_t stride_bias, ck_tile::index_t stride_randval, ck_tile::index_t stride_o, ck_tile::index_t nhead_stride_q, ck_tile::index_t nhead_stride_k, ck_tile::index_t nhead_stride_v, ck_tile::index_t nhead_stride_bias, ck_tile::index_t nhead_stride_randval, ck_tile::index_t nhead_stride_lse, ck_tile::index_t nhead_stride_o, ck_tile::index_t batch_stride_q, ck_tile::index_t batch_stride_k, ck_tile::index_t batch_stride_v, ck_tile::index_t batch_stride_bias, ck_tile::index_t batch_stride_randval, ck_tile::index_t batch_stride_lse, ck_tile::index_t batch_stride_o, ck_tile::index_t window_size_left, ck_tile::index_t window_size_right, ck_tile::index_t sink_size, ck_tile::index_t mask_type, float p_drop, bool s_randval, const std::tuple< const void *, const void * > &drop_seed_offset, const void *cu_seqlen_q_ptr=nullptr, const void *cu_seqlen_k_ptr=nullptr)
Definition: fmha_fwd_kernel.hpp:548
static constexpr CK_TILE_HOST std::enable_if_t< Cond, Kargs > MakeKargs(const void *q_ptr, const void *k_ptr, const void *v_ptr, const void *bias_ptr, const void *q_descale_ptr, const void *k_descale_ptr, const void *v_descale_ptr, void *rand_val_ptr, void *lse_ptr, void *o_ptr, const void *seqstart_q_ptr, const void *seqstart_k_ptr, const void *seqlen_q_ptr, const void *seqlen_k_ptr, ck_tile::index_t hdim_q, ck_tile::index_t hdim_v, ck_tile::index_t num_head_q, ck_tile::index_t nhead_ratio_qk, float scale_s, float logits_soft_cap, ck_tile::index_t stride_q, ck_tile::index_t stride_k, ck_tile::index_t stride_v, ck_tile::index_t stride_bias, ck_tile::index_t stride_randval, ck_tile::index_t stride_o, ck_tile::index_t nhead_stride_q, ck_tile::index_t nhead_stride_k, ck_tile::index_t nhead_stride_v, ck_tile::index_t nhead_stride_bias, ck_tile::index_t nhead_stride_randval, ck_tile::index_t nhead_stride_lse, ck_tile::index_t nhead_stride_o, ck_tile::index_t window_size_left, ck_tile::index_t window_size_right, ck_tile::index_t sink_size, ck_tile::index_t mask_type, ck_tile::index_t min_seqlen_q, float p_drop, bool s_randval, const std::tuple< uint64_t, uint64_t > &drop_seed_offset, const void *cu_seqlen_q_ptr=nullptr, const void *cu_seqlen_k_ptr=nullptr)
Definition: fmha_fwd_kernel.hpp:794
static constexpr CK_TILE_HOST std::enable_if_t< Cond, Kargs > MakeKargs(const void *q_ptr, const void *k_ptr, const void *v_ptr, const void *bias_ptr, const void *q_descale_ptr, const void *k_descale_ptr, const void *v_descale_ptr, void *rand_val_ptr, void *lse_ptr, void *o_ptr, const void *seqstart_q_ptr, const void *seqstart_k_ptr, const void *seqlen_q_ptr, const void *seqlen_k_ptr, ck_tile::index_t hdim_q, ck_tile::index_t hdim_v, ck_tile::index_t num_head_q, ck_tile::index_t nhead_ratio_qk, float scale_s, float logits_soft_cap, ck_tile::index_t stride_q, ck_tile::index_t stride_k, ck_tile::index_t stride_v, ck_tile::index_t stride_bias, ck_tile::index_t stride_randval, ck_tile::index_t stride_o, ck_tile::index_t nhead_stride_q, ck_tile::index_t nhead_stride_k, ck_tile::index_t nhead_stride_v, ck_tile::index_t nhead_stride_bias, ck_tile::index_t nhead_stride_randval, ck_tile::index_t nhead_stride_lse, ck_tile::index_t nhead_stride_o, ck_tile::index_t window_size_left, ck_tile::index_t window_size_right, ck_tile::index_t sink_size, ck_tile::index_t mask_type, ck_tile::index_t min_seqlen_q, float p_drop, bool s_randval, const std::tuple< const void *, const void * > &drop_seed_offset, const void *cu_seqlen_q_ptr=nullptr, const void *cu_seqlen_k_ptr=nullptr)
Definition: fmha_fwd_kernel.hpp:887
static constexpr CK_TILE_HOST std::enable_if_t< Cond, Kargs > MakeKargsImpl(const void *q_ptr, const void *k_ptr, const void *v_ptr, const void *bias_ptr, const void *q_descale_ptr, const void *k_descale_ptr, const void *v_descale_ptr, void *rand_val_ptr, void *lse_ptr, void *o_ptr, ck_tile::index_t seqlen_q, ck_tile::index_t seqlen_k, ck_tile::index_t hdim_q, ck_tile::index_t hdim_v, ck_tile::index_t num_head_q, ck_tile::index_t nhead_ratio_qk, float scale_s, float logits_soft_cap, ck_tile::index_t stride_q, ck_tile::index_t stride_k, ck_tile::index_t stride_v, ck_tile::index_t stride_bias, ck_tile::index_t stride_randval, ck_tile::index_t stride_o, ck_tile::index_t nhead_stride_q, ck_tile::index_t nhead_stride_k, ck_tile::index_t nhead_stride_v, ck_tile::index_t nhead_stride_bias, ck_tile::index_t nhead_stride_randval, ck_tile::index_t nhead_stride_lse, ck_tile::index_t nhead_stride_o, ck_tile::index_t batch_stride_q, ck_tile::index_t batch_stride_k, ck_tile::index_t batch_stride_v, ck_tile::index_t batch_stride_bias, ck_tile::index_t batch_stride_randval, ck_tile::index_t batch_stride_lse, ck_tile::index_t batch_stride_o, ck_tile::index_t window_size_left, ck_tile::index_t window_size_right, ck_tile::index_t sink_size, ck_tile::index_t mask_type, float p_drop, bool s_randval, std::variant< std::pair< uint64_t, uint64_t >, std::pair< const void *, const void * >> drop_seed_offset, const void *cu_seqlen_q_ptr=nullptr, const void *cu_seqlen_k_ptr=nullptr)
Definition: fmha_fwd_kernel.hpp:299
static constexpr CK_TILE_HOST std::enable_if_t< Cond, Kargs > MakeKargsImpl(const void *q_ptr, const void *k_ptr, const void *v_ptr, const void *bias_ptr, const void *q_descale_ptr, const void *k_descale_ptr, const void *v_descale_ptr, void *rand_val_ptr, void *lse_ptr, void *o_ptr, const void *seqstart_q_ptr, const void *seqstart_k_ptr, const void *seqlen_q_ptr, const void *seqlen_k_ptr, ck_tile::index_t hdim_q, ck_tile::index_t hdim_v, ck_tile::index_t num_head_q, ck_tile::index_t nhead_ratio_qk, float scale_s, float logits_soft_cap, ck_tile::index_t stride_q, ck_tile::index_t stride_k, ck_tile::index_t stride_v, ck_tile::index_t stride_bias, ck_tile::index_t stride_randval, ck_tile::index_t stride_o, ck_tile::index_t nhead_stride_q, ck_tile::index_t nhead_stride_k, ck_tile::index_t nhead_stride_v, ck_tile::index_t nhead_stride_bias, ck_tile::index_t nhead_stride_randval, ck_tile::index_t nhead_stride_lse, ck_tile::index_t nhead_stride_o, ck_tile::index_t window_size_left, ck_tile::index_t window_size_right, ck_tile::index_t sink_size, ck_tile::index_t mask_type, ck_tile::index_t min_seqlen_q, float p_drop, bool s_randval, std::variant< std::pair< uint64_t, uint64_t >, std::pair< const void *, const void * >> drop_seed_offset, const void *cu_seqlen_q_ptr=nullptr, const void *cu_seqlen_k_ptr=nullptr)
Definition: fmha_fwd_kernel.hpp:648
ck_tile::remove_cvref_t< typename FmhaPipeline::VDataType > VDataType
Definition: fmha_fwd_kernel.hpp:39
static constexpr auto QScaleEnum
Definition: fmha_fwd_kernel.hpp:59
static constexpr ck_tile::index_t kBlockPerCuInput
Definition: fmha_fwd_kernel.hpp:35
static constexpr auto BiasEnum
Definition: fmha_fwd_kernel.hpp:56
static constexpr bool kPadHeadDimV
Definition: fmha_fwd_kernel.hpp:54
static constexpr CK_TILE_DEVICE auto GetTileIndex(const Kargs &kargs)
Definition: fmha_fwd_kernel.hpp:1002
static constexpr bool kSkipMinSeqlenQ
Definition: fmha_fwd_kernel.hpp:60
static constexpr std::string_view kPipelineName
Definition: fmha_fwd_kernel.hpp:75
ck_tile::remove_cvref_t< typename FmhaPipeline::PDataType > PDataType
Definition: fmha_fwd_kernel.hpp:40
ck_tile::remove_cvref_t< typename FmhaPipeline::LSEDataType > LSEDataType
Definition: fmha_fwd_kernel.hpp:44
ck_tile::remove_cvref_t< typename FmhaPipeline::QDataType > QDataType
Definition: fmha_fwd_kernel.hpp:37
static CK_TILE_HOST dim3 BlockSize()
Definition: fmha_fwd_kernel.hpp:1067
CK_TILE_DEVICE void run_(Kargs kargs) const
Definition: fmha_fwd_kernel.hpp:1090
ck_tile::remove_cvref_t< typename FmhaPipeline::AttentionVariant > AttentionVariant
Definition: fmha_fwd_kernel.hpp:63
static constexpr CK_TILE_HOST_DEVICE ck_tile::index_t GetSmemSize()
Definition: fmha_fwd_kernel.hpp:1079
static constexpr bool kUseTrLoad
Definition: fmha_fwd_kernel.hpp:69
static constexpr bool kHasMask
Definition: fmha_fwd_kernel.hpp:65
static constexpr bool kUseAsyncCopy
Definition: fmha_fwd_kernel.hpp:67
ck_tile::remove_cvref_t< FmhaPipeline_ > FmhaPipeline
Definition: fmha_fwd_kernel.hpp:29
static constexpr bool kPadHeadDimQ
Definition: fmha_fwd_kernel.hpp:53
ck_tile::remove_cvref_t< typename FmhaPipeline::SaccDataType > SaccDataType
Definition: fmha_fwd_kernel.hpp:46
static constexpr bool kHasSink
Definition: fmha_fwd_kernel.hpp:61
static constexpr bool kPadSeqLenQ
Definition: fmha_fwd_kernel.hpp:51
ck_tile::remove_cvref_t< typename FmhaPipeline::FmhaMask > FmhaMask
Definition: fmha_fwd_kernel.hpp:64
static constexpr bool kHasLogitsSoftCap
Definition: fmha_fwd_kernel.hpp:55
static constexpr CK_TILE_HOST std::enable_if_t< Cond, Kargs > MakeKargs(const void *q_ptr, const void *k_ptr, const void *v_ptr, const void *bias_ptr, const void *q_descale_ptr, const void *k_descale_ptr, const void *v_descale_ptr, void *rand_val_ptr, void *lse_ptr, void *o_ptr, ck_tile::index_t seqlen_q, ck_tile::index_t seqlen_k, ck_tile::index_t hdim_q, ck_tile::index_t hdim_v, ck_tile::index_t num_head_q, ck_tile::index_t nhead_ratio_qk, float scale_s, float logits_soft_cap, ck_tile::index_t stride_q, ck_tile::index_t stride_k, ck_tile::index_t stride_v, ck_tile::index_t stride_bias, ck_tile::index_t stride_randval, ck_tile::index_t stride_o, ck_tile::index_t nhead_stride_q, ck_tile::index_t nhead_stride_k, ck_tile::index_t nhead_stride_v, ck_tile::index_t nhead_stride_bias, ck_tile::index_t nhead_stride_randval, ck_tile::index_t nhead_stride_lse, ck_tile::index_t nhead_stride_o, ck_tile::index_t batch_stride_q, ck_tile::index_t batch_stride_k, ck_tile::index_t batch_stride_v, ck_tile::index_t batch_stride_bias, ck_tile::index_t batch_stride_randval, ck_tile::index_t batch_stride_lse, ck_tile::index_t batch_stride_o, ck_tile::index_t window_size_left, ck_tile::index_t window_size_right, ck_tile::index_t sink_size, ck_tile::index_t mask_type, float p_drop, bool s_randval, const std::tuple< uint64_t, uint64_t > &drop_seed_offset, const void *cu_seqlen_q_ptr=nullptr, const void *cu_seqlen_k_ptr=nullptr)
Definition: fmha_fwd_kernel.hpp:447
ck_tile::remove_cvref_t< typename FmhaPipeline::RandValOutputDataType > RandValOutputDataType
Definition: fmha_fwd_kernel.hpp:43
ck_tile::remove_cvref_t< EpiloguePipeline_ > EpiloguePipeline
Definition: fmha_fwd_kernel.hpp:30
static constexpr bool kPadSeqLenK
Definition: fmha_fwd_kernel.hpp:52
static constexpr bool kIsGroupMode
Definition: fmha_fwd_kernel.hpp:50
CK_TILE_DEVICE void operator()(Kargs kargs) const
Definition: fmha_fwd_kernel.hpp:1084
Definition: variants.hpp:63
float logits_soft_cap
Definition: variants.hpp:128
Definition: block_dropout.hpp:39
Definition: variants.hpp:51
Definition: integral_constant.hpp:13
Definition: functional.hpp:114
Definition: numeric.hpp:18
Definition: coordinate_transform.hpp:1392
Definition: unary_element_function.hpp:55
Definition: math.hpp:28
Definition: sequence.hpp:49
const T * ptr
Definition: fmha_fwd_kernel.hpp:183