/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp Source File
fmha_fwd_kernel.hpp
Go to the documentation of this file.
1 // Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
2 // SPDX-License-Identifier: MIT
3 
4 #pragma once
5 
6 #include "ck_tile/core.hpp"
7 #include "ck_tile/ops/common.hpp"
11 
12 #include <string>
13 #include <type_traits>
14 #include <utility>
15 #include <variant>
16 
17 #define CK_TILE_FMHA_HANDLE_XOR_LENGTH_FOLD 0
18 // S[seqlen_q, seqlen_k] = Q[seqlen_q, hdim_q] @ K[seqlen_k, hdim_q]
19 // S'[seqlen_q, seqlen_k] = S[seqlen_q, seqlen_k] * Scale[1]
20 // S''[seqlen_q, seqlen_k] = S'[seqlen_q, seqlen_k] + Bias[seqlen_q, seqlen_k]
21 // P[seqlen_q, seqlen_k] = Softmax(S''[seqlen_q, seqlen_k])
22 // O[seqlen_q, hdim_v] = P[seqlen_q, seqlen_k] @ V^T[hdim_v, seqlen_k]
23 
24 namespace ck_tile {
25 
26 template <typename FmhaPipeline_, typename EpiloguePipeline_>
28 {
31  static constexpr ck_tile::index_t kBlockSize = FmhaPipeline::kBlockSize;
32 
33  static constexpr ck_tile::index_t kBlockPerCu = FmhaPipeline::kBlockPerCu;
34  static_assert(kBlockPerCu > 0);
35  static constexpr ck_tile::index_t kBlockPerCuInput = FmhaPipeline::Problem::kBlockPerCu;
36 
47 
49 
50  static constexpr bool kIsGroupMode = FmhaPipeline::kIsGroupMode;
51  static constexpr bool kPadSeqLenQ = FmhaPipeline::kPadSeqLenQ;
52  static constexpr bool kPadSeqLenK = FmhaPipeline::kPadSeqLenK;
53  static constexpr bool kPadHeadDimQ = FmhaPipeline::kPadHeadDimQ;
54  static constexpr bool kPadHeadDimV = FmhaPipeline::kPadHeadDimV;
55  static constexpr bool kHasLogitsSoftCap = FmhaPipeline::kHasLogitsSoftCap;
56  static constexpr auto BiasEnum = FmhaPipeline::BiasEnum;
57  static constexpr bool kStoreLSE = FmhaPipeline::kStoreLSE;
58  static constexpr bool kHasDropout = FmhaPipeline::kHasDropout;
59  static constexpr auto QScaleEnum = FmhaPipeline::Problem::QScaleEnum;
60  static constexpr bool kSkipMinSeqlenQ = FmhaPipeline::Problem::kSkipMinSeqlenQ;
61 
64  static constexpr bool kHasMask = FmhaMask::IsMasking;
65 
66  static constexpr bool kUseAsyncCopy = FmhaPipeline::Policy::AsyncCopy;
67 
68  static constexpr bool kUseTrLoad = FmhaPipeline::Problem::kUseTrLoad;
69 #if defined(__gfx950__)
70  static constexpr bool kIsAvailable = true;
71 #else
72  static constexpr bool kIsAvailable = !kUseTrLoad;
73 #endif
74  static constexpr std::string_view kPipelineName = FmhaPipeline::name;
75 
76  template <ck_tile::index_t I> // to avoid duplicated base class prblem, introduce an template
77  // arg
79  {
80  };
81 
82  // kargs use aggregate initializer, so no constructor will provided
83  // use inheritance to minimize karg size
84  // user need to use MakeKargs() function to create kargs.
86  {
87  const void* q_ptr;
88  const void* k_ptr;
89  const void* v_ptr;
90  void* o_ptr;
91 
96 
98  // for MQA/GQA, nhead could be different. This parameter is nhead_q / nhead_k
99  // if this param is larger than 1, indicate MQA/GQA case
101  float scale_s;
102 
107 
112  };
113 
115  {
117 
118  void init_logits_soft_cap(float logits_soft_cap_)
119  {
120  if(0 < logits_soft_cap_)
121  {
122  logits_soft_cap = logits_soft_cap_;
124  }
125  else
126  {
127  logits_soft_cap = 0.f;
128  logits_soft_cap_rcp = 0.f;
129  }
130  }
131 
134  };
135 
137  {
138  const void* bias_ptr = nullptr;
141  };
142 
144  {
146  };
147 
149  {
150  // alibi is batch*nhead*1, no matter in batch/group mode, they are the same
151  const void* alibi_slope_ptr;
152  ck_tile::index_t alibi_slope_stride; // stride in batch, or 0 for all batch share same slope
153  };
154 
156  {
157  // ck_tile::index_t window_size_left, window_size_right;
160  };
161 
163  {
164  const void* q_descale_ptr = nullptr;
165  const void* k_descale_ptr = nullptr;
166  const void* v_descale_ptr = nullptr;
167  };
168 
170  {
171  void* lse_ptr = nullptr;
174  };
175 
177  {
178  template <typename T>
180  {
181  T val;
182  const T* ptr;
183  };
184 
188  };
189 
191  {
192  void init_dropout(float p_drop, uint64_t seed, uint64_t offset)
193  {
194  float p_undrop = 1.0 - p_drop;
197  rp_undrop = 1.0 / p_undrop;
198 
199  this->drop_seed.val = seed;
200  this->drop_offset.val = offset;
201  this->is_drop_seed_offset_from_host = true;
202  }
203 
204  void init_dropout(float p_drop, const uint64_t* seed_ptr, const uint64_t* offset_ptr)
205  {
206  float p_undrop = 1.0 - p_drop;
209  rp_undrop = 1.0 / p_undrop;
210 
211  this->drop_seed.ptr = seed_ptr;
212  this->drop_offset.ptr = offset_ptr;
213  this->is_drop_seed_offset_from_host = false;
214  }
215 
216  float rp_undrop = 1;
218  bool is_store_randval = false;
219  void* rand_val_ptr = nullptr;
220 
223  };
224 
226  {
228  };
229 
231  {
233  };
234 
237  std::conditional_t<BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS,
238  FmhaFwdBatchModeBiasKargs,
239  std::conditional_t<BiasEnum == BlockAttentionBiasEnum::ALIBI,
240  FmhaFwdAlibiKargs,
241  FmhaFwdEmptyKargs<0>>>,
242  std::conditional_t<kHasMask, FmhaFwdMaskKargs, FmhaFwdEmptyKargs<1>>,
243  std::conditional_t<kStoreLSE, FmhaFwdCommonLSEKargs, FmhaFwdEmptyKargs<2>>,
244  std::conditional_t<QScaleEnum == BlockAttentionQuantScaleEnum::PERTENSOR,
245  FmhaFwdCommonQScaleKargs,
246  FmhaFwdEmptyKargs<3>>,
247  std::conditional_t<kHasDropout, FmhaFwdBatchModeDropoutKargs, FmhaFwdEmptyKargs<4>>,
248  std::conditional_t<kHasLogitsSoftCap, FmhaFwdLogitsSoftCapKargs, FmhaFwdEmptyKargs<5>>
249  {
254 
255  // Optional cumulative sequence length pointers for batch mode
256  // If provided, they override seqlen_q / seqlen_k per-batch to skip tail padding.
257  const int32_t* cu_seqlen_q_ptr = nullptr; // cumulative, length without PAD
258  const int32_t* cu_seqlen_k_ptr = nullptr; // cumulative, length without PAD
259  };
260 
263  std::conditional_t<BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS,
264  FmhaFwdCommonBiasKargs,
265  std::conditional_t<BiasEnum == BlockAttentionBiasEnum::ALIBI,
266  FmhaFwdAlibiKargs,
267  FmhaFwdEmptyKargs<0>>>,
268  std::conditional_t<kHasMask, FmhaFwdMaskKargs, FmhaFwdEmptyKargs<1>>,
269  std::conditional_t<kStoreLSE, FmhaFwdCommonLSEKargs, FmhaFwdEmptyKargs<2>>,
270  std::conditional_t<QScaleEnum == BlockAttentionQuantScaleEnum::PERTENSOR,
271  FmhaFwdCommonQScaleKargs,
272  FmhaFwdEmptyKargs<3>>,
273  std::conditional_t<kHasDropout, FmhaFwdCommonDropoutKargs, FmhaFwdEmptyKargs<4>>,
274  std::conditional_t<kHasLogitsSoftCap, FmhaFwdLogitsSoftCapKargs, FmhaFwdEmptyKargs<5>>,
275  std::conditional_t<kSkipMinSeqlenQ, FmhaFwdSkipMinSeqlenQKargs, FmhaFwdEmptyKargs<6>>
276  {
281 
282  // Optional per-sequence and cumulative logical (excluding padding) sequence length arrays
283  const int32_t* cu_seqlen_q_ptr = nullptr;
284  const int32_t* cu_seqlen_k_ptr = nullptr;
285  };
286 
287  using Kargs = std::conditional_t<kIsGroupMode, FmhaFwdGroupModeKargs, FmhaFwdBatchModeKargs>;
288 
290  {
294  };
295 
296  template <bool Cond = !kIsGroupMode>
297  CK_TILE_HOST static constexpr std::enable_if_t<Cond, Kargs>
298  MakeKargsImpl(const void* q_ptr,
299  const void* k_ptr,
300  const void* v_ptr,
301  const void* bias_ptr,
302  const void* q_descale_ptr,
303  const void* k_descale_ptr,
304  const void* v_descale_ptr,
305  void* rand_val_ptr,
306  void* lse_ptr,
307  void* o_ptr,
308  ck_tile::index_t seqlen_q,
309  ck_tile::index_t seqlen_k,
310  ck_tile::index_t hdim_q,
311  ck_tile::index_t hdim_v,
312  ck_tile::index_t num_head_q,
313  ck_tile::index_t nhead_ratio_qk,
314  float scale_s,
315  float logits_soft_cap,
316  ck_tile::index_t stride_q,
317  ck_tile::index_t stride_k,
318  ck_tile::index_t stride_v,
319  ck_tile::index_t stride_bias,
320  ck_tile::index_t stride_randval,
321  ck_tile::index_t stride_o,
322  ck_tile::index_t nhead_stride_q,
323  ck_tile::index_t nhead_stride_k,
324  ck_tile::index_t nhead_stride_v,
325  ck_tile::index_t nhead_stride_bias,
326  ck_tile::index_t nhead_stride_randval,
327  ck_tile::index_t nhead_stride_lse,
328  ck_tile::index_t nhead_stride_o,
329  ck_tile::index_t batch_stride_q,
330  ck_tile::index_t batch_stride_k,
331  ck_tile::index_t batch_stride_v,
332  ck_tile::index_t batch_stride_bias,
333  ck_tile::index_t batch_stride_randval,
334  ck_tile::index_t batch_stride_lse,
335  ck_tile::index_t batch_stride_o,
336  ck_tile::index_t window_size_left,
337  ck_tile::index_t window_size_right,
338  ck_tile::index_t mask_type,
339  float p_drop,
340  bool s_randval,
341  std::variant<std::pair<uint64_t, uint64_t>, std::pair<const void*, const void*>>
342  drop_seed_offset,
343  const void* cu_seqlen_q_ptr = nullptr,
344  const void* cu_seqlen_k_ptr = nullptr)
345  {
346  Kargs kargs{{q_ptr,
347  k_ptr,
348  v_ptr,
349  o_ptr,
350  seqlen_q,
351  seqlen_k,
352  hdim_q,
353  hdim_v,
354  num_head_q,
355  nhead_ratio_qk,
356 #if CK_TILE_FMHA_FWD_FAST_EXP2
357  static_cast<float>(scale_s * ck_tile::log2e_v<>),
358 #else
359  scale_s,
360 #endif
361  stride_q,
362  stride_k,
363  stride_v,
364  stride_o,
365  nhead_stride_q,
366  nhead_stride_k,
367  nhead_stride_v,
368  nhead_stride_o}, // args for common karg
369  {}, // placeholder for bias
370  {}, // placeholder for mask
371  {}, // placeholder for lse
372  {}, // placeholder for qscale
373  {}, // placeholder for dropout
374  {}, // placeholder for logits_soft_cap
375  batch_stride_q,
376  batch_stride_k,
377  batch_stride_v,
378  batch_stride_o};
379 
381  {
382  kargs.bias_ptr = bias_ptr;
383  kargs.stride_bias = stride_bias;
384  kargs.nhead_stride_bias = nhead_stride_bias;
385  kargs.batch_stride_bias = batch_stride_bias;
386  }
387  else if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI)
388  {
389  kargs.alibi_slope_ptr = bias_ptr;
390  kargs.alibi_slope_stride = stride_bias;
391  }
392  if constexpr(kHasMask)
393  {
394  kargs.window_size_left = window_size_left;
395  kargs.window_size_right = window_size_right;
396  kargs.mask_type = static_cast<ck_tile::GenericAttentionMaskEnum>(mask_type);
397  }
398  if constexpr(kStoreLSE)
399  {
400  kargs.lse_ptr = lse_ptr;
401  kargs.nhead_stride_lse = nhead_stride_lse;
402  kargs.batch_stride_lse = batch_stride_lse;
403  }
405  {
406  kargs.q_descale_ptr = q_descale_ptr;
407  kargs.k_descale_ptr = k_descale_ptr;
408  kargs.v_descale_ptr = v_descale_ptr;
409  }
410  if constexpr(kHasDropout)
411  {
412  if(drop_seed_offset.index() == 0) // seed & offset come from host
413  {
414  const auto& [seed, offset] = std::get<0>(drop_seed_offset);
415  kargs.init_dropout(p_drop, seed, offset);
416  }
417  else // seed & offset come from device
418  {
419  const auto& [seed_ptr, offset_ptr] = std::get<1>(drop_seed_offset);
420  kargs.init_dropout(p_drop,
421  reinterpret_cast<const uint64_t*>(seed_ptr),
422  reinterpret_cast<const uint64_t*>(offset_ptr));
423  }
424 
425  kargs.rand_val_ptr = rand_val_ptr;
426  kargs.stride_randval = stride_randval;
427  kargs.nhead_stride_randval = nhead_stride_randval;
428  kargs.batch_stride_randval = batch_stride_randval;
429  kargs.is_store_randval = s_randval;
430  }
431  if constexpr(kHasLogitsSoftCap)
432  {
433  kargs.init_logits_soft_cap(logits_soft_cap);
434  }
435 
436  kargs.cu_seqlen_q_ptr = reinterpret_cast<const int32_t*>(cu_seqlen_q_ptr);
437  kargs.cu_seqlen_k_ptr = reinterpret_cast<const int32_t*>(cu_seqlen_k_ptr);
438  return kargs;
439  }
440 
441  // std::variant<> can't take in a list initializer, overload for backward compatibility
442  template <bool Cond = !kIsGroupMode>
443  CK_TILE_HOST static constexpr std::enable_if_t<Cond, Kargs>
444  MakeKargs(const void* q_ptr,
445  const void* k_ptr,
446  const void* v_ptr,
447  const void* bias_ptr,
448  const void* q_descale_ptr,
449  const void* k_descale_ptr,
450  const void* v_descale_ptr,
451  void* rand_val_ptr,
452  void* lse_ptr,
453  void* o_ptr,
454  ck_tile::index_t seqlen_q,
455  ck_tile::index_t seqlen_k,
456  ck_tile::index_t hdim_q,
457  ck_tile::index_t hdim_v,
458  ck_tile::index_t num_head_q,
459  ck_tile::index_t nhead_ratio_qk,
460  float scale_s,
461  float logits_soft_cap,
462  ck_tile::index_t stride_q,
463  ck_tile::index_t stride_k,
464  ck_tile::index_t stride_v,
465  ck_tile::index_t stride_bias,
466  ck_tile::index_t stride_randval,
467  ck_tile::index_t stride_o,
468  ck_tile::index_t nhead_stride_q,
469  ck_tile::index_t nhead_stride_k,
470  ck_tile::index_t nhead_stride_v,
471  ck_tile::index_t nhead_stride_bias,
472  ck_tile::index_t nhead_stride_randval,
473  ck_tile::index_t nhead_stride_lse,
474  ck_tile::index_t nhead_stride_o,
475  ck_tile::index_t batch_stride_q,
476  ck_tile::index_t batch_stride_k,
477  ck_tile::index_t batch_stride_v,
478  ck_tile::index_t batch_stride_bias,
479  ck_tile::index_t batch_stride_randval,
480  ck_tile::index_t batch_stride_lse,
481  ck_tile::index_t batch_stride_o,
482  ck_tile::index_t window_size_left,
483  ck_tile::index_t window_size_right,
484  ck_tile::index_t mask_type,
485  float p_drop,
486  bool s_randval,
487  const std::tuple<uint64_t, uint64_t>& drop_seed_offset,
488  const void* cu_seqlen_q_ptr = nullptr,
489  const void* cu_seqlen_k_ptr = nullptr)
490  {
491  return MakeKargsImpl(
492  q_ptr,
493  k_ptr,
494  v_ptr,
495  bias_ptr,
496  q_descale_ptr,
497  k_descale_ptr,
498  v_descale_ptr,
499  rand_val_ptr,
500  lse_ptr,
501  o_ptr,
502  seqlen_q,
503  seqlen_k,
504  hdim_q,
505  hdim_v,
506  num_head_q,
507  nhead_ratio_qk,
508  scale_s,
509  logits_soft_cap,
510  stride_q,
511  stride_k,
512  stride_v,
513  stride_bias,
514  stride_randval,
515  stride_o,
516  nhead_stride_q,
517  nhead_stride_k,
518  nhead_stride_v,
519  nhead_stride_bias,
520  nhead_stride_randval,
521  nhead_stride_lse,
522  nhead_stride_o,
523  batch_stride_q,
524  batch_stride_k,
525  batch_stride_v,
526  batch_stride_bias,
527  batch_stride_randval,
528  batch_stride_lse,
529  batch_stride_o,
530  window_size_left,
531  window_size_right,
532  mask_type,
533  p_drop,
534  s_randval,
535  std::make_pair(std::get<0>(drop_seed_offset), std::get<1>(drop_seed_offset)),
536  cu_seqlen_q_ptr,
537  cu_seqlen_k_ptr);
538  }
539 
540  // std::variant<> can't take in a list initializer, overload for backward compatibility
541  template <bool Cond = !kIsGroupMode>
542  CK_TILE_HOST static constexpr std::enable_if_t<Cond, Kargs>
543  MakeKargs(const void* q_ptr,
544  const void* k_ptr,
545  const void* v_ptr,
546  const void* bias_ptr,
547  const void* q_descale_ptr,
548  const void* k_descale_ptr,
549  const void* v_descale_ptr,
550  void* rand_val_ptr,
551  void* lse_ptr,
552  void* o_ptr,
553  ck_tile::index_t seqlen_q,
554  ck_tile::index_t seqlen_k,
555  ck_tile::index_t hdim_q,
556  ck_tile::index_t hdim_v,
557  ck_tile::index_t num_head_q,
558  ck_tile::index_t nhead_ratio_qk,
559  float scale_s,
560  float logits_soft_cap,
561  ck_tile::index_t stride_q,
562  ck_tile::index_t stride_k,
563  ck_tile::index_t stride_v,
564  ck_tile::index_t stride_bias,
565  ck_tile::index_t stride_randval,
566  ck_tile::index_t stride_o,
567  ck_tile::index_t nhead_stride_q,
568  ck_tile::index_t nhead_stride_k,
569  ck_tile::index_t nhead_stride_v,
570  ck_tile::index_t nhead_stride_bias,
571  ck_tile::index_t nhead_stride_randval,
572  ck_tile::index_t nhead_stride_lse,
573  ck_tile::index_t nhead_stride_o,
574  ck_tile::index_t batch_stride_q,
575  ck_tile::index_t batch_stride_k,
576  ck_tile::index_t batch_stride_v,
577  ck_tile::index_t batch_stride_bias,
578  ck_tile::index_t batch_stride_randval,
579  ck_tile::index_t batch_stride_lse,
580  ck_tile::index_t batch_stride_o,
581  ck_tile::index_t window_size_left,
582  ck_tile::index_t window_size_right,
583  ck_tile::index_t mask_type,
584  float p_drop,
585  bool s_randval,
586  const std::tuple<const void*, const void*>& drop_seed_offset,
587  const void* cu_seqlen_q_ptr = nullptr,
588  const void* cu_seqlen_k_ptr = nullptr)
589  {
590  return MakeKargsImpl(
591  q_ptr,
592  k_ptr,
593  v_ptr,
594  bias_ptr,
595  q_descale_ptr,
596  k_descale_ptr,
597  v_descale_ptr,
598  rand_val_ptr,
599  lse_ptr,
600  o_ptr,
601  seqlen_q,
602  seqlen_k,
603  hdim_q,
604  hdim_v,
605  num_head_q,
606  nhead_ratio_qk,
607  scale_s,
608  logits_soft_cap,
609  stride_q,
610  stride_k,
611  stride_v,
612  stride_bias,
613  stride_randval,
614  stride_o,
615  nhead_stride_q,
616  nhead_stride_k,
617  nhead_stride_v,
618  nhead_stride_bias,
619  nhead_stride_randval,
620  nhead_stride_lse,
621  nhead_stride_o,
622  batch_stride_q,
623  batch_stride_k,
624  batch_stride_v,
625  batch_stride_bias,
626  batch_stride_randval,
627  batch_stride_lse,
628  batch_stride_o,
629  window_size_left,
630  window_size_right,
631  mask_type,
632  p_drop,
633  s_randval,
634  std::make_pair(std::get<0>(drop_seed_offset), std::get<1>(drop_seed_offset)),
635  cu_seqlen_q_ptr,
636  cu_seqlen_k_ptr);
637  }
638 
639  template <bool Cond = kIsGroupMode>
640  CK_TILE_HOST static constexpr std::enable_if_t<Cond, Kargs>
641  MakeKargsImpl(const void* q_ptr,
642  const void* k_ptr,
643  const void* v_ptr,
644  const void* bias_ptr,
645  const void* q_descale_ptr,
646  const void* k_descale_ptr,
647  const void* v_descale_ptr,
648  void* rand_val_ptr,
649  void* lse_ptr,
650  void* o_ptr,
651  const void* seqstart_q_ptr,
652  const void* seqstart_k_ptr,
653  const void* seqlen_q_ptr,
654  const void* seqlen_k_ptr,
655  ck_tile::index_t hdim_q,
656  ck_tile::index_t hdim_v,
657  ck_tile::index_t num_head_q,
658  ck_tile::index_t nhead_ratio_qk,
659  float scale_s,
660  float logits_soft_cap,
661  ck_tile::index_t stride_q,
662  ck_tile::index_t stride_k,
663  ck_tile::index_t stride_v,
664  ck_tile::index_t stride_bias,
665  ck_tile::index_t stride_randval,
666  ck_tile::index_t stride_o,
667  ck_tile::index_t nhead_stride_q,
668  ck_tile::index_t nhead_stride_k,
669  ck_tile::index_t nhead_stride_v,
670  ck_tile::index_t nhead_stride_bias,
671  ck_tile::index_t nhead_stride_randval,
672  ck_tile::index_t nhead_stride_lse,
673  ck_tile::index_t nhead_stride_o,
674  ck_tile::index_t window_size_left,
675  ck_tile::index_t window_size_right,
676  ck_tile::index_t mask_type,
677  ck_tile::index_t min_seqlen_q,
678  float p_drop,
679  bool s_randval,
680  std::variant<std::pair<uint64_t, uint64_t>, std::pair<const void*, const void*>>
681  drop_seed_offset,
682  const void* cu_seqlen_q_ptr = nullptr,
683  const void* cu_seqlen_k_ptr = nullptr)
684  {
685  Kargs kargs{{q_ptr,
686  k_ptr,
687  v_ptr,
688  o_ptr,
689  -1, // seqlen will be updated by another pointer
690  -1, //
691  hdim_q,
692  hdim_v,
693  num_head_q,
694  nhead_ratio_qk,
695 #if CK_TILE_FMHA_FWD_FAST_EXP2
696  static_cast<float>(scale_s * ck_tile::log2e_v<>),
697 #else
698  scale_s,
699 #endif
700  stride_q,
701  stride_k,
702  stride_v,
703  stride_o,
704  nhead_stride_q,
705  nhead_stride_k,
706  nhead_stride_v,
707  nhead_stride_o}, // args for common karg
708  {}, // placeholder for bias
709  {}, // placeholder for mask
710  {}, // placeholder for lse
711  {}, // placeholder for qscale
712  {}, // placeholder for dropout
713  {}, // placeholder for logits_soft_cap
714  {}, // placeholder for min_seqlen_q
715  reinterpret_cast<const int32_t*>(seqstart_q_ptr),
716  reinterpret_cast<const int32_t*>(seqstart_k_ptr),
717  reinterpret_cast<const int32_t*>(seqlen_q_ptr),
718  reinterpret_cast<const int32_t*>(seqlen_k_ptr)};
719 
721  {
722  kargs.bias_ptr = bias_ptr;
723  kargs.stride_bias = stride_bias;
724  kargs.nhead_stride_bias = nhead_stride_bias;
725  }
726  else if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI)
727  {
728  kargs.alibi_slope_ptr = bias_ptr;
729  kargs.alibi_slope_stride = stride_bias;
730  }
731  if constexpr(kHasMask)
732  {
733  kargs.window_size_left = window_size_left;
734  kargs.window_size_right = window_size_right;
735  kargs.mask_type = static_cast<ck_tile::GenericAttentionMaskEnum>(mask_type);
736  }
737  if constexpr(kStoreLSE)
738  {
739  kargs.lse_ptr = lse_ptr;
740  kargs.nhead_stride_lse = nhead_stride_lse;
741  }
743  {
744  kargs.q_descale_ptr = q_descale_ptr;
745  kargs.k_descale_ptr = k_descale_ptr;
746  kargs.v_descale_ptr = v_descale_ptr;
747  }
748  if constexpr(kHasDropout)
749  {
750  if(drop_seed_offset.index() == 0) // seed & offset come from host
751  {
752  const auto& [seed, offset] = std::get<0>(drop_seed_offset);
753  kargs.init_dropout(p_drop, seed, offset);
754  }
755  else // seed & offset come from device
756  {
757  const auto& [seed_ptr, offset_ptr] = std::get<1>(drop_seed_offset);
758  kargs.init_dropout(p_drop,
759  reinterpret_cast<const uint64_t*>(seed_ptr),
760  reinterpret_cast<const uint64_t*>(offset_ptr));
761  }
762 
763  kargs.rand_val_ptr = rand_val_ptr;
764  kargs.stride_randval = stride_randval;
765  kargs.nhead_stride_randval = nhead_stride_randval;
766  kargs.is_store_randval = s_randval;
767  }
768  if constexpr(kHasLogitsSoftCap)
769  {
770  kargs.init_logits_soft_cap(logits_soft_cap);
771  }
772  if constexpr(kSkipMinSeqlenQ)
773  {
774  kargs.min_seqlen_q = min_seqlen_q;
775  }
776 
777  kargs.cu_seqlen_q_ptr = reinterpret_cast<const int32_t*>(cu_seqlen_q_ptr);
778  kargs.cu_seqlen_k_ptr = reinterpret_cast<const int32_t*>(cu_seqlen_k_ptr);
779  return kargs;
780  }
781 
782  // std::variant<> can't take in a list initializer, overload for backward compatibility
783  template <bool Cond = kIsGroupMode>
784  CK_TILE_HOST static constexpr std::enable_if_t<Cond, Kargs>
785  MakeKargs(const void* q_ptr,
786  const void* k_ptr,
787  const void* v_ptr,
788  const void* bias_ptr,
789  const void* q_descale_ptr,
790  const void* k_descale_ptr,
791  const void* v_descale_ptr,
792  void* rand_val_ptr,
793  void* lse_ptr,
794  void* o_ptr,
795  const void* seqstart_q_ptr,
796  const void* seqstart_k_ptr,
797  const void* seqlen_q_ptr,
798  const void* seqlen_k_ptr,
799  ck_tile::index_t hdim_q,
800  ck_tile::index_t hdim_v,
801  ck_tile::index_t num_head_q,
802  ck_tile::index_t nhead_ratio_qk,
803  float scale_s,
804  float logits_soft_cap,
805  ck_tile::index_t stride_q,
806  ck_tile::index_t stride_k,
807  ck_tile::index_t stride_v,
808  ck_tile::index_t stride_bias,
809  ck_tile::index_t stride_randval,
810  ck_tile::index_t stride_o,
811  ck_tile::index_t nhead_stride_q,
812  ck_tile::index_t nhead_stride_k,
813  ck_tile::index_t nhead_stride_v,
814  ck_tile::index_t nhead_stride_bias,
815  ck_tile::index_t nhead_stride_randval,
816  ck_tile::index_t nhead_stride_lse,
817  ck_tile::index_t nhead_stride_o,
818  ck_tile::index_t window_size_left,
819  ck_tile::index_t window_size_right,
820  ck_tile::index_t mask_type,
821  ck_tile::index_t min_seqlen_q,
822  float p_drop,
823  bool s_randval,
824  const std::tuple<uint64_t, uint64_t>& drop_seed_offset,
825  const void* cu_seqlen_q_ptr = nullptr,
826  const void* cu_seqlen_k_ptr = nullptr)
827  {
828  return MakeKargsImpl(
829  q_ptr,
830  k_ptr,
831  v_ptr,
832  bias_ptr,
833  q_descale_ptr,
834  k_descale_ptr,
835  v_descale_ptr,
836  rand_val_ptr,
837  lse_ptr,
838  o_ptr,
839  seqstart_q_ptr,
840  seqstart_k_ptr,
841  seqlen_q_ptr,
842  seqlen_k_ptr,
843  hdim_q,
844  hdim_v,
845  num_head_q,
846  nhead_ratio_qk,
847  scale_s,
848  logits_soft_cap,
849  stride_q,
850  stride_k,
851  stride_v,
852  stride_bias,
853  stride_randval,
854  stride_o,
855  nhead_stride_q,
856  nhead_stride_k,
857  nhead_stride_v,
858  nhead_stride_bias,
859  nhead_stride_randval,
860  nhead_stride_lse,
861  nhead_stride_o,
862  window_size_left,
863  window_size_right,
864  mask_type,
865  min_seqlen_q,
866  p_drop,
867  s_randval,
868  std::make_pair(std::get<0>(drop_seed_offset), std::get<1>(drop_seed_offset)),
869  cu_seqlen_q_ptr,
870  cu_seqlen_k_ptr);
871  }
872 
873  // std::variant<> can't take in a list initializer, overload for backward compatibility
874  template <bool Cond = kIsGroupMode>
875  CK_TILE_HOST static constexpr std::enable_if_t<Cond, Kargs>
876  MakeKargs(const void* q_ptr,
877  const void* k_ptr,
878  const void* v_ptr,
879  const void* bias_ptr,
880  const void* q_descale_ptr,
881  const void* k_descale_ptr,
882  const void* v_descale_ptr,
883  void* rand_val_ptr,
884  void* lse_ptr,
885  void* o_ptr,
886  const void* seqstart_q_ptr,
887  const void* seqstart_k_ptr,
888  const void* seqlen_q_ptr,
889  const void* seqlen_k_ptr,
890  ck_tile::index_t hdim_q,
891  ck_tile::index_t hdim_v,
892  ck_tile::index_t num_head_q,
893  ck_tile::index_t nhead_ratio_qk,
894  float scale_s,
895  float logits_soft_cap,
896  ck_tile::index_t stride_q,
897  ck_tile::index_t stride_k,
898  ck_tile::index_t stride_v,
899  ck_tile::index_t stride_bias,
900  ck_tile::index_t stride_randval,
901  ck_tile::index_t stride_o,
902  ck_tile::index_t nhead_stride_q,
903  ck_tile::index_t nhead_stride_k,
904  ck_tile::index_t nhead_stride_v,
905  ck_tile::index_t nhead_stride_bias,
906  ck_tile::index_t nhead_stride_randval,
907  ck_tile::index_t nhead_stride_lse,
908  ck_tile::index_t nhead_stride_o,
909  ck_tile::index_t window_size_left,
910  ck_tile::index_t window_size_right,
911  ck_tile::index_t mask_type,
912  ck_tile::index_t min_seqlen_q,
913  float p_drop,
914  bool s_randval,
915  const std::tuple<const void*, const void*>& drop_seed_offset,
916  const void* cu_seqlen_q_ptr = nullptr,
917  const void* cu_seqlen_k_ptr = nullptr)
918  {
919  return MakeKargsImpl(
920  q_ptr,
921  k_ptr,
922  v_ptr,
923  bias_ptr,
924  q_descale_ptr,
925  k_descale_ptr,
926  v_descale_ptr,
927  rand_val_ptr,
928  lse_ptr,
929  o_ptr,
930  seqstart_q_ptr,
931  seqstart_k_ptr,
932  seqlen_q_ptr,
933  seqlen_k_ptr,
934  hdim_q,
935  hdim_v,
936  num_head_q,
937  nhead_ratio_qk,
938  scale_s,
939  logits_soft_cap,
940  stride_q,
941  stride_k,
942  stride_v,
943  stride_bias,
944  stride_randval,
945  stride_o,
946  nhead_stride_q,
947  nhead_stride_k,
948  nhead_stride_v,
949  nhead_stride_bias,
950  nhead_stride_randval,
951  nhead_stride_lse,
952  nhead_stride_o,
953  window_size_left,
954  window_size_right,
955  mask_type,
956  min_seqlen_q,
957  p_drop,
958  s_randval,
959  std::make_pair(std::get<0>(drop_seed_offset), std::get<1>(drop_seed_offset)),
960  cu_seqlen_q_ptr,
961  cu_seqlen_k_ptr);
962  }
963 
964  CK_TILE_HOST static constexpr auto GridSize(ck_tile::index_t batch_size_,
965  ck_tile::index_t nhead_,
966  ck_tile::index_t seqlen_q_,
967  ck_tile::index_t hdim_v_,
968  bool has_padded_seqlen_k = false)
969  {
970  // has_padded_seqlen_k is determined by checking (seqlen_k_ptr != nullptr)
971  if(has_padded_seqlen_k)
972  {
973  // TODO: this may need tuning
974  return dim3(nhead_,
975  batch_size_,
976  ck_tile::integer_divide_ceil(seqlen_q_, FmhaPipeline::kM0) *
977  ck_tile::integer_divide_ceil(hdim_v_, FmhaPipeline::kN1));
978  }
979  else
980  {
981  // TODO: this may need tuning
982  return dim3(nhead_,
983  ck_tile::integer_divide_ceil(seqlen_q_, FmhaPipeline::kM0) *
984  ck_tile::integer_divide_ceil(hdim_v_, FmhaPipeline::kN1),
985  batch_size_);
986  }
987  }
988 
989  CK_TILE_DEVICE static constexpr auto GetTileIndex(const Kargs& kargs)
990  {
991  bool has_padded_seqlen_k = false;
992 
993  if constexpr(kIsGroupMode)
994  has_padded_seqlen_k = (kargs.seqlen_k_ptr != nullptr);
995 
996  if(has_padded_seqlen_k)
997  {
998  // const index_t num_tile_m0 = seqlen_q / kM0;
999  const index_t num_tile_n1 =
1000  ck_tile::integer_divide_ceil(kargs.hdim_v, FmhaPipeline::kN1);
1001 
1002  const index_t i_block = blockIdx.z;
1003  const index_t i_nhead = blockIdx.x;
1004  const index_t i_batch = blockIdx.y;
1005 
1006  const auto f = [](index_t dividend, index_t divisor) {
1007  index_t quotient = dividend / divisor;
1008  index_t modulus = dividend - quotient * divisor;
1009  return ck_tile::make_tuple(quotient, modulus);
1010  };
1011 
1012  const auto [i_tile_m, i_tile_n] = f(i_block, num_tile_n1);
1013 
1014  if constexpr(kHasMask)
1015  {
1016  // assume that num_tile_n1 is always 1
1017  return ck_tile::make_tuple(gridDim.z - 1 - i_tile_m, i_tile_n, i_nhead, i_batch);
1018  }
1019  else
1020  {
1021  return ck_tile::make_tuple(i_tile_m, i_tile_n, i_nhead, i_batch);
1022  }
1023  }
1024  else
1025  {
1026  // const index_t num_tile_m0 = seqlen_q / kM0;
1027  const index_t num_tile_n1 =
1028  ck_tile::integer_divide_ceil(kargs.hdim_v, FmhaPipeline::kN1);
1029 
1030  const index_t i_block = blockIdx.y; // blockIdx.x
1031  const index_t i_nhead = blockIdx.x; // blockIdx.y
1032  const index_t i_batch = blockIdx.z;
1033 
1034  const auto f = [](index_t dividend, index_t divisor) {
1035  index_t quotient = dividend / divisor;
1036  index_t modulus = dividend - quotient * divisor;
1037  return ck_tile::make_tuple(quotient, modulus);
1038  };
1039 
1040  const auto [i_tile_m, i_tile_n] = f(i_block, num_tile_n1);
1041 
1042  if constexpr(kHasMask)
1043  {
1044  // assume that num_tile_n1 is always 1
1045  return ck_tile::make_tuple(gridDim.y - 1 - i_tile_m, i_tile_n, i_nhead, i_batch);
1046  }
1047  else
1048  {
1049  return ck_tile::make_tuple(i_tile_m, i_tile_n, i_nhead, i_batch);
1050  }
1051  }
1052  }
1053 
1054  CK_TILE_HOST static dim3 BlockSize()
1055  {
1056  if(is_wave32())
1057  {
1058  return dim3(kBlockSize / 2);
1059  }
1060  else
1061  {
1062  return dim3(kBlockSize);
1063  }
1064  }
1065 
1067  {
1068  return ck_tile::max(FmhaPipeline::GetSmemSize(), EpiloguePipeline::GetSmemSize());
1069  }
1070 
1071  CK_TILE_DEVICE void operator()(Kargs kargs) const
1072  {
1073  if constexpr(kIsAvailable)
1074  run_(std::move(kargs));
1075  }
1076 
1077  CK_TILE_DEVICE void run_(Kargs kargs) const
1078  {
1079  if constexpr(kPipelineName != "qr_async_trload")
1080  {
1081  // allocate LDS
1082  __shared__ char smem_ptr[GetSmemSize()];
1083 
1084  // divide problem
1085  const auto [i_tile_m, i_tile_n, i_nhead, i_batch] = GetTileIndex(kargs);
1086 
1087  const index_t i_m0 = amd_wave_read_first_lane(i_tile_m * FmhaPipeline::kM0);
1088  const index_t i_n1 = amd_wave_read_first_lane(i_tile_n * FmhaPipeline::kN1);
1089 
1090  long_index_t batch_offset_q = 0;
1091  long_index_t batch_offset_k = 0;
1092  long_index_t batch_offset_v = 0;
1093  long_index_t batch_offset_bias = 0;
1094  long_index_t batch_offset_randval = 0;
1095  long_index_t batch_offset_lse = 0;
1096  long_index_t batch_offset_o = 0;
1097 
1098  if constexpr(kIsGroupMode)
1099  {
1100  // Use seqstart_q_ptr and seqstart_k_ptr for physical starts
1101  const long_index_t query_start = kargs.seqstart_q_ptr[i_batch];
1102  const long_index_t key_start = kargs.seqstart_k_ptr[i_batch];
1103 
1104  // DRAM base offsets use physical starts
1105  batch_offset_q = query_start * kargs.stride_q;
1106  batch_offset_k = key_start * kargs.stride_k;
1107  if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
1108  {
1109  batch_offset_v = key_start * kargs.stride_v;
1110  }
1111  else
1112  {
1113  batch_offset_v = key_start;
1114  }
1116  {
1117  batch_offset_bias = query_start * kargs.stride_bias;
1118  }
1119  if constexpr(kStoreLSE)
1120  {
1121  // LSE follows the physical layout to stay consistent with other tensors
1122  batch_offset_lse = query_start;
1123  }
1124  if constexpr(kHasDropout)
1125  {
1126  batch_offset_randval = query_start * kargs.stride_randval;
1127  }
1128  batch_offset_o = query_start * kargs.stride_o;
1129 
1130  // real logical lengths (exclude PAD)
1131  // Priority: seqlen_q_ptr > cu_seqlen_q_ptr > calculated from seqstart_q_ptr
1132  if(kargs.seqlen_q_ptr != nullptr)
1133  {
1134  kargs.seqlen_q = kargs.seqlen_q_ptr[i_batch];
1135  }
1136  else if(kargs.cu_seqlen_q_ptr != nullptr)
1137  {
1138  kargs.seqlen_q =
1139  kargs.cu_seqlen_q_ptr[i_batch + 1] - kargs.cu_seqlen_q_ptr[i_batch];
1140  }
1141  else
1142  {
1143  const auto adjusted_seqstart_q_ptr = kargs.seqstart_q_ptr + i_batch;
1144  kargs.seqlen_q = adjusted_seqstart_q_ptr[1] - adjusted_seqstart_q_ptr[0];
1145  }
1146 
1147  if constexpr(kSkipMinSeqlenQ)
1148  {
1149  if(kargs.seqlen_q <= kargs.min_seqlen_q)
1150  {
1151  return;
1152  }
1153  }
1154 
1155  // terminate unnecessary blocks earlier
1156  if(kargs.seqlen_q <= i_m0)
1157  {
1158  return;
1159  }
1160 
1161  if(kargs.seqlen_k_ptr != nullptr)
1162  {
1163  kargs.seqlen_k = kargs.seqlen_k_ptr[i_batch];
1164  }
1165  else if(kargs.cu_seqlen_k_ptr != nullptr)
1166  {
1167  kargs.seqlen_k =
1168  kargs.cu_seqlen_k_ptr[i_batch + 1] - kargs.cu_seqlen_k_ptr[i_batch];
1169  }
1170  else
1171  {
1172  const auto adjusted_seqstart_k_ptr = kargs.seqstart_k_ptr + i_batch;
1173  kargs.seqlen_k = adjusted_seqstart_k_ptr[1] - adjusted_seqstart_k_ptr[0];
1174  }
1175  }
1176  else
1177  {
1178  batch_offset_q = static_cast<long_index_t>(i_batch) * kargs.batch_stride_q;
1179  batch_offset_k = static_cast<long_index_t>(i_batch) * kargs.batch_stride_k;
1180  batch_offset_v = static_cast<long_index_t>(i_batch) * kargs.batch_stride_v;
1182  {
1183  batch_offset_bias =
1184  static_cast<long_index_t>(i_batch) * kargs.batch_stride_bias;
1185  }
1186  if constexpr(kStoreLSE)
1187  {
1188  batch_offset_lse = static_cast<long_index_t>(i_batch) * kargs.batch_stride_lse;
1189  }
1190  if constexpr(kHasDropout)
1191  {
1192  batch_offset_randval =
1193  static_cast<long_index_t>(i_batch) * kargs.batch_stride_randval;
1194  }
1195  batch_offset_o = static_cast<long_index_t>(i_batch) * kargs.batch_stride_o;
1196 
1197  // If cumulative seqlen pointers are provided, override per-batch effective lengths
1198  if(kargs.cu_seqlen_q_ptr != nullptr)
1199  {
1200  kargs.seqlen_q =
1201  kargs.cu_seqlen_q_ptr[i_batch + 1] - kargs.cu_seqlen_q_ptr[i_batch];
1202  }
1203  if(kargs.cu_seqlen_k_ptr != nullptr)
1204  {
1205  kargs.seqlen_k =
1206  kargs.cu_seqlen_k_ptr[i_batch + 1] - kargs.cu_seqlen_k_ptr[i_batch];
1207  }
1208  }
1209 
1210  // for simplicity, batch stride we just modify the pointer
1211  const QDataType* q_ptr = reinterpret_cast<const QDataType*>(kargs.q_ptr) +
1212  static_cast<long_index_t>(i_nhead) * kargs.nhead_stride_q +
1213  batch_offset_q;
1214  const KDataType* k_ptr =
1215  reinterpret_cast<const KDataType*>(kargs.k_ptr) +
1216  static_cast<long_index_t>(i_nhead / kargs.nhead_ratio_qk) * kargs.nhead_stride_k +
1217  batch_offset_k;
1218  const VDataType* v_ptr =
1219  reinterpret_cast<const VDataType*>(kargs.v_ptr) +
1220  static_cast<long_index_t>(i_nhead / kargs.nhead_ratio_qk) * kargs.nhead_stride_v +
1221  batch_offset_v;
1222  ODataType* o_ptr = reinterpret_cast<ODataType*>(kargs.o_ptr) +
1223  static_cast<long_index_t>(i_nhead) * kargs.nhead_stride_o +
1224  batch_offset_o;
1225 
1226  // Q/K/V DRAM and DRAM window
1227  const auto q_dram = [&]() {
1228  const auto q_dram_naive = make_naive_tensor_view<address_space_enum::global>(
1229  q_ptr,
1230  make_tuple(kargs.seqlen_q, kargs.hdim_q),
1231  make_tuple(kargs.stride_q, 1),
1233  number<1>{});
1234  if constexpr(FmhaPipeline::kQLoadOnce)
1235  {
1236  return pad_tensor_view(q_dram_naive,
1240  }
1241  else
1242  {
1243  return pad_tensor_view(
1244  q_dram_naive,
1247  }
1248  }();
1249  const auto k_dram = [&]() {
1250  const auto k_dram_naive = make_naive_tensor_view<address_space_enum::global>(
1251  k_ptr,
1252  make_tuple(kargs.seqlen_k, kargs.hdim_q),
1253  make_tuple(kargs.stride_k, 1),
1255  number<1>{});
1256 
1257  constexpr bool kPadSeqLenK_ = kUseAsyncCopy ? kPadSeqLenK : false;
1258  return pad_tensor_view(
1259  k_dram_naive,
1262  }();
1263  const auto v_dram = [&]() {
1264  if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
1265  {
1266  const auto v_dram_naive = make_naive_tensor_view<address_space_enum::global>(
1267  v_ptr,
1268  make_tuple(kargs.seqlen_k, kargs.hdim_v),
1269  make_tuple(kargs.stride_v, 1),
1271  number<1>{});
1272 
1273  const auto v_dram_transposed = transform_tensor_view(
1274  v_dram_naive,
1276  make_pass_through_transform(kargs.seqlen_k)),
1279 
1280  constexpr bool kPadSeqLenK_ = kUseAsyncCopy ? kPadSeqLenK : false;
1281  return pad_tensor_view(
1282  v_dram_transposed,
1285  }
1286  else
1287  {
1288  const auto v_dram_naive = make_naive_tensor_view<address_space_enum::global>(
1289  v_ptr,
1290  make_tuple(kargs.hdim_v, kargs.seqlen_k),
1291  make_tuple(kargs.stride_v, 1),
1293  number<1>{});
1294 
1295  constexpr bool kPadHeadDimV_ = kUseAsyncCopy ? kPadHeadDimV : false;
1296  return pad_tensor_view(
1297  v_dram_naive,
1300  }
1301  }();
1302 
1303  auto q_dram_window = make_tile_window(
1304  q_dram,
1305  [&]() {
1306  if constexpr(FmhaPipeline::kQLoadOnce)
1309  else
1311  }(),
1312  {i_m0, 0});
1313 
1314  auto k_dram_window = make_tile_window(
1315  k_dram,
1317  {0, 0});
1318 
1319  auto v_dram_window = make_tile_window(
1320  v_dram,
1322  {i_n1, 0});
1325  const auto bias_dram_window = [&, i_nhead_ = i_nhead]() {
1326  constexpr auto bias_dram_window_lengths =
1329  {
1330  const BiasDataType* bias_ptr =
1331  reinterpret_cast<const BiasDataType*>(kargs.bias_ptr) +
1332  static_cast<long_index_t>(i_nhead_) * kargs.nhead_stride_bias +
1333  batch_offset_bias;
1334 
1335  const auto bias_dram = [&]() {
1336  const auto bias_dram_naive =
1337  make_naive_tensor_view<address_space_enum::global>(
1338  bias_ptr,
1339  make_tuple(kargs.seqlen_q, kargs.seqlen_k),
1340  make_tuple(kargs.stride_bias, 1),
1342  number<1>{});
1343 
1344  return pad_tensor_view(bias_dram_naive,
1345  bias_dram_window_lengths,
1347  }();
1348 
1349  return make_tile_window(bias_dram, bias_dram_window_lengths, {i_m0, 0});
1350  }
1351  else
1352  {
1353  return make_null_tile_window(bias_dram_window_lengths);
1354  }
1355  }();
1356 
1357  // lse
1358  auto lse_dram_window = [&, i_nhead_ = i_nhead]() {
1359  constexpr auto lse_dram_window_lengths = make_tuple(number<FmhaPipeline::kM0>{});
1360  if constexpr(kStoreLSE)
1361  {
1362  LSEDataType* lse_ptr =
1363  reinterpret_cast<LSEDataType*>(kargs.lse_ptr) +
1364  static_cast<long_index_t>(i_nhead_) * kargs.nhead_stride_lse +
1365  batch_offset_lse;
1366 
1367  const auto lse_dram = [&]() {
1368  const auto lse_dram_naive =
1369  make_naive_tensor_view<address_space_enum::global>(
1370  lse_ptr,
1371  make_tuple(kargs.seqlen_q),
1372  make_tuple(1),
1373  number<1>{},
1374  number<1>{});
1375 
1376  return pad_tensor_view(
1377  lse_dram_naive, lse_dram_window_lengths, sequence<kPadSeqLenQ>{});
1378  }();
1379 
1380  return make_tile_window(lse_dram, lse_dram_window_lengths, {i_m0});
1381  }
1382  else
1383  {
1384  return make_null_tile_window(lse_dram_window_lengths);
1385  }
1386  }();
1387 
1388  auto dropout = [&, i_nhead_ = i_nhead, i_batch_ = i_batch]() {
1389  if constexpr(kHasDropout)
1390  {
1391  return BlockDropout{i_batch_,
1392  i_nhead_,
1393  kargs.num_head_q,
1394  kargs.is_drop_seed_offset_from_host ? kargs.drop_seed.val
1395  : *kargs.drop_seed.ptr,
1396  kargs.is_drop_seed_offset_from_host
1397  ? kargs.drop_offset.val
1398  : *kargs.drop_offset.ptr,
1399  kargs.rp_undrop,
1400  kargs.p_undrop_in_uint8_t,
1401  kargs.is_store_randval};
1402  }
1403  else
1404  {
1405  return NullBlockDropout{};
1406  };
1407  }();
1408 
1409  auto randval_dram_window = [&, i_nhead_ = i_nhead]() {
1410  constexpr auto randval_dram_window_lengths =
1412  if constexpr(kHasDropout)
1413  {
1414  RandValOutputDataType* rand_val_ptr =
1415  reinterpret_cast<RandValOutputDataType*>(kargs.rand_val_ptr) +
1416  static_cast<long_index_t>(i_nhead_) * kargs.nhead_stride_randval +
1417  batch_offset_randval;
1418 
1419  const auto randval_dram = [&]() {
1420  const auto randval_dram_naive =
1421  make_naive_tensor_view<address_space_enum::global>(
1422  rand_val_ptr,
1423  make_tuple(kargs.seqlen_q, kargs.seqlen_k),
1424  make_tuple(kargs.stride_randval, 1),
1426  number<1>{});
1427 
1428  return pad_tensor_view(randval_dram_naive,
1429  randval_dram_window_lengths,
1431  }();
1432 
1433  return make_tile_window(randval_dram, randval_dram_window_lengths, {i_m0, 0});
1434  }
1435  else
1436  {
1437  return make_null_tile_window(randval_dram_window_lengths);
1438  }
1439  }();
1440 
1441  FmhaMask mask = [&]() {
1442  if constexpr(kHasMask)
1443  return ck_tile::make_generic_attention_mask_from_lr_window<FmhaMask>(
1444  kargs.window_size_left,
1445  kargs.window_size_right,
1446  kargs.seqlen_q,
1447  kargs.seqlen_k,
1449  else
1450  return FmhaMask{kargs.seqlen_q, kargs.seqlen_k};
1451  }();
1452 
1453  // WA i_batch capture structure binding before c++20
1454  auto position_encoding = [&, i_batch_ = i_batch, i_nhead_ = i_nhead]() {
1455  if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI)
1456  {
1457  // data loading, shared by entire wg
1458  // TODO: how to use s_read?
1459  SaccDataType slope =
1460  *(reinterpret_cast<const SaccDataType*>(kargs.alibi_slope_ptr) +
1461  i_batch_ * kargs.alibi_slope_stride + i_nhead_);
1462 #if CK_TILE_FMHA_FWD_FAST_EXP2
1463  slope *= ck_tile::log2e_v<>;
1464 #endif
1465  if constexpr(kHasMask)
1466  {
1467  return make_alibi_from_lr_mask<SaccDataType, true>(slope,
1468  kargs.window_size_left,
1469  kargs.window_size_right,
1470  kargs.seqlen_q,
1471  kargs.seqlen_k,
1472  kargs.mask_type);
1473  }
1474  else
1475  {
1477  slope, kargs.seqlen_q, kargs.seqlen_k, AlibiMode::FROM_BOTTOM_RIGHT};
1478  }
1479  }
1480  else
1481  {
1483  }
1484  }();
1485 
1486  AttentionVariant variant;
1487  const auto variant_params = [&] {
1488  if constexpr(kHasLogitsSoftCap)
1489  {
1491  mask, kargs.scale_s, kargs.logits_soft_cap, kargs.logits_soft_cap_rcp};
1492  }
1493  else
1494  {
1495  return ck_tile::StandardAttentionParams<FmhaMask>{mask, kargs.scale_s};
1496  }
1497  }();
1498 
1499  BlockIndices block_indices{i_batch, i_nhead, i_nhead / kargs.nhead_ratio_qk};
1500 
1501  auto o_acc_tile = [&]() {
1503  {
1504  // TODO - move global load of descale to pipeline
1505  float q_descale = *(reinterpret_cast<const float*>(kargs.q_descale_ptr));
1506  float k_descale = *(reinterpret_cast<const float*>(kargs.k_descale_ptr));
1507  float v_descale = *(reinterpret_cast<const float*>(kargs.v_descale_ptr));
1508 
1509  float scale_s = kargs.scale_s * q_descale * k_descale;
1510  float scale_p =
1511  ck_tile::type_convert<float>(ck_tile::numeric<PDataType>::max());
1512  float scale_o = v_descale / scale_p;
1513 
1514  auto o_acc_element_func = [&]() {
1515  if constexpr(std::is_same_v<ODataType, ck_tile::fp8_t>)
1517  ck_tile::scales{scale_o});
1518  else
1519  return ck_tile::scales{scale_o};
1520  }();
1521  return FmhaPipeline{}(q_dram_window,
1522  identity{}, // q_element_func
1523  k_dram_window,
1524  identity{}, // k_element_func
1525  v_dram_window,
1526  identity{}, // v_element_func
1527  bias_dram_window,
1528  identity{}, // bias_element_func
1529  randval_dram_window,
1530  lse_dram_window,
1531  identity{}, // lse_element_func
1532  identity{}, // s_acc_element_func
1533  scales{scale_p}, // p_compute_element_func
1534  o_acc_element_func, // o_acc_element_func
1535  mask,
1536  position_encoding,
1537  scale_s,
1538  variant,
1539  variant_params,
1540  block_indices,
1541  smem_ptr,
1542  dropout);
1543  }
1544  else
1545  {
1546  return FmhaPipeline{}(q_dram_window,
1547  k_dram_window,
1548  v_dram_window,
1549  bias_dram_window,
1550  randval_dram_window,
1551  lse_dram_window,
1552  mask,
1553  position_encoding,
1554  kargs.scale_s,
1555  variant,
1556  variant_params,
1557  block_indices,
1558  smem_ptr,
1559  dropout);
1560  }
1561  }();
1562 
1563  // O DRAM and O DRAM window
1564  auto o_dram = [&]() {
1565  const auto o_dram_naive = make_naive_tensor_view<address_space_enum::global>(
1566  o_ptr,
1567  make_tuple(kargs.seqlen_q, kargs.hdim_v),
1568  make_tuple(kargs.stride_o, 1),
1570  number<1>{});
1571 
1572  return pad_tensor_view(
1573  o_dram_naive,
1576  }();
1577 
1578  auto o_dram_window = make_tile_window(
1579  o_dram,
1581  {i_m0, i_n1});
1582 
1583  EpiloguePipeline{}(o_dram_window, o_acc_tile, nullptr);
1584  }
1585  else
1586  {
1587  // TODO: Refine the logical here.
1588  // In Decode case
1589  // 1. we don't expect KV data reused by different ThreadGroups, bypass the cache
1590  // 2. limit the LDS usage, as we want higher occupancy
1591  // In Prefill case
1592  // 1. we expect KV data reused by different ThreadGroups, use cache
1593  // 2. use more LDS, as we want better memory latency hiding
1594  // If SplitKV off, we don't expect Q data reused by different ThreadGroups, bypass the
1595  // cache
1596  constexpr bool PrefillCase = FmhaPipeline::kM0 > 64;
1597  // divide problem
1598  const auto [i_tile_m, i_tile_n, i_nhead, i_batch] = GetTileIndex(kargs);
1599 
1600  const index_t i_m0 = i_tile_m * FmhaPipeline::kM0;
1601  const index_t i_n1 = i_tile_n * FmhaPipeline::kN1;
1602 
1603  long_index_t batch_offset_q = 0;
1604  long_index_t batch_offset_k = 0; // unused for paged-kvcache
1605  long_index_t batch_offset_v = 0; // unused for paged-kvcache
1606  long_index_t batch_offset_bias = 0;
1607  long_index_t batch_offset_lse = 0;
1608  long_index_t batch_offset_o = 0;
1609  // index_t kv_l2p_offset =
1610  // 0; // logical-to-physical offset of seqlen_k coordinate. only used for
1611  // paged-kvcache
1612 
1613  if constexpr(kIsGroupMode)
1614  {
1615  // get starting offset for each batch - use seqstart_q_ptr/seqstart_k_ptr for
1616  // physical starts
1617  const long_index_t query_start = kargs.seqstart_q_ptr[i_batch];
1618  const long_index_t key_start = kargs.seqstart_k_ptr[i_batch];
1619 
1620  batch_offset_q = query_start * kargs.stride_q;
1621  batch_offset_k = key_start * kargs.stride_k;
1622  if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
1623  {
1624  batch_offset_v = key_start * kargs.stride_v;
1625  }
1626  else
1627  {
1628  // col-major V: offset along seqlen dimension is scalar index
1629  batch_offset_v = key_start;
1630  }
1632  {
1633  batch_offset_bias = query_start * kargs.stride_bias;
1634  }
1635 
1636  // LSE layout is [nhead, total_seqlen] following the physical layout for Q/O
1637  batch_offset_lse = query_start;
1638  batch_offset_o = query_start * kargs.stride_o;
1639 
1640  // get real # queries & # keys under group mode
1641  if(kargs.seqlen_q_ptr != nullptr)
1642  {
1643  kargs.seqlen_q = kargs.seqlen_q_ptr[i_batch];
1644  }
1645  else if(kargs.cu_seqlen_q_ptr != nullptr)
1646  {
1647  kargs.seqlen_q =
1648  kargs.cu_seqlen_q_ptr[i_batch + 1] - kargs.cu_seqlen_q_ptr[i_batch];
1649  }
1650  else
1651  {
1652  kargs.seqlen_q =
1653  kargs.seqstart_q_ptr[i_batch + 1] - kargs.seqstart_q_ptr[i_batch];
1654  }
1655 
1656  // # of required blocks is different in each groups, terminate unnecessary blocks
1657  // earlier
1658  if(kargs.seqlen_q <= i_m0)
1659  {
1660  return;
1661  }
1662 
1663  if(kargs.seqlen_k_ptr != nullptr)
1664  {
1665  kargs.seqlen_k = kargs.seqlen_k_ptr[i_batch];
1666  }
1667  else if(kargs.cu_seqlen_k_ptr != nullptr)
1668  {
1669  kargs.seqlen_k =
1670  kargs.cu_seqlen_k_ptr[i_batch + 1] - kargs.cu_seqlen_k_ptr[i_batch];
1671  }
1672  else
1673  {
1674  kargs.seqlen_k =
1675  kargs.seqstart_k_ptr[i_batch + 1] - kargs.seqstart_k_ptr[i_batch];
1676  }
1677  }
1678  else
1679  {
1680  batch_offset_q = static_cast<long_index_t>(i_batch) * kargs.batch_stride_q;
1681  batch_offset_k = static_cast<long_index_t>(i_batch) * kargs.batch_stride_k;
1682  batch_offset_v = static_cast<long_index_t>(i_batch) * kargs.batch_stride_v;
1683  if constexpr(kStoreLSE)
1684  {
1685  batch_offset_lse = static_cast<long_index_t>(i_batch) * kargs.batch_stride_lse;
1686  }
1687  batch_offset_o = static_cast<long_index_t>(i_batch) * kargs.batch_stride_o;
1688 
1690  {
1691  batch_offset_bias =
1692  static_cast<long_index_t>(i_batch) * kargs.batch_stride_bias;
1693  }
1694 
1695  // If cumulative seqlen pointers are provided, override per-batch effective lengths
1696  if(kargs.cu_seqlen_q_ptr != nullptr)
1697  {
1698  kargs.seqlen_q =
1699  kargs.cu_seqlen_q_ptr[i_batch + 1] - kargs.cu_seqlen_q_ptr[i_batch];
1700  }
1701  if(kargs.cu_seqlen_k_ptr != nullptr)
1702  {
1703  kargs.seqlen_k =
1704  kargs.cu_seqlen_k_ptr[i_batch + 1] - kargs.cu_seqlen_k_ptr[i_batch];
1705  }
1706  }
1707 
1708  // for simplicity, batch stride we just modify the pointer
1709  const index_t i_nhead_k = i_nhead / kargs.nhead_ratio_qk;
1710 
1711  const QDataType* q_ptr = reinterpret_cast<const QDataType*>(kargs.q_ptr) +
1712  static_cast<long_index_t>(i_nhead) * kargs.nhead_stride_q +
1713  batch_offset_q;
1714  const KDataType* k_ptr = reinterpret_cast<const KDataType*>(kargs.k_ptr) +
1715  static_cast<long_index_t>(i_nhead_k) * kargs.nhead_stride_k +
1716  batch_offset_k;
1717  const VDataType* v_ptr = reinterpret_cast<const VDataType*>(kargs.v_ptr) +
1718  static_cast<long_index_t>(i_nhead_k) * kargs.nhead_stride_v +
1719  batch_offset_v;
1720 
1721  ODataType* o_ptr = reinterpret_cast<ODataType*>(kargs.o_ptr) +
1722  static_cast<long_index_t>(i_nhead) * kargs.nhead_stride_o +
1723  batch_offset_o;
1724 
1725  // Q/K/V DRAM and DRAM window
1726  const auto q_dram = [&] {
1727  const auto q_dram_naive = [&] {
1728  {
1729  return make_naive_tensor_view<address_space_enum::global,
1730  memory_operation_enum::set,
1732  q_ptr,
1733  make_tuple(kargs.seqlen_q, kargs.hdim_q),
1734  make_tuple(kargs.stride_q, 1),
1736  number<1>{});
1737  }
1738  }();
1739 
1740  if constexpr(FmhaPipeline::kQLoadOnce)
1741  {
1742  const auto seqlen_q = kargs.seqlen_q;
1743  const auto q_dram_pad = pad_tensor_view(
1744  q_dram_naive,
1747 #if CK_TILE_FMHA_HANDLE_XOR_LENGTH_FOLD
1748  constexpr index_t LDSLayerSize = 256 / sizeof(QDataType);
1749  constexpr index_t XorLengthFold = LDSLayerSize / (FmhaPipeline::kQKHeaddim);
1750 
1751  if constexpr(XorLengthFold > 1)
1752  {
1753  const auto q_dram_unmerged = transform_tensor_view(
1754  q_dram_pad,
1755  make_tuple(
1757  make_tuple(seqlen_q / XorLengthFold, XorLengthFold)),
1761 
1762  const auto q_dram_merged = transform_tensor_view(
1763  q_dram_unmerged,
1764  make_tuple(make_pass_through_transform(seqlen_q / XorLengthFold),
1766  XorLengthFold, number<FmhaPipeline::kQKHeaddim>{}))),
1769 
1770  const auto q_dram_unmerged_xor = transform_tensor_view(
1771  q_dram_merged,
1772  make_tuple(make_pass_through_transform(seqlen_q / XorLengthFold),
1778 
1779  const auto q_dram_permuted = transform_tensor_view(
1780  q_dram_unmerged_xor,
1781  make_tuple(
1783  make_tuple(seqlen_q / XorLengthFold,
1788 
1789  const auto q_dram_tmp = transform_tensor_view(
1790  q_dram_permuted,
1791  make_tuple(
1792  make_pass_through_transform(seqlen_q / XorLengthFold),
1795  number<FmhaPipeline::kQKHeaddim /
1796  FmhaPipeline::kAlignmentQ>{})),
1800 
1801  return transform_tensor_view(
1802  q_dram_tmp,
1803  make_tuple(
1805  make_tuple(seqlen_q / XorLengthFold, number<XorLengthFold>{})),
1811  }
1812  else
1813 #endif // CK_TILE_FMHA_HANDLE_XOR_LENGTH_FOLD
1814  {
1815  const auto q_dram_unmerged = transform_tensor_view(
1816  q_dram_pad,
1817  make_tuple(
1818  make_pass_through_transform(seqlen_q),
1824 
1825  const auto q_dram_permuted = transform_tensor_view(
1826  q_dram_unmerged,
1827  make_tuple(
1828  make_xor_transform(make_tuple(seqlen_q,
1829  number<FmhaPipeline::kQKHeaddim /
1830  FmhaPipeline::kAlignmentQ>{})),
1834 
1835  return transform_tensor_view(
1836  q_dram_permuted,
1837  make_tuple(
1838  make_pass_through_transform(seqlen_q),
1844  }
1845  }
1846  else
1847  {
1848  return pad_tensor_view(
1849  q_dram_naive,
1852  }
1853  }();
1854 
1855  const auto make_k_dram = [&](const KDataType* data, index_t height) {
1856  const auto k_dram_naive = make_naive_tensor_view<address_space_enum::global>(
1857  data, // will update this pointer if using paged-kvcache
1858  make_tuple(height, kargs.hdim_q),
1859  make_tuple(kargs.stride_k, 1),
1861  number<1>{});
1862 
1863  const auto k_dram_pad = pad_tensor_view(
1864  k_dram_naive,
1867 
1868  constexpr auto kDramTileK =
1869  FmhaPipeline::kKLoadOnce ? FmhaPipeline::kQKHeaddim : FmhaPipeline::kK0;
1870 
1871 #if CK_TILE_FMHA_HANDLE_XOR_LENGTH_FOLD
1872  constexpr index_t LDSLayerSize = 256 / sizeof(KDataType);
1873  constexpr index_t XorLengthFold = LDSLayerSize / (FmhaPipeline::kQKHeaddim);
1874 
1875  if constexpr(XorLengthFold > 1)
1876  {
1877  const auto k_dram_unmerged = transform_tensor_view(
1878  k_dram_pad,
1880  make_tuple(height / XorLengthFold, XorLengthFold)),
1884 
1885  const auto k_dram_merged = transform_tensor_view(
1886  k_dram_unmerged,
1887  make_tuple(make_pass_through_transform(height / XorLengthFold),
1889  XorLengthFold, number<FmhaPipeline::kQKHeaddim>{}))),
1892 
1893  const auto k_dram_unmerged_xor = transform_tensor_view(
1894  k_dram_merged,
1895  make_tuple(make_pass_through_transform(height / XorLengthFold),
1901 
1902  const auto k_dram_permuted = transform_tensor_view(
1903  k_dram_unmerged_xor,
1904  make_tuple(
1906  make_tuple(height / XorLengthFold,
1911 
1912  const auto k_dram_tmp = transform_tensor_view(
1913  k_dram_permuted,
1914  make_tuple(
1915  make_pass_through_transform(height / XorLengthFold),
1918  number<FmhaPipeline::kQKHeaddim / FmhaPipeline::kAlignmentK>{})),
1922 
1923  return transform_tensor_view(
1924  k_dram_tmp,
1925  make_tuple(
1927  make_tuple(height / XorLengthFold, number<XorLengthFold>{})),
1933  }
1934  else
1935 #endif // CK_TILE_FMHA_HANDLE_XOR_LENGTH_FOLD
1936  {
1937  const auto k_dram_unmerged = transform_tensor_view(
1938  k_dram_pad,
1941  make_tuple(number<FmhaPipeline::kQKHeaddim / kDramTileK /
1942  FmhaPipeline::kAlignmentK>{},
1943  number<kDramTileK / FmhaPipeline::kAlignmentK>{},
1947 
1948  const auto k_dram_permuted = transform_tensor_view(
1949  k_dram_unmerged,
1950  make_tuple(
1954  number<FmhaPipeline::kQKHeaddim / kDramTileK /
1955  FmhaPipeline::kAlignmentK>{}),
1959 
1960  return transform_tensor_view(
1961  k_dram_permuted,
1964  make_tuple(number<FmhaPipeline::kQKHeaddim / kDramTileK /
1965  FmhaPipeline::kAlignmentK>{},
1966  number<kDramTileK / FmhaPipeline::kAlignmentK>{},
1970  }
1971  };
1972  const auto k_dram = [&]() {
1973  {
1974  return make_k_dram(k_ptr, kargs.seqlen_k);
1975  }
1976  }();
1977 
1978  const auto make_v_dram = [&](const VDataType* data, index_t length) {
1979  const auto v_dram_naive = make_naive_tensor_view<address_space_enum::global>(
1980  data, // will update this pointer if using paged-kvcache
1981  make_tuple(length, kargs.hdim_v),
1982  make_tuple(kargs.stride_v, 1),
1984  number<1>{});
1985 
1986  // TODO: Add kVHeadDim
1987  constexpr index_t XorGroupSize =
1988  FmhaPipeline::Problem::BlockFmhaShape::Gemm1WarpTile::at(number<0>{});
1989 
1990  const auto v_dram_pad = pad_tensor_view(
1991  v_dram_naive,
1994 
1995 #if CK_TILE_FMHA_HANDLE_XOR_LENGTH_FOLD
1996  constexpr index_t LDSLayerSize = 256 / sizeof(VDataType);
1997  constexpr index_t XorLengthFold = LDSLayerSize / (FmhaPipeline::kQKHeaddim);
1998 
1999  if constexpr(XorLengthFold > 1)
2000  {
2001  const auto v_dram_unmerged = transform_tensor_view(
2002  v_dram_pad,
2004  make_tuple(length / XorLengthFold, XorLengthFold)),
2008 
2009  const auto v_dram_merged = transform_tensor_view(
2010  v_dram_unmerged,
2011  make_tuple(make_pass_through_transform(length / XorLengthFold),
2013  XorLengthFold, number<FmhaPipeline::kQKHeaddim>{}))),
2016 
2017  const auto v_dram_unmerged_xor = transform_tensor_view(
2018  v_dram_merged,
2019  make_tuple(
2020  make_pass_through_transform(length / XorLengthFold),
2022  number<XorGroupSize>{}))),
2025 
2026  const auto v_dram_permuted = transform_tensor_view(
2027  v_dram_unmerged_xor,
2028  make_tuple(
2029  make_xor_transform(make_tuple(length / XorLengthFold,
2034 
2035  const auto v_dram_tmp = transform_tensor_view(
2036  v_dram_permuted,
2037  make_tuple(make_pass_through_transform(length / XorLengthFold),
2040  number<FmhaPipeline::kQKHeaddim / XorGroupSize>{})),
2044 
2045  return transform_tensor_view(
2046  v_dram_tmp,
2048  make_tuple(length / XorLengthFold, number<XorLengthFold>{})),
2051  number<XorGroupSize>{}))),
2054  }
2055  else
2056 #endif // CK_TILE_FMHA_HANDLE_XOR_LENGTH_FOLD
2057  {
2058  const auto v_dram_unmerged = transform_tensor_view(
2059  v_dram_pad,
2063  number<XorGroupSize>{}))),
2066 
2067  const auto v_dram_permuted = transform_tensor_view(
2068  v_dram_unmerged,
2074 
2075  return transform_tensor_view(
2076  v_dram_permuted,
2080  number<XorGroupSize>{}))),
2083  }
2084  };
2085 
2086  const auto v_dram = [&]() {
2087  {
2088  return make_v_dram(v_ptr, kargs.seqlen_k);
2089  }
2090  }();
2091 
2092  auto q_dram_window = make_tile_window(
2093  q_dram,
2094  [&]() {
2095  if constexpr(FmhaPipeline::kQLoadOnce)
2098  else
2100  }(),
2101  {i_m0, 0});
2102 
2103  auto k_dram_window = make_tile_window(
2104  k_dram,
2106  {0, 0});
2107 
2108  auto v_dram_window = make_tile_window(
2109  v_dram,
2111  {0, 0});
2112 
2115  const auto bias_dram_window = [&, i_nhead_ = i_nhead]() {
2116  constexpr auto bias_dram_window_lengths =
2119  {
2120  const BiasDataType* bias_ptr =
2121  reinterpret_cast<const BiasDataType*>(kargs.bias_ptr) +
2122  static_cast<long_index_t>(i_nhead_) * kargs.nhead_stride_bias +
2123  batch_offset_bias;
2124 
2125  const auto bias_dram = [&]() {
2126  const auto bias_dram_naive =
2127  make_naive_tensor_view<address_space_enum::global>(
2128  bias_ptr,
2129  make_tuple(kargs.seqlen_q, kargs.seqlen_k),
2130  make_tuple(kargs.stride_bias, 1),
2132  number<1>{});
2133 
2134  return pad_tensor_view(bias_dram_naive,
2135  bias_dram_window_lengths,
2137  }();
2138 
2139  return make_tile_window(bias_dram, bias_dram_window_lengths, {i_m0, 0});
2140  }
2141  else
2142  {
2143  return make_null_tile_window(bias_dram_window_lengths);
2144  }
2145  }();
2146 
2147  // lse acc
2148  auto lse_dram_window = [&, i_nhead_ = i_nhead]() {
2149  constexpr auto lse_dram_window_lengths = make_tuple(number<FmhaPipeline::kM0>{});
2150  if constexpr(kStoreLSE)
2151  {
2152  LSEDataType* lse_ptr =
2153  reinterpret_cast<LSEDataType*>(kargs.lse_ptr) +
2154  static_cast<long_index_t>(i_nhead_) * kargs.nhead_stride_lse +
2155  batch_offset_lse;
2156 
2157  const auto lse_dram = [&] {
2158  const auto lse_dram_naive = [&] {
2159  {
2160  return make_naive_tensor_view<address_space_enum::global>(
2161  lse_ptr,
2162  make_tuple(kargs.seqlen_q),
2163  make_tuple(1),
2164  number<1>{},
2165  number<1>{});
2166  }
2167  }();
2168  return pad_tensor_view(
2169  lse_dram_naive, lse_dram_window_lengths, sequence<kPadSeqLenQ>{});
2170  }();
2171 
2172  return make_tile_window(lse_dram, lse_dram_window_lengths, {i_m0});
2173  }
2174  else
2175  {
2176  return make_null_tile_window(lse_dram_window_lengths);
2177  }
2178  }();
2179 
2180  FmhaMask mask = [&]() {
2181  if constexpr(kHasMask)
2182  return ck_tile::make_generic_attention_mask_from_lr_window<FmhaMask>(
2183  kargs.window_size_left,
2184  kargs.window_size_right,
2185  kargs.seqlen_q,
2186  kargs.seqlen_k,
2188  else
2189  return FmhaMask{kargs.seqlen_q, kargs.seqlen_k};
2190  }();
2191 
2192  // WA i_batch capture structure binding before c++20
2193  auto position_encoding = [&, i_batch_ = i_batch, i_nhead_ = i_nhead]() {
2194  if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI)
2195  {
2196  // data loading, shared by entire wg
2197  // TODO: how to use s_read?
2198  SaccDataType slope =
2199  *(reinterpret_cast<const SaccDataType*>(kargs.alibi_slope_ptr) +
2200  i_batch_ * kargs.alibi_slope_stride + i_nhead_);
2201 #if CK_TILE_FMHA_FWD_FAST_EXP2
2202  slope *= ck_tile::log2e_v<>;
2203 #endif
2204  if constexpr(kHasMask)
2205  {
2206  return make_alibi_from_lr_mask<SaccDataType, true, 32>(
2207  slope,
2208  kargs.window_size_left,
2209  kargs.window_size_right,
2210  kargs.seqlen_q,
2211  kargs.seqlen_k,
2212  kargs.mask_type);
2213  }
2214  else
2215  {
2217  slope, kargs.seqlen_q, kargs.seqlen_k, AlibiMode::FROM_BOTTOM_RIGHT};
2218  }
2219  }
2220  else
2221  {
2223  }
2224  }();
2225 
2226  auto o_acc_tile = [&]() {
2227  if constexpr(PrefillCase)
2228  {
2229  // allocate double lds
2230  // add __restrict__ here to avoid aliasing
2231  __shared__ char smem_ptrk0
2232  [FmhaPipeline::Policy::template GetSmemSizeK<typename FmhaPipeline::Problem,
2233  true>()];
2234  __shared__ char smem_ptrk1
2235  [FmhaPipeline::Policy::template GetSmemSizeK<typename FmhaPipeline::Problem,
2236  true>()];
2237  __shared__ char smem_ptrv0[FmhaPipeline::Policy::template GetSmemSizeV<
2238  typename FmhaPipeline::Problem>()];
2239  __shared__ char smem_ptrv1[FmhaPipeline::Policy::template GetSmemSizeV<
2240  typename FmhaPipeline::Problem>()];
2241 
2242  return FmhaPipeline{}(q_dram_window,
2243  k_dram_window,
2244  v_dram_window,
2245  bias_dram_window,
2246  lse_dram_window,
2247  mask,
2248  position_encoding,
2249  kargs.scale_s,
2250  smem_ptrk0,
2251  smem_ptrk1,
2252  smem_ptrv0,
2253  smem_ptrv1);
2254  }
2255  else
2256  {
2257  __shared__ char smem_ptr[GetSmemSize()];
2258  return FmhaPipeline{}(q_dram_window,
2259  k_dram_window,
2260  v_dram_window,
2261  bias_dram_window,
2262  lse_dram_window,
2263  mask,
2264  position_encoding,
2265  kargs.scale_s,
2266  smem_ptr);
2267  }
2268  }();
2269 
2270  // Oacc DRAM and Oacc DRAM window
2271  auto o_dram = [&] {
2272  const auto o_dram_naive = [&] {
2273  {
2274  return make_naive_tensor_view<address_space_enum::global>(
2275  o_ptr,
2276  make_tuple(kargs.seqlen_q, kargs.hdim_v),
2277  make_tuple(kargs.stride_o, 1),
2279  number<1>{});
2280  }
2281  }();
2282 
2283  return pad_tensor_view(
2284  o_dram_naive,
2287  }();
2288 
2289  auto o_dram_window = make_tile_window(
2290  o_dram,
2292  {i_m0, i_n1});
2293 
2294  EpiloguePipeline{}(o_dram_window, o_acc_tile, nullptr);
2295  }
2296  }
2297 };
2298 
2299 } // namespace ck_tile
#define CK_TILE_DEVICE
Definition: config.hpp:45
#define CK_TILE_HOST
Definition: config.hpp:44
#define CK_TILE_HOST_DEVICE
Definition: config.hpp:46
__host__ constexpr __device__ T max(T x)
Definition: math.hpp:84
__host__ T floor(T x)
Definition: math_v2.hpp:367
Definition: cluster_descriptor.hpp:13
constexpr CK_TILE_DEVICE auto make_null_tile_window(const WindowLengths &window_lengths)
Definition: null_tile_window.hpp:66
constexpr CK_TILE_HOST_DEVICE auto integer_divide_ceil(X x, Y y)
Definition: math.hpp:149
__device__ uint32_t amd_wave_read_first_lane(uint16_t v)
Definition: amd_buffer_addressing.hpp:35
constexpr CK_TILE_HOST_DEVICE auto transform_tensor_view(const OldTensorView &old_tensor_view, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition: tensor_view.hpp:526
int32_t index_t
Definition: integer.hpp:9
constexpr CK_TILE_HOST_DEVICE auto pad_tensor_view(const TensorView &tensor_view, const TileLengths &tile_lengths, DoPads)
Definition: tensor_view.hpp:545
constexpr CK_TILE_HOST_DEVICE auto make_pass_through_transform(const LowLength &low_length)
Definition: coordinate_transform.hpp:1633
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition: type_traits.hpp:21
constexpr CK_TILE_HOST_DEVICE auto make_unmerge_transform(const UpLengths &up_lengths, bool_constant< Use24BitIntegerCalculation >=bool_constant< false >{})
Definition: coordinate_transform.hpp:1697
constexpr CK_TILE_HOST_DEVICE auto make_merge_transform_v3_division_mod(const LowLengths &low_lengths)
Definition: coordinate_transform.hpp:1684
int64_t long_index_t
Definition: integer.hpp:11
int32_t int32_t
Definition: integer.hpp:10
constexpr CK_TILE_DEVICE auto make_tile_window(null_tensor_view, const WindowLengths &window_lengths, const multi_index< WindowLengths::size()> &, Ts &&...)
Definition: null_tile_window.hpp:75
constexpr CK_TILE_HOST_DEVICE auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:360
constexpr CK_TILE_HOST_DEVICE auto make_xor_transform(const LowLengths &low_lengths)
Definition: coordinate_transform.hpp:1737
GenericAttentionMaskEnum
Definition: block_masking.hpp:11
constexpr CK_TILE_HOST_DEVICE auto make_naive_tensor_view(DataType *__restrict__ p, const tuple< Lengths... > &lengths, const tuple< Strides... > &strides, number< GuaranteedLastDimensionVectorLength >=number<-1 >{}, number< GuaranteedLastDimensionVectorStride >=number<-1 >{})
Definition: tensor_view.hpp:486
__host__ __device__ composes(Ts &&...) -> composes< remove_cvref_t< Ts >... >
FIXME: create macro to replace 'host device' and nothing more.
constexpr CK_TILE_HOST_DEVICE T max(T x)
Definition: math.hpp:161
typename conditional< predicate, X, Y >::type conditional_t
Definition: functional.hpp:115
unsigned char uint8_t
Definition: stdint.h:124
unsigned __int64 uint64_t
Definition: stdint.h:136
Definition: block_position_encoding.hpp:48
Definition: block_dropout.hpp:53
const float rp_undrop
Definition: block_dropout.hpp:371
Definition: block_position_encoding.hpp:137
Definition: fmha_fwd_kernel.hpp:290
ck_tile::index_t kv_head_idx
Definition: fmha_fwd_kernel.hpp:293
ck_tile::index_t batch_idx
Definition: fmha_fwd_kernel.hpp:291
ck_tile::index_t qo_head_idx
Definition: fmha_fwd_kernel.hpp:292
Definition: fmha_fwd_kernel.hpp:149
ck_tile::index_t alibi_slope_stride
Definition: fmha_fwd_kernel.hpp:152
const void * alibi_slope_ptr
Definition: fmha_fwd_kernel.hpp:151
Definition: fmha_fwd_kernel.hpp:144
ck_tile::index_t batch_stride_bias
Definition: fmha_fwd_kernel.hpp:145
Definition: fmha_fwd_kernel.hpp:226
ck_tile::index_t batch_stride_randval
Definition: fmha_fwd_kernel.hpp:227
Definition: fmha_fwd_kernel.hpp:249
ck_tile::index_t batch_stride_o
Definition: fmha_fwd_kernel.hpp:253
const int32_t * cu_seqlen_k_ptr
Definition: fmha_fwd_kernel.hpp:258
ck_tile::index_t batch_stride_q
Definition: fmha_fwd_kernel.hpp:250
ck_tile::index_t batch_stride_k
Definition: fmha_fwd_kernel.hpp:251
const int32_t * cu_seqlen_q_ptr
Definition: fmha_fwd_kernel.hpp:257
ck_tile::index_t batch_stride_v
Definition: fmha_fwd_kernel.hpp:252
Definition: fmha_fwd_kernel.hpp:137
const void * bias_ptr
Definition: fmha_fwd_kernel.hpp:138
ck_tile::index_t stride_bias
Definition: fmha_fwd_kernel.hpp:139
ck_tile::index_t nhead_stride_bias
Definition: fmha_fwd_kernel.hpp:140
Definition: fmha_fwd_kernel.hpp:191
void init_dropout(float p_drop, const uint64_t *seed_ptr, const uint64_t *offset_ptr)
Definition: fmha_fwd_kernel.hpp:204
float rp_undrop
Definition: fmha_fwd_kernel.hpp:216
ck_tile::index_t stride_randval
Definition: fmha_fwd_kernel.hpp:221
ck_tile::index_t nhead_stride_randval
Definition: fmha_fwd_kernel.hpp:222
void * rand_val_ptr
Definition: fmha_fwd_kernel.hpp:219
void init_dropout(float p_drop, uint64_t seed, uint64_t offset)
Definition: fmha_fwd_kernel.hpp:192
bool is_store_randval
Definition: fmha_fwd_kernel.hpp:218
uint8_t p_undrop_in_uint8_t
Definition: fmha_fwd_kernel.hpp:217
Definition: fmha_fwd_kernel.hpp:86
ck_tile::index_t nhead_stride_k
Definition: fmha_fwd_kernel.hpp:109
float scale_s
Definition: fmha_fwd_kernel.hpp:101
ck_tile::index_t seqlen_k
Definition: fmha_fwd_kernel.hpp:93
ck_tile::index_t nhead_stride_o
Definition: fmha_fwd_kernel.hpp:111
ck_tile::index_t nhead_ratio_qk
Definition: fmha_fwd_kernel.hpp:100
ck_tile::index_t num_head_q
Definition: fmha_fwd_kernel.hpp:97
ck_tile::index_t hdim_q
Definition: fmha_fwd_kernel.hpp:94
const void * v_ptr
Definition: fmha_fwd_kernel.hpp:89
void * o_ptr
Definition: fmha_fwd_kernel.hpp:90
const void * k_ptr
Definition: fmha_fwd_kernel.hpp:88
ck_tile::index_t nhead_stride_q
Definition: fmha_fwd_kernel.hpp:108
ck_tile::index_t stride_k
Definition: fmha_fwd_kernel.hpp:104
ck_tile::index_t stride_o
Definition: fmha_fwd_kernel.hpp:106
ck_tile::index_t stride_v
Definition: fmha_fwd_kernel.hpp:105
ck_tile::index_t hdim_v
Definition: fmha_fwd_kernel.hpp:95
ck_tile::index_t nhead_stride_v
Definition: fmha_fwd_kernel.hpp:110
const void * q_ptr
Definition: fmha_fwd_kernel.hpp:87
ck_tile::index_t seqlen_q
Definition: fmha_fwd_kernel.hpp:92
ck_tile::index_t stride_q
Definition: fmha_fwd_kernel.hpp:103
Definition: fmha_fwd_kernel.hpp:170
ck_tile::index_t batch_stride_lse
Definition: fmha_fwd_kernel.hpp:173
void * lse_ptr
Definition: fmha_fwd_kernel.hpp:171
ck_tile::index_t nhead_stride_lse
Definition: fmha_fwd_kernel.hpp:172
Definition: fmha_fwd_kernel.hpp:163
const void * v_descale_ptr
Definition: fmha_fwd_kernel.hpp:166
const void * k_descale_ptr
Definition: fmha_fwd_kernel.hpp:165
const void * q_descale_ptr
Definition: fmha_fwd_kernel.hpp:164
Definition: fmha_fwd_kernel.hpp:177
bool is_drop_seed_offset_from_host
Definition: fmha_fwd_kernel.hpp:187
ValueOrPointer< uint64_t > drop_seed
Definition: fmha_fwd_kernel.hpp:185
ValueOrPointer< uint64_t > drop_offset
Definition: fmha_fwd_kernel.hpp:186
Definition: fmha_fwd_kernel.hpp:79
Definition: fmha_fwd_kernel.hpp:276
const int32_t * seqlen_q_ptr
Definition: fmha_fwd_kernel.hpp:279
const int32_t * seqstart_q_ptr
Definition: fmha_fwd_kernel.hpp:277
const int32_t * seqlen_k_ptr
Definition: fmha_fwd_kernel.hpp:280
const int32_t * cu_seqlen_k_ptr
Definition: fmha_fwd_kernel.hpp:284
const int32_t * cu_seqlen_q_ptr
Definition: fmha_fwd_kernel.hpp:283
const int32_t * seqstart_k_ptr
Definition: fmha_fwd_kernel.hpp:278
Definition: fmha_fwd_kernel.hpp:115
float logits_soft_cap
Definition: fmha_fwd_kernel.hpp:132
float logits_soft_cap_rcp
Definition: fmha_fwd_kernel.hpp:133
void init_logits_soft_cap(float logits_soft_cap_)
Definition: fmha_fwd_kernel.hpp:118
Definition: fmha_fwd_kernel.hpp:156
ck_tile::GenericAttentionMaskEnum mask_type
Definition: fmha_fwd_kernel.hpp:159
ck_tile::index_t window_size_right
Definition: fmha_fwd_kernel.hpp:158
ck_tile::index_t window_size_left
Definition: fmha_fwd_kernel.hpp:158
Definition: fmha_fwd_kernel.hpp:231
ck_tile::index_t min_seqlen_q
Definition: fmha_fwd_kernel.hpp:232
Definition: fmha_fwd_kernel.hpp:28
static constexpr bool kHasDropout
Definition: fmha_fwd_kernel.hpp:58
static constexpr bool kIsAvailable
Definition: fmha_fwd_kernel.hpp:72
static constexpr bool kStoreLSE
Definition: fmha_fwd_kernel.hpp:57
static constexpr CK_TILE_HOST std::enable_if_t< Cond, Kargs > MakeKargs(const void *q_ptr, const void *k_ptr, const void *v_ptr, const void *bias_ptr, const void *q_descale_ptr, const void *k_descale_ptr, const void *v_descale_ptr, void *rand_val_ptr, void *lse_ptr, void *o_ptr, ck_tile::index_t seqlen_q, ck_tile::index_t seqlen_k, ck_tile::index_t hdim_q, ck_tile::index_t hdim_v, ck_tile::index_t num_head_q, ck_tile::index_t nhead_ratio_qk, float scale_s, float logits_soft_cap, ck_tile::index_t stride_q, ck_tile::index_t stride_k, ck_tile::index_t stride_v, ck_tile::index_t stride_bias, ck_tile::index_t stride_randval, ck_tile::index_t stride_o, ck_tile::index_t nhead_stride_q, ck_tile::index_t nhead_stride_k, ck_tile::index_t nhead_stride_v, ck_tile::index_t nhead_stride_bias, ck_tile::index_t nhead_stride_randval, ck_tile::index_t nhead_stride_lse, ck_tile::index_t nhead_stride_o, ck_tile::index_t batch_stride_q, ck_tile::index_t batch_stride_k, ck_tile::index_t batch_stride_v, ck_tile::index_t batch_stride_bias, ck_tile::index_t batch_stride_randval, ck_tile::index_t batch_stride_lse, ck_tile::index_t batch_stride_o, ck_tile::index_t window_size_left, ck_tile::index_t window_size_right, ck_tile::index_t mask_type, float p_drop, bool s_randval, const std::tuple< const void *, const void * > &drop_seed_offset, const void *cu_seqlen_q_ptr=nullptr, const void *cu_seqlen_k_ptr=nullptr)
Definition: fmha_fwd_kernel.hpp:543
ck_tile::remove_cvref_t< typename FmhaPipeline::KDataType > KDataType
Definition: fmha_fwd_kernel.hpp:38
std::conditional_t< kIsGroupMode, FmhaFwdGroupModeKargs, FmhaFwdBatchModeKargs > Kargs
Definition: fmha_fwd_kernel.hpp:287
static constexpr ck_tile::index_t kBlockPerCu
Definition: fmha_fwd_kernel.hpp:33
ck_tile::remove_cvref_t< typename FmhaPipeline::ODataType > ODataType
Definition: fmha_fwd_kernel.hpp:45
ck_tile::remove_cvref_t< typename FmhaPipeline::VLayout > VLayout
Definition: fmha_fwd_kernel.hpp:48
static constexpr ck_tile::index_t kBlockSize
Definition: fmha_fwd_kernel.hpp:31
ck_tile::remove_cvref_t< typename FmhaPipeline::BiasDataType > BiasDataType
Definition: fmha_fwd_kernel.hpp:41
static constexpr CK_TILE_HOST auto GridSize(ck_tile::index_t batch_size_, ck_tile::index_t nhead_, ck_tile::index_t seqlen_q_, ck_tile::index_t hdim_v_, bool has_padded_seqlen_k=false)
Definition: fmha_fwd_kernel.hpp:964
static constexpr CK_TILE_HOST std::enable_if_t< Cond, Kargs > MakeKargs(const void *q_ptr, const void *k_ptr, const void *v_ptr, const void *bias_ptr, const void *q_descale_ptr, const void *k_descale_ptr, const void *v_descale_ptr, void *rand_val_ptr, void *lse_ptr, void *o_ptr, ck_tile::index_t seqlen_q, ck_tile::index_t seqlen_k, ck_tile::index_t hdim_q, ck_tile::index_t hdim_v, ck_tile::index_t num_head_q, ck_tile::index_t nhead_ratio_qk, float scale_s, float logits_soft_cap, ck_tile::index_t stride_q, ck_tile::index_t stride_k, ck_tile::index_t stride_v, ck_tile::index_t stride_bias, ck_tile::index_t stride_randval, ck_tile::index_t stride_o, ck_tile::index_t nhead_stride_q, ck_tile::index_t nhead_stride_k, ck_tile::index_t nhead_stride_v, ck_tile::index_t nhead_stride_bias, ck_tile::index_t nhead_stride_randval, ck_tile::index_t nhead_stride_lse, ck_tile::index_t nhead_stride_o, ck_tile::index_t batch_stride_q, ck_tile::index_t batch_stride_k, ck_tile::index_t batch_stride_v, ck_tile::index_t batch_stride_bias, ck_tile::index_t batch_stride_randval, ck_tile::index_t batch_stride_lse, ck_tile::index_t batch_stride_o, ck_tile::index_t window_size_left, ck_tile::index_t window_size_right, ck_tile::index_t mask_type, float p_drop, bool s_randval, const std::tuple< uint64_t, uint64_t > &drop_seed_offset, const void *cu_seqlen_q_ptr=nullptr, const void *cu_seqlen_k_ptr=nullptr)
Definition: fmha_fwd_kernel.hpp:444
ck_tile::remove_cvref_t< typename FmhaPipeline::VDataType > VDataType
Definition: fmha_fwd_kernel.hpp:39
static constexpr CK_TILE_HOST std::enable_if_t< Cond, Kargs > MakeKargsImpl(const void *q_ptr, const void *k_ptr, const void *v_ptr, const void *bias_ptr, const void *q_descale_ptr, const void *k_descale_ptr, const void *v_descale_ptr, void *rand_val_ptr, void *lse_ptr, void *o_ptr, const void *seqstart_q_ptr, const void *seqstart_k_ptr, const void *seqlen_q_ptr, const void *seqlen_k_ptr, ck_tile::index_t hdim_q, ck_tile::index_t hdim_v, ck_tile::index_t num_head_q, ck_tile::index_t nhead_ratio_qk, float scale_s, float logits_soft_cap, ck_tile::index_t stride_q, ck_tile::index_t stride_k, ck_tile::index_t stride_v, ck_tile::index_t stride_bias, ck_tile::index_t stride_randval, ck_tile::index_t stride_o, ck_tile::index_t nhead_stride_q, ck_tile::index_t nhead_stride_k, ck_tile::index_t nhead_stride_v, ck_tile::index_t nhead_stride_bias, ck_tile::index_t nhead_stride_randval, ck_tile::index_t nhead_stride_lse, ck_tile::index_t nhead_stride_o, ck_tile::index_t window_size_left, ck_tile::index_t window_size_right, ck_tile::index_t mask_type, ck_tile::index_t min_seqlen_q, float p_drop, bool s_randval, std::variant< std::pair< uint64_t, uint64_t >, std::pair< const void *, const void * >> drop_seed_offset, const void *cu_seqlen_q_ptr=nullptr, const void *cu_seqlen_k_ptr=nullptr)
Definition: fmha_fwd_kernel.hpp:641
static constexpr auto QScaleEnum
Definition: fmha_fwd_kernel.hpp:59
static constexpr ck_tile::index_t kBlockPerCuInput
Definition: fmha_fwd_kernel.hpp:35
static constexpr auto BiasEnum
Definition: fmha_fwd_kernel.hpp:56
static constexpr bool kPadHeadDimV
Definition: fmha_fwd_kernel.hpp:54
static constexpr CK_TILE_DEVICE auto GetTileIndex(const Kargs &kargs)
Definition: fmha_fwd_kernel.hpp:989
static constexpr bool kSkipMinSeqlenQ
Definition: fmha_fwd_kernel.hpp:60
static constexpr std::string_view kPipelineName
Definition: fmha_fwd_kernel.hpp:74
ck_tile::remove_cvref_t< typename FmhaPipeline::PDataType > PDataType
Definition: fmha_fwd_kernel.hpp:40
ck_tile::remove_cvref_t< typename FmhaPipeline::LSEDataType > LSEDataType
Definition: fmha_fwd_kernel.hpp:44
ck_tile::remove_cvref_t< typename FmhaPipeline::QDataType > QDataType
Definition: fmha_fwd_kernel.hpp:37
static CK_TILE_HOST dim3 BlockSize()
Definition: fmha_fwd_kernel.hpp:1054
CK_TILE_DEVICE void run_(Kargs kargs) const
Definition: fmha_fwd_kernel.hpp:1077
ck_tile::remove_cvref_t< typename FmhaPipeline::AttentionVariant > AttentionVariant
Definition: fmha_fwd_kernel.hpp:62
static constexpr CK_TILE_HOST_DEVICE ck_tile::index_t GetSmemSize()
Definition: fmha_fwd_kernel.hpp:1066
static constexpr CK_TILE_HOST std::enable_if_t< Cond, Kargs > MakeKargs(const void *q_ptr, const void *k_ptr, const void *v_ptr, const void *bias_ptr, const void *q_descale_ptr, const void *k_descale_ptr, const void *v_descale_ptr, void *rand_val_ptr, void *lse_ptr, void *o_ptr, const void *seqstart_q_ptr, const void *seqstart_k_ptr, const void *seqlen_q_ptr, const void *seqlen_k_ptr, ck_tile::index_t hdim_q, ck_tile::index_t hdim_v, ck_tile::index_t num_head_q, ck_tile::index_t nhead_ratio_qk, float scale_s, float logits_soft_cap, ck_tile::index_t stride_q, ck_tile::index_t stride_k, ck_tile::index_t stride_v, ck_tile::index_t stride_bias, ck_tile::index_t stride_randval, ck_tile::index_t stride_o, ck_tile::index_t nhead_stride_q, ck_tile::index_t nhead_stride_k, ck_tile::index_t nhead_stride_v, ck_tile::index_t nhead_stride_bias, ck_tile::index_t nhead_stride_randval, ck_tile::index_t nhead_stride_lse, ck_tile::index_t nhead_stride_o, ck_tile::index_t window_size_left, ck_tile::index_t window_size_right, ck_tile::index_t mask_type, ck_tile::index_t min_seqlen_q, float p_drop, bool s_randval, const std::tuple< uint64_t, uint64_t > &drop_seed_offset, const void *cu_seqlen_q_ptr=nullptr, const void *cu_seqlen_k_ptr=nullptr)
Definition: fmha_fwd_kernel.hpp:785
static constexpr CK_TILE_HOST std::enable_if_t< Cond, Kargs > MakeKargsImpl(const void *q_ptr, const void *k_ptr, const void *v_ptr, const void *bias_ptr, const void *q_descale_ptr, const void *k_descale_ptr, const void *v_descale_ptr, void *rand_val_ptr, void *lse_ptr, void *o_ptr, ck_tile::index_t seqlen_q, ck_tile::index_t seqlen_k, ck_tile::index_t hdim_q, ck_tile::index_t hdim_v, ck_tile::index_t num_head_q, ck_tile::index_t nhead_ratio_qk, float scale_s, float logits_soft_cap, ck_tile::index_t stride_q, ck_tile::index_t stride_k, ck_tile::index_t stride_v, ck_tile::index_t stride_bias, ck_tile::index_t stride_randval, ck_tile::index_t stride_o, ck_tile::index_t nhead_stride_q, ck_tile::index_t nhead_stride_k, ck_tile::index_t nhead_stride_v, ck_tile::index_t nhead_stride_bias, ck_tile::index_t nhead_stride_randval, ck_tile::index_t nhead_stride_lse, ck_tile::index_t nhead_stride_o, ck_tile::index_t batch_stride_q, ck_tile::index_t batch_stride_k, ck_tile::index_t batch_stride_v, ck_tile::index_t batch_stride_bias, ck_tile::index_t batch_stride_randval, ck_tile::index_t batch_stride_lse, ck_tile::index_t batch_stride_o, ck_tile::index_t window_size_left, ck_tile::index_t window_size_right, ck_tile::index_t mask_type, float p_drop, bool s_randval, std::variant< std::pair< uint64_t, uint64_t >, std::pair< const void *, const void * >> drop_seed_offset, const void *cu_seqlen_q_ptr=nullptr, const void *cu_seqlen_k_ptr=nullptr)
Definition: fmha_fwd_kernel.hpp:298
static constexpr bool kUseTrLoad
Definition: fmha_fwd_kernel.hpp:68
static constexpr bool kHasMask
Definition: fmha_fwd_kernel.hpp:64
static constexpr bool kUseAsyncCopy
Definition: fmha_fwd_kernel.hpp:66
ck_tile::remove_cvref_t< FmhaPipeline_ > FmhaPipeline
Definition: fmha_fwd_kernel.hpp:29
static constexpr bool kPadHeadDimQ
Definition: fmha_fwd_kernel.hpp:53
ck_tile::remove_cvref_t< typename FmhaPipeline::SaccDataType > SaccDataType
Definition: fmha_fwd_kernel.hpp:46
static constexpr bool kPadSeqLenQ
Definition: fmha_fwd_kernel.hpp:51
ck_tile::remove_cvref_t< typename FmhaPipeline::FmhaMask > FmhaMask
Definition: fmha_fwd_kernel.hpp:63
static constexpr bool kHasLogitsSoftCap
Definition: fmha_fwd_kernel.hpp:55
ck_tile::remove_cvref_t< typename FmhaPipeline::RandValOutputDataType > RandValOutputDataType
Definition: fmha_fwd_kernel.hpp:43
static constexpr CK_TILE_HOST std::enable_if_t< Cond, Kargs > MakeKargs(const void *q_ptr, const void *k_ptr, const void *v_ptr, const void *bias_ptr, const void *q_descale_ptr, const void *k_descale_ptr, const void *v_descale_ptr, void *rand_val_ptr, void *lse_ptr, void *o_ptr, const void *seqstart_q_ptr, const void *seqstart_k_ptr, const void *seqlen_q_ptr, const void *seqlen_k_ptr, ck_tile::index_t hdim_q, ck_tile::index_t hdim_v, ck_tile::index_t num_head_q, ck_tile::index_t nhead_ratio_qk, float scale_s, float logits_soft_cap, ck_tile::index_t stride_q, ck_tile::index_t stride_k, ck_tile::index_t stride_v, ck_tile::index_t stride_bias, ck_tile::index_t stride_randval, ck_tile::index_t stride_o, ck_tile::index_t nhead_stride_q, ck_tile::index_t nhead_stride_k, ck_tile::index_t nhead_stride_v, ck_tile::index_t nhead_stride_bias, ck_tile::index_t nhead_stride_randval, ck_tile::index_t nhead_stride_lse, ck_tile::index_t nhead_stride_o, ck_tile::index_t window_size_left, ck_tile::index_t window_size_right, ck_tile::index_t mask_type, ck_tile::index_t min_seqlen_q, float p_drop, bool s_randval, const std::tuple< const void *, const void * > &drop_seed_offset, const void *cu_seqlen_q_ptr=nullptr, const void *cu_seqlen_k_ptr=nullptr)
Definition: fmha_fwd_kernel.hpp:876
ck_tile::remove_cvref_t< EpiloguePipeline_ > EpiloguePipeline
Definition: fmha_fwd_kernel.hpp:30
static constexpr bool kPadSeqLenK
Definition: fmha_fwd_kernel.hpp:52
static constexpr bool kIsGroupMode
Definition: fmha_fwd_kernel.hpp:50
CK_TILE_DEVICE void operator()(Kargs kargs) const
Definition: fmha_fwd_kernel.hpp:1071
Definition: variants.hpp:63
float logits_soft_cap
Definition: variants.hpp:128
Definition: block_dropout.hpp:39
Definition: variants.hpp:51
Definition: integral_constant.hpp:13
Definition: functional.hpp:86
Definition: numeric.hpp:18
Definition: coordinate_transform.hpp:1392
Definition: unary_element_function.hpp:56
Definition: math.hpp:28
Definition: sequence.hpp:49
const T * ptr
Definition: fmha_fwd_kernel.hpp:182