include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp Source File

include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp Source File#

Composable Kernel: include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp Source File
fmha_fwd_kernel.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include "ck_tile/core.hpp"
7 #include "ck_tile/ops/common.hpp"
9 
10 #include <string>
11 #include <type_traits>
12 #include <utility>
13 #include <variant>
14 
15 // S[seqlen_q, seqlen_k] = Q[seqlen_q, hdim_q] @ K[seqlen_k, hdim_q]
16 // S'[seqlen_q, seqlen_k] = S[seqlen_q, seqlen_k] * Scale[1]
17 // S''[seqlen_q, seqlen_k] = S'[seqlen_q, seqlen_k] + Bias[seqlen_q, seqlen_k]
18 // P[seqlen_q, seqlen_k] = Softmax(S''[seqlen_q, seqlen_k])
19 // O[seqlen_q, hdim_v] = P[seqlen_q, seqlen_k] @ V^T[hdim_v, seqlen_k]
20 
21 namespace ck_tile {
22 
23 template <typename FmhaPipeline_, typename EpiloguePipeline_>
25 {
28  static constexpr ck_tile::index_t kBlockSize = FmhaPipeline::kBlockSize;
29  static constexpr ck_tile::index_t kBlockPerCu = FmhaPipeline::kBlockPerCu;
30  static_assert(kBlockPerCu > 0);
31  static constexpr ck_tile::index_t kBlockPerCuInput = FmhaPipeline::Problem::kBlockPerCu;
32 
42 
44 
45  static constexpr bool kIsGroupMode = FmhaPipeline::kIsGroupMode;
46  static constexpr bool kPadSeqLenQ = FmhaPipeline::kPadSeqLenQ;
47  static constexpr bool kPadSeqLenK = FmhaPipeline::kPadSeqLenK;
48  static constexpr bool kPadHeadDimQ = FmhaPipeline::kPadHeadDimQ;
49  static constexpr bool kPadHeadDimV = FmhaPipeline::kPadHeadDimV;
50  static constexpr auto BiasEnum = FmhaPipeline::BiasEnum;
51  static constexpr bool kStoreLSE = FmhaPipeline::kStoreLSE;
52  static constexpr bool kHasDropout = FmhaPipeline::kHasDropout;
53  static constexpr bool kDoFp8StaticQuant = FmhaPipeline::Problem::kDoFp8StaticQuant;
55  static constexpr bool kHasMask = FmhaMask::IsMasking;
56 
57  // clang-format off
58  template <typename T> struct t2s;
59  template <> struct t2s<float> { static constexpr const char * name = "fp32"; };
60  template <> struct t2s<ck_tile::fp16_t> { static constexpr const char * name = "fp16"; };
61  template <> struct t2s<ck_tile::bf16_t> { static constexpr const char * name = "bf16"; };
62  template <> struct t2s<ck_tile::fp8_t> { static constexpr const char * name = "fp8"; };
63  template <> struct t2s<ck_tile::bf8_t> { static constexpr const char * name = "bf8"; };
64  // clang-format on
65 
66  CK_TILE_HOST static std::string GetName()
67  {
68  // sync with generate.py
69  // clang-format off
70  using bfs = typename FmhaPipeline::BlockFmhaShape;
71  using g0br = typename bfs::Gemm0BlockWarps;
72  using g1br = typename bfs::Gemm1BlockWarps;
73  using g0wt = typename bfs::Gemm0WarpTile;
74  using g1wt = typename bfs::Gemm1WarpTile;
75  #define _SS_ std::string
76  #define _TS_ std::to_string
77  auto pn = [&] () {
78  std::string n;
79  if (kPadSeqLenQ) n += "s";
80  if (kPadSeqLenK) n += "sk";
81  if (kPadHeadDimQ) n += "d";
82  if (kPadHeadDimV) n += "dv";
83  return n.empty() ? n : std::string("p") + n; }();
84  return
85  _SS_("fmha_fwd_d") + _TS_(bfs::kQKHeaddim) + "_" + _SS_(t2s<QDataType>::name) +
86  "_" + (kIsGroupMode ? "group" : "batch") + "_"
87  "b" + _TS_(bfs::kM0) + "x" + _TS_(bfs::kN0) + "x" + _TS_(bfs::kK0) + "x" +
88  _TS_(bfs::kN1) + "x" + _TS_(bfs::kK1) + "x" + _TS_(bfs::kQKHeaddim) + "_" +
89  "r" + _TS_(g0br::at(ck_tile::number<0>{})) + "x" + _TS_(g0br::at(ck_tile::number<1>{})) + "x" + _TS_(g0br::at(ck_tile::number<2>{})) + "_" +
90  "r" + _TS_(g1br::at(ck_tile::number<0>{})) + "x" + _TS_(g1br::at(ck_tile::number<1>{})) + "x" + _TS_(g1br::at(ck_tile::number<2>{})) + "_" +
91  "w" + _TS_(g0wt::at(ck_tile::number<0>{})) + "x" + _TS_(g0wt::at(ck_tile::number<1>{})) + "x" + _TS_(g0wt::at(ck_tile::number<2>{})) + "_" +
92  "w" + _TS_(g1wt::at(ck_tile::number<0>{})) + "x" + _TS_(g1wt::at(ck_tile::number<1>{})) + "x" + _TS_(g1wt::at(ck_tile::number<2>{})) + "_" +
93  (kBlockPerCuInput == -1 ? "" : ("o" + _TS_(kBlockPerCu) + "_")) + _SS_(FmhaPipeline::name) + "_" +
94  "v" + (std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor> ? "r" : "c") + (pn.empty() ? "" : "_" + pn) +
96  (kHasMask ? "_" + _SS_(FmhaMask::name) : "") + (kStoreLSE ? "_lse" : "" ) + (kHasDropout ? "_dropout" : "" ) + (kDoFp8StaticQuant ? "_squant" : "" );
97  #undef _SS_
98  #undef _TS_
99  // clang-format on
100  }
101 
102  template <ck_tile::index_t I> // to avoid duplicated base class prblem, introduce an template
103  // arg
105  {
106  };
107 
108  // kargs use aggregate initializer, so no constructor will provided
109  // use inheritance to minimize karg size
110  // user need to use MakeKargs() function to create kargs.
112  {
113  const void* q_ptr;
114  const void* k_ptr;
115  const void* v_ptr;
116  void* o_ptr;
117 
122 
124  // for MQA/GQA, nhead could be different. This parameter is nhead_q / nhead_k
125  // if this param is larger than 1, indicate MQA/GQA case
127  float scale_s;
128 
133 
138  };
139 
141  {
142  const void* bias_ptr = nullptr;
145  };
146 
148  {
150  };
151 
153  {
154  // alibi is batch*nhead*1, no matter in batch/group mode, they are the same
155  const void* alibi_slope_ptr;
156  ck_tile::index_t alibi_slope_stride; // stride in batch, or 0 for all batch share same slope
157  };
158 
160  {
161  // ck_tile::index_t window_size_left, window_size_right;
164  };
165 
167  {
168  float scale_p;
169  float scale_o;
170  };
171 
173  {
174  void* lse_ptr = nullptr;
177  };
178 
180  {
181  template <typename T>
183  {
184  T val;
185  const T* ptr;
186  };
187 
191  };
192 
194  {
195  void init_dropout(float p_drop, uint64_t seed, uint64_t offset)
196  {
197  float p_undrop = 1.0 - p_drop;
199  uint8_t(std::floor(p_undrop * std::numeric_limits<uint8_t>::max()));
200  rp_undrop = 1.0 / p_undrop;
201 
202  this->drop_seed.val = seed;
203  this->drop_offset.val = offset;
204  this->is_drop_seed_offset_from_host = true;
205  }
206 
207  void init_dropout(float p_drop, const uint64_t* seed_ptr, const uint64_t* offset_ptr)
208  {
209  float p_undrop = 1.0 - p_drop;
211  uint8_t(std::floor(p_undrop * std::numeric_limits<uint8_t>::max()));
212  rp_undrop = 1.0 / p_undrop;
213 
214  this->drop_seed.ptr = seed_ptr;
215  this->drop_offset.ptr = offset_ptr;
216  this->is_drop_seed_offset_from_host = false;
217  }
218 
219  float rp_undrop = 1;
221  bool is_store_randval = false;
222  void* rand_val_ptr = nullptr;
223 
226  };
227 
229  {
231  };
232 
235  std::conditional_t<BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS,
236  FmhaFwdBatchModeBiasKargs,
237  std::conditional_t<BiasEnum == BlockAttentionBiasEnum::ALIBI,
238  FmhaFwdAlibiKargs,
239  FmhaFwdEmptyKargs<0>>>,
240  std::conditional_t<kHasMask, FmhaFwdMaskKargs, FmhaFwdEmptyKargs<1>>,
241  std::conditional_t<kStoreLSE, FmhaFwdCommonLSEKargs, FmhaFwdEmptyKargs<2>>,
242  std::conditional_t<kDoFp8StaticQuant, FmhaFwdFp8StaticQuantKargs, FmhaFwdEmptyKargs<3>>,
243  std::conditional_t<kHasDropout, FmhaFwdBatchModeDropoutKargs, FmhaFwdEmptyKargs<4>>
244  {
249  };
250 
253  std::conditional_t<BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS,
254  FmhaFwdCommonBiasKargs,
255  std::conditional_t<BiasEnum == BlockAttentionBiasEnum::ALIBI,
256  FmhaFwdAlibiKargs,
257  FmhaFwdEmptyKargs<0>>>,
258  std::conditional_t<kHasMask, FmhaFwdMaskKargs, FmhaFwdEmptyKargs<1>>,
259  std::conditional_t<kStoreLSE, FmhaFwdCommonLSEKargs, FmhaFwdEmptyKargs<2>>,
260  std::conditional_t<kDoFp8StaticQuant, FmhaFwdFp8StaticQuantKargs, FmhaFwdEmptyKargs<3>>,
261  std::conditional_t<kHasDropout, FmhaFwdCommonDropoutKargs, FmhaFwdEmptyKargs<4>>
262  {
263  const int32_t* seqstart_q_ptr;
264  const int32_t* seqstart_k_ptr;
265  const int32_t* seqlen_k_ptr;
266  };
267 
268  using Kargs = std::conditional_t<kIsGroupMode, FmhaFwdGroupModeKargs, FmhaFwdBatchModeKargs>;
269 
270  template <bool Cond = !kIsGroupMode>
271  CK_TILE_HOST static constexpr std::enable_if_t<Cond, Kargs>
272  MakeKargsImpl(const void* q_ptr,
273  const void* k_ptr,
274  const void* v_ptr,
275  const void* bias_ptr,
276  void* rand_val_ptr,
277  void* lse_ptr,
278  void* o_ptr,
279  ck_tile::index_t seqlen_q,
280  ck_tile::index_t seqlen_k,
281  ck_tile::index_t hdim_q,
282  ck_tile::index_t hdim_v,
283  ck_tile::index_t num_head_q,
284  ck_tile::index_t nhead_ratio_qk,
285  float scale_s,
286  float scale_p,
287  float scale_o,
288  ck_tile::index_t stride_q,
289  ck_tile::index_t stride_k,
290  ck_tile::index_t stride_v,
291  ck_tile::index_t stride_bias,
292  ck_tile::index_t stride_randval,
293  ck_tile::index_t stride_o,
294  ck_tile::index_t nhead_stride_q,
295  ck_tile::index_t nhead_stride_k,
296  ck_tile::index_t nhead_stride_v,
297  ck_tile::index_t nhead_stride_bias,
298  ck_tile::index_t nhead_stride_randval,
299  ck_tile::index_t nhead_stride_lse,
300  ck_tile::index_t nhead_stride_o,
301  ck_tile::index_t batch_stride_q,
302  ck_tile::index_t batch_stride_k,
303  ck_tile::index_t batch_stride_v,
304  ck_tile::index_t batch_stride_bias,
305  ck_tile::index_t batch_stride_randval,
306  ck_tile::index_t batch_stride_lse,
307  ck_tile::index_t batch_stride_o,
308  ck_tile::index_t window_size_left,
309  ck_tile::index_t window_size_right,
310  ck_tile::index_t mask_type,
311  float p_drop,
312  bool s_randval,
313  std::variant<std::pair<uint64_t, uint64_t>, std::pair<const void*, const void*>>
314  drop_seed_offset)
315  {
316  Kargs kargs{{q_ptr,
317  k_ptr,
318  v_ptr,
319  o_ptr,
320  seqlen_q,
321  seqlen_k,
322  hdim_q,
323  hdim_v,
324  num_head_q,
325  nhead_ratio_qk,
326 #if CK_TILE_FMHA_FWD_FAST_EXP2
327  static_cast<float>(scale_s * ck_tile::log2e_v<>),
328 #else
329  scale_s,
330 #endif
331  stride_q,
332  stride_k,
333  stride_v,
334  stride_o,
335  nhead_stride_q,
336  nhead_stride_k,
337  nhead_stride_v,
338  nhead_stride_o}, // args for common karg
339  {}, // placeholder for bias
340  {}, // placeholder for mask
341  {}, // placeholder for lse
342  {}, // placeholder for fp8_static_quant args
343  {}, // placeholder for dropout
344  batch_stride_q,
345  batch_stride_k,
346  batch_stride_v,
347  batch_stride_o};
348 
350  {
351  kargs.bias_ptr = bias_ptr;
352  kargs.stride_bias = stride_bias;
353  kargs.nhead_stride_bias = nhead_stride_bias;
354  kargs.batch_stride_bias = batch_stride_bias;
355  }
356  else if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI)
357  {
358  kargs.alibi_slope_ptr = bias_ptr;
359  kargs.alibi_slope_stride = stride_bias;
360  }
361  if constexpr(kHasMask)
362  {
363  kargs.window_size_left = window_size_left;
364  kargs.window_size_right = window_size_right;
365  kargs.mask_type = static_cast<ck_tile::GenericAttentionMaskEnum>(mask_type);
366  }
367  if constexpr(kStoreLSE)
368  {
369  kargs.lse_ptr = lse_ptr;
370  kargs.nhead_stride_lse = nhead_stride_lse;
371  kargs.batch_stride_lse = batch_stride_lse;
372  }
373  if constexpr(kDoFp8StaticQuant)
374  {
375  kargs.scale_p = scale_p;
376  kargs.scale_o = scale_o;
377  }
378  if constexpr(kHasDropout)
379  {
380  if(drop_seed_offset.index() == 0) // seed & offset come from host
381  {
382  const auto& [seed, offset] = std::get<0>(drop_seed_offset);
383  kargs.init_dropout(p_drop, seed, offset);
384  }
385  else // seed & offset come from device
386  {
387  const auto& [seed_ptr, offset_ptr] = std::get<1>(drop_seed_offset);
388  kargs.init_dropout(p_drop,
389  reinterpret_cast<const uint64_t*>(seed_ptr),
390  reinterpret_cast<const uint64_t*>(offset_ptr));
391  }
392 
393  kargs.rand_val_ptr = rand_val_ptr;
394  kargs.stride_randval = stride_randval;
395  kargs.nhead_stride_randval = nhead_stride_randval;
396  kargs.batch_stride_randval = batch_stride_randval;
397  kargs.is_store_randval = s_randval;
398  }
399 
400  return kargs;
401  }
402 
403  // std::variant<> can't take in a list initializer, overload for backward compatibility
404  template <bool Cond = !kIsGroupMode>
405  CK_TILE_HOST static constexpr std::enable_if_t<Cond, Kargs>
406  MakeKargs(const void* q_ptr,
407  const void* k_ptr,
408  const void* v_ptr,
409  const void* bias_ptr,
410  void* rand_val_ptr,
411  void* lse_ptr,
412  void* o_ptr,
413  ck_tile::index_t seqlen_q,
414  ck_tile::index_t seqlen_k,
415  ck_tile::index_t hdim_q,
416  ck_tile::index_t hdim_v,
417  ck_tile::index_t num_head_q,
418  ck_tile::index_t nhead_ratio_qk,
419  float scale_s,
420  float scale_p,
421  float scale_o,
422  ck_tile::index_t stride_q,
423  ck_tile::index_t stride_k,
424  ck_tile::index_t stride_v,
425  ck_tile::index_t stride_bias,
426  ck_tile::index_t stride_randval,
427  ck_tile::index_t stride_o,
428  ck_tile::index_t nhead_stride_q,
429  ck_tile::index_t nhead_stride_k,
430  ck_tile::index_t nhead_stride_v,
431  ck_tile::index_t nhead_stride_bias,
432  ck_tile::index_t nhead_stride_randval,
433  ck_tile::index_t nhead_stride_lse,
434  ck_tile::index_t nhead_stride_o,
435  ck_tile::index_t batch_stride_q,
436  ck_tile::index_t batch_stride_k,
437  ck_tile::index_t batch_stride_v,
438  ck_tile::index_t batch_stride_bias,
439  ck_tile::index_t batch_stride_randval,
440  ck_tile::index_t batch_stride_lse,
441  ck_tile::index_t batch_stride_o,
442  ck_tile::index_t window_size_left,
443  ck_tile::index_t window_size_right,
444  ck_tile::index_t mask_type,
445  float p_drop,
446  bool s_randval,
447  const std::tuple<uint64_t, uint64_t>& drop_seed_offset)
448  {
449  return MakeKargsImpl(
450  q_ptr,
451  k_ptr,
452  v_ptr,
453  bias_ptr,
454  rand_val_ptr,
455  lse_ptr,
456  o_ptr,
457  seqlen_q,
458  seqlen_k,
459  hdim_q,
460  hdim_v,
461  num_head_q,
462  nhead_ratio_qk,
463  scale_s,
464  scale_p,
465  scale_o,
466  stride_q,
467  stride_k,
468  stride_v,
469  stride_bias,
470  stride_randval,
471  stride_o,
472  nhead_stride_q,
473  nhead_stride_k,
474  nhead_stride_v,
475  nhead_stride_bias,
476  nhead_stride_randval,
477  nhead_stride_lse,
478  nhead_stride_o,
479  batch_stride_q,
480  batch_stride_k,
481  batch_stride_v,
482  batch_stride_bias,
483  batch_stride_randval,
484  batch_stride_lse,
485  batch_stride_o,
486  window_size_left,
487  window_size_right,
488  mask_type,
489  p_drop,
490  s_randval,
491  std::make_pair(std::get<0>(drop_seed_offset), std::get<1>(drop_seed_offset)));
492  }
493 
494  // std::variant<> can't take in a list initializer, overload for backward compatibility
495  template <bool Cond = !kIsGroupMode>
496  CK_TILE_HOST static constexpr std::enable_if_t<Cond, Kargs>
497  MakeKargs(const void* q_ptr,
498  const void* k_ptr,
499  const void* v_ptr,
500  const void* bias_ptr,
501  void* rand_val_ptr,
502  void* lse_ptr,
503  void* o_ptr,
504  ck_tile::index_t seqlen_q,
505  ck_tile::index_t seqlen_k,
506  ck_tile::index_t hdim_q,
507  ck_tile::index_t hdim_v,
508  ck_tile::index_t num_head_q,
509  ck_tile::index_t nhead_ratio_qk,
510  float scale_s,
511  float scale_p,
512  float scale_o,
513  ck_tile::index_t stride_q,
514  ck_tile::index_t stride_k,
515  ck_tile::index_t stride_v,
516  ck_tile::index_t stride_bias,
517  ck_tile::index_t stride_randval,
518  ck_tile::index_t stride_o,
519  ck_tile::index_t nhead_stride_q,
520  ck_tile::index_t nhead_stride_k,
521  ck_tile::index_t nhead_stride_v,
522  ck_tile::index_t nhead_stride_bias,
523  ck_tile::index_t nhead_stride_randval,
524  ck_tile::index_t nhead_stride_lse,
525  ck_tile::index_t nhead_stride_o,
526  ck_tile::index_t batch_stride_q,
527  ck_tile::index_t batch_stride_k,
528  ck_tile::index_t batch_stride_v,
529  ck_tile::index_t batch_stride_bias,
530  ck_tile::index_t batch_stride_randval,
531  ck_tile::index_t batch_stride_lse,
532  ck_tile::index_t batch_stride_o,
533  ck_tile::index_t window_size_left,
534  ck_tile::index_t window_size_right,
535  ck_tile::index_t mask_type,
536  float p_drop,
537  bool s_randval,
538  const std::tuple<const void*, const void*>& drop_seed_offset)
539  {
540  return MakeKargsImpl(
541  q_ptr,
542  k_ptr,
543  v_ptr,
544  bias_ptr,
545  rand_val_ptr,
546  lse_ptr,
547  o_ptr,
548  seqlen_q,
549  seqlen_k,
550  hdim_q,
551  hdim_v,
552  num_head_q,
553  nhead_ratio_qk,
554  scale_s,
555  scale_p,
556  scale_o,
557  stride_q,
558  stride_k,
559  stride_v,
560  stride_bias,
561  stride_randval,
562  stride_o,
563  nhead_stride_q,
564  nhead_stride_k,
565  nhead_stride_v,
566  nhead_stride_bias,
567  nhead_stride_randval,
568  nhead_stride_lse,
569  nhead_stride_o,
570  batch_stride_q,
571  batch_stride_k,
572  batch_stride_v,
573  batch_stride_bias,
574  batch_stride_randval,
575  batch_stride_lse,
576  batch_stride_o,
577  window_size_left,
578  window_size_right,
579  mask_type,
580  p_drop,
581  s_randval,
582  std::make_pair(std::get<0>(drop_seed_offset), std::get<1>(drop_seed_offset)));
583  }
584 
585  template <bool Cond = kIsGroupMode>
586  CK_TILE_HOST static constexpr std::enable_if_t<Cond, Kargs>
587  MakeKargsImpl(const void* q_ptr,
588  const void* k_ptr,
589  const void* v_ptr,
590  const void* bias_ptr,
591  void* rand_val_ptr,
592  void* lse_ptr,
593  void* o_ptr,
594  const void* seqstart_q_ptr,
595  const void* seqstart_k_ptr,
596  const void* seqlen_k_ptr,
597  ck_tile::index_t hdim_q,
598  ck_tile::index_t hdim_v,
599  ck_tile::index_t num_head_q,
600  ck_tile::index_t nhead_ratio_qk,
601  float scale_s,
602  float scale_p,
603  float scale_o,
604  ck_tile::index_t stride_q,
605  ck_tile::index_t stride_k,
606  ck_tile::index_t stride_v,
607  ck_tile::index_t stride_bias,
608  ck_tile::index_t stride_randval,
609  ck_tile::index_t stride_o,
610  ck_tile::index_t nhead_stride_q,
611  ck_tile::index_t nhead_stride_k,
612  ck_tile::index_t nhead_stride_v,
613  ck_tile::index_t nhead_stride_bias,
614  ck_tile::index_t nhead_stride_randval,
615  ck_tile::index_t nhead_stride_lse,
616  ck_tile::index_t nhead_stride_o,
617  ck_tile::index_t window_size_left,
618  ck_tile::index_t window_size_right,
619  ck_tile::index_t mask_type,
620  float p_drop,
621  bool s_randval,
622  std::variant<std::pair<uint64_t, uint64_t>, std::pair<const void*, const void*>>
623  drop_seed_offset)
624  {
625  Kargs kargs{{q_ptr,
626  k_ptr,
627  v_ptr,
628  o_ptr,
629  -1, // seqlen will be updated by another pointer
630  -1, //
631  hdim_q,
632  hdim_v,
633  num_head_q,
634  nhead_ratio_qk,
635 #if CK_TILE_FMHA_FWD_FAST_EXP2
636  static_cast<float>(scale_s * ck_tile::log2e_v<>),
637 #else
638  scale_s,
639 #endif
640  stride_q,
641  stride_k,
642  stride_v,
643  stride_o,
644  nhead_stride_q,
645  nhead_stride_k,
646  nhead_stride_v,
647  nhead_stride_o}, // args for common karg
648  {}, // placeholder for bias
649  {}, // placeholder for mask
650  {}, // placeholder for lse
651  {}, // placeholder for fp8_static_quant args
652  {}, // placeholder for dropout
653  reinterpret_cast<const int32_t*>(seqstart_q_ptr),
654  reinterpret_cast<const int32_t*>(seqstart_k_ptr),
655  reinterpret_cast<const int32_t*>(seqlen_k_ptr)};
656 
658  {
659  kargs.bias_ptr = bias_ptr;
660  kargs.stride_bias = stride_bias;
661  kargs.nhead_stride_bias = nhead_stride_bias;
662  }
663  else if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI)
664  {
665  kargs.alibi_slope_ptr = bias_ptr;
666  kargs.alibi_slope_stride = stride_bias;
667  }
668  if constexpr(kHasMask)
669  {
670  kargs.window_size_left = window_size_left;
671  kargs.window_size_right = window_size_right;
672  kargs.mask_type = static_cast<ck_tile::GenericAttentionMaskEnum>(mask_type);
673  }
674  if constexpr(kStoreLSE)
675  {
676  kargs.lse_ptr = lse_ptr;
677  kargs.nhead_stride_lse = nhead_stride_lse;
678  }
679  if constexpr(kDoFp8StaticQuant)
680  {
681  kargs.scale_p = scale_p;
682  kargs.scale_o = scale_o;
683  }
684  if constexpr(kHasDropout)
685  {
686  if(drop_seed_offset.index() == 0) // seed & offset come from host
687  {
688  const auto& [seed, offset] = std::get<0>(drop_seed_offset);
689  kargs.init_dropout(p_drop, seed, offset);
690  }
691  else // seed & offset come from device
692  {
693  const auto& [seed_ptr, offset_ptr] = std::get<1>(drop_seed_offset);
694  kargs.init_dropout(p_drop,
695  reinterpret_cast<const uint64_t*>(seed_ptr),
696  reinterpret_cast<const uint64_t*>(offset_ptr));
697  }
698 
699  kargs.rand_val_ptr = rand_val_ptr;
700  kargs.stride_randval = stride_randval;
701  kargs.nhead_stride_randval = nhead_stride_randval;
702  kargs.is_store_randval = s_randval;
703  }
704 
705  return kargs;
706  }
707 
708  // std::variant<> can't take in a list initializer, overload for backward compatibility
709  template <bool Cond = kIsGroupMode>
710  CK_TILE_HOST static constexpr std::enable_if_t<Cond, Kargs>
711  MakeKargs(const void* q_ptr,
712  const void* k_ptr,
713  const void* v_ptr,
714  const void* bias_ptr,
715  void* rand_val_ptr,
716  void* lse_ptr,
717  void* o_ptr,
718  const void* seqstart_q_ptr,
719  const void* seqstart_k_ptr,
720  const void* seqlen_k_ptr,
721  ck_tile::index_t hdim_q,
722  ck_tile::index_t hdim_v,
723  ck_tile::index_t num_head_q,
724  ck_tile::index_t nhead_ratio_qk,
725  float scale_s,
726  float scale_p,
727  float scale_o,
728  ck_tile::index_t stride_q,
729  ck_tile::index_t stride_k,
730  ck_tile::index_t stride_v,
731  ck_tile::index_t stride_bias,
732  ck_tile::index_t stride_randval,
733  ck_tile::index_t stride_o,
734  ck_tile::index_t nhead_stride_q,
735  ck_tile::index_t nhead_stride_k,
736  ck_tile::index_t nhead_stride_v,
737  ck_tile::index_t nhead_stride_bias,
738  ck_tile::index_t nhead_stride_randval,
739  ck_tile::index_t nhead_stride_lse,
740  ck_tile::index_t nhead_stride_o,
741  ck_tile::index_t window_size_left,
742  ck_tile::index_t window_size_right,
743  ck_tile::index_t mask_type,
744  float p_drop,
745  bool s_randval,
746  const std::tuple<uint64_t, uint64_t>& drop_seed_offset)
747  {
748  return MakeKargsImpl(
749  q_ptr,
750  k_ptr,
751  v_ptr,
752  bias_ptr,
753  rand_val_ptr,
754  lse_ptr,
755  o_ptr,
756  seqstart_q_ptr,
757  seqstart_k_ptr,
758  seqlen_k_ptr,
759  hdim_q,
760  hdim_v,
761  num_head_q,
762  nhead_ratio_qk,
763  scale_s,
764  scale_p,
765  scale_o,
766  stride_q,
767  stride_k,
768  stride_v,
769  stride_bias,
770  stride_randval,
771  stride_o,
772  nhead_stride_q,
773  nhead_stride_k,
774  nhead_stride_v,
775  nhead_stride_bias,
776  nhead_stride_randval,
777  nhead_stride_lse,
778  nhead_stride_o,
779  window_size_left,
780  window_size_right,
781  mask_type,
782  p_drop,
783  s_randval,
784  std::make_pair(std::get<0>(drop_seed_offset), std::get<1>(drop_seed_offset)));
785  }
786 
787  // std::variant<> can't take in a list initializer, overload for backward compatibility
788  template <bool Cond = kIsGroupMode>
789  CK_TILE_HOST static constexpr std::enable_if_t<Cond, Kargs>
790  MakeKargs(const void* q_ptr,
791  const void* k_ptr,
792  const void* v_ptr,
793  const void* bias_ptr,
794  void* rand_val_ptr,
795  void* lse_ptr,
796  void* o_ptr,
797  const void* seqstart_q_ptr,
798  const void* seqstart_k_ptr,
799  const void* seqlen_k_ptr,
800  ck_tile::index_t hdim_q,
801  ck_tile::index_t hdim_v,
802  ck_tile::index_t num_head_q,
803  ck_tile::index_t nhead_ratio_qk,
804  float scale_s,
805  float scale_p,
806  float scale_o,
807  ck_tile::index_t stride_q,
808  ck_tile::index_t stride_k,
809  ck_tile::index_t stride_v,
810  ck_tile::index_t stride_bias,
811  ck_tile::index_t stride_randval,
812  ck_tile::index_t stride_o,
813  ck_tile::index_t nhead_stride_q,
814  ck_tile::index_t nhead_stride_k,
815  ck_tile::index_t nhead_stride_v,
816  ck_tile::index_t nhead_stride_bias,
817  ck_tile::index_t nhead_stride_randval,
818  ck_tile::index_t nhead_stride_lse,
819  ck_tile::index_t nhead_stride_o,
820  ck_tile::index_t window_size_left,
821  ck_tile::index_t window_size_right,
822  ck_tile::index_t mask_type,
823  float p_drop,
824  bool s_randval,
825  const std::tuple<const void*, const void*>& drop_seed_offset)
826  {
827  return MakeKargsImpl(
828  q_ptr,
829  k_ptr,
830  v_ptr,
831  bias_ptr,
832  rand_val_ptr,
833  lse_ptr,
834  o_ptr,
835  seqstart_q_ptr,
836  seqstart_k_ptr,
837  seqlen_k_ptr,
838  hdim_q,
839  hdim_v,
840  num_head_q,
841  nhead_ratio_qk,
842  scale_s,
843  scale_p,
844  scale_o,
845  stride_q,
846  stride_k,
847  stride_v,
848  stride_bias,
849  stride_randval,
850  stride_o,
851  nhead_stride_q,
852  nhead_stride_k,
853  nhead_stride_v,
854  nhead_stride_bias,
855  nhead_stride_randval,
856  nhead_stride_lse,
857  nhead_stride_o,
858  window_size_left,
859  window_size_right,
860  mask_type,
861  p_drop,
862  s_randval,
863  std::make_pair(std::get<0>(drop_seed_offset), std::get<1>(drop_seed_offset)));
864  }
865 
866  CK_TILE_HOST static constexpr auto GridSize(ck_tile::index_t batch_size_,
867  ck_tile::index_t nhead_,
868  ck_tile::index_t seqlen_q_,
869  ck_tile::index_t hdim_v_,
870  bool has_padded_seqlen_k = false)
871  {
872  // has_padded_seqlen_k is determined by checking (seqlen_k_ptr != nullptr)
873  if(has_padded_seqlen_k)
874  {
875  // TODO: this may need tuning
876  return dim3(nhead_,
877  batch_size_,
878  ck_tile::integer_divide_ceil(seqlen_q_, FmhaPipeline::kM0) *
879  ck_tile::integer_divide_ceil(hdim_v_, FmhaPipeline::kN1));
880  }
881  else
882  {
883  // TODO: this may need tuning
884  return dim3(ck_tile::integer_divide_ceil(seqlen_q_, FmhaPipeline::kM0) *
885  ck_tile::integer_divide_ceil(hdim_v_, FmhaPipeline::kN1),
886  nhead_,
887  batch_size_);
888  }
889  }
890 
891  CK_TILE_DEVICE static constexpr auto GetTileIndex(const Kargs& kargs)
892  {
893  bool has_padded_seqlen_k = false;
894 
895  if constexpr(kIsGroupMode)
896  has_padded_seqlen_k = (kargs.seqlen_k_ptr != nullptr);
897 
898  if(has_padded_seqlen_k)
899  {
900  // const index_t num_tile_m0 = seqlen_q / kM0;
901  const index_t num_tile_n1 =
902  ck_tile::integer_divide_ceil(kargs.hdim_v, FmhaPipeline::kN1);
903 
904  const index_t i_block = blockIdx.z;
905  const index_t i_nhead = blockIdx.x;
906  const index_t i_batch = blockIdx.y;
907 
908  const auto f = [](index_t dividend, index_t divisor) {
909  index_t quotient = dividend / divisor;
910  index_t modulus = dividend - quotient * divisor;
911  return ck_tile::make_tuple(quotient, modulus);
912  };
913 
914  const auto [i_tile_m, i_tile_n] = f(i_block, num_tile_n1);
915 
916  return ck_tile::make_tuple(i_tile_m, i_tile_n, i_nhead, i_batch);
917  }
918  else
919  {
920  // const index_t num_tile_m0 = seqlen_q / kM0;
921  const index_t num_tile_n1 =
922  ck_tile::integer_divide_ceil(kargs.hdim_v, FmhaPipeline::kN1);
923 
924  const index_t i_block = blockIdx.x;
925  const index_t i_nhead = blockIdx.y;
926  const index_t i_batch = blockIdx.z;
927 
928  const auto f = [](index_t dividend, index_t divisor) {
929  index_t quotient = dividend / divisor;
930  index_t modulus = dividend - quotient * divisor;
931  return ck_tile::make_tuple(quotient, modulus);
932  };
933 
934  const auto [i_tile_m, i_tile_n] = f(i_block, num_tile_n1);
935 
936  return ck_tile::make_tuple(i_tile_m, i_tile_n, i_nhead, i_batch);
937  }
938  }
939 
940  CK_TILE_HOST static constexpr auto BlockSize() { return dim3(kBlockSize); }
941 
943  {
944  return ck_tile::max(FmhaPipeline::GetSmemSize(), EpiloguePipeline::GetSmemSize());
945  }
946 
947  CK_TILE_DEVICE void operator()(Kargs kargs) const
948  {
949  // allocate LDS
950  __shared__ char smem_ptr[GetSmemSize()];
951 
952  // divide problem
953  const auto [i_tile_m, i_tile_n, i_nhead, i_batch] = GetTileIndex(kargs);
954 
955  const index_t i_m0 = __builtin_amdgcn_readfirstlane(i_tile_m * FmhaPipeline::kM0);
956  const index_t i_n1 = __builtin_amdgcn_readfirstlane(i_tile_n * FmhaPipeline::kN1);
957 
958  long_index_t batch_offset_q = 0;
959  long_index_t batch_offset_k = 0;
960  long_index_t batch_offset_v = 0;
961  long_index_t batch_offset_bias = 0;
962  long_index_t batch_offset_randval = 0;
963  long_index_t batch_offset_lse = 0;
964  long_index_t batch_offset_o = 0;
965 
966  if constexpr(kIsGroupMode)
967  {
968  // get starting offset for each batch
969  const long_index_t query_start = kargs.seqstart_q_ptr[i_batch];
970  const long_index_t key_start = kargs.seqstart_k_ptr[i_batch];
971 
972  batch_offset_q = query_start * kargs.stride_q;
973  batch_offset_k = key_start * kargs.stride_k;
974  if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
975  {
976  batch_offset_v = key_start * kargs.stride_v;
977  }
978  else
979  {
980  batch_offset_v = key_start;
981  }
983  {
984  batch_offset_bias = query_start * kargs.stride_bias + key_start;
985  }
986  if constexpr(kStoreLSE)
987  {
988  batch_offset_lse = query_start;
989  }
990  if constexpr(kHasDropout)
991  {
992  batch_offset_randval = query_start * kargs.stride_randval;
993  }
994  batch_offset_o = query_start * kargs.stride_o;
995 
996  // get real # queries & # keys under group mode
997  const auto adjusted_seqstart_q_ptr = kargs.seqstart_q_ptr + i_batch;
998  kargs.seqlen_q = adjusted_seqstart_q_ptr[1] - adjusted_seqstart_q_ptr[0];
999 
1000  // # of required blocks is different in each groups, terminate unnecessary blocks
1001  // earlier
1002  if(kargs.seqlen_q <= i_m0)
1003  {
1004  return;
1005  }
1006 
1007  if(kargs.seqlen_k_ptr != nullptr)
1008  {
1009  kargs.seqlen_k = kargs.seqlen_k_ptr[i_batch];
1010  }
1011  else
1012  {
1013  const auto adjusted_seqstart_k_ptr = kargs.seqstart_k_ptr + i_batch;
1014  kargs.seqlen_k = adjusted_seqstart_k_ptr[1] - adjusted_seqstart_k_ptr[0];
1015  }
1016  }
1017  else
1018  {
1019  batch_offset_q = static_cast<long_index_t>(i_batch) * kargs.batch_stride_q;
1020  batch_offset_k = static_cast<long_index_t>(i_batch) * kargs.batch_stride_k;
1021  batch_offset_v = static_cast<long_index_t>(i_batch) * kargs.batch_stride_v;
1023  {
1024  batch_offset_bias = static_cast<long_index_t>(i_batch) * kargs.batch_stride_bias;
1025  }
1026  if constexpr(kStoreLSE)
1027  {
1028  batch_offset_lse = static_cast<long_index_t>(i_batch) * kargs.batch_stride_lse;
1029  }
1030  if constexpr(kHasDropout)
1031  {
1032  batch_offset_randval =
1033  static_cast<long_index_t>(i_batch) * kargs.batch_stride_randval;
1034  }
1035  batch_offset_o = static_cast<long_index_t>(i_batch) * kargs.batch_stride_o;
1036  }
1037 
1038  // for simplicity, batch stride we just modify the pointer
1039  const QDataType* q_ptr = reinterpret_cast<const QDataType*>(kargs.q_ptr) +
1040  static_cast<long_index_t>(i_nhead) * kargs.nhead_stride_q +
1041  batch_offset_q;
1042  const KDataType* k_ptr =
1043  reinterpret_cast<const KDataType*>(kargs.k_ptr) +
1044  static_cast<long_index_t>(i_nhead / kargs.nhead_ratio_qk) * kargs.nhead_stride_k +
1045  batch_offset_k;
1046  const VDataType* v_ptr =
1047  reinterpret_cast<const VDataType*>(kargs.v_ptr) +
1048  static_cast<long_index_t>(i_nhead / kargs.nhead_ratio_qk) * kargs.nhead_stride_v +
1049  batch_offset_v;
1050  ODataType* o_ptr = reinterpret_cast<ODataType*>(kargs.o_ptr) +
1051  static_cast<long_index_t>(i_nhead) * kargs.nhead_stride_o +
1052  batch_offset_o;
1053 
1054  // Q/K/V DRAM and DRAM window
1055  const auto q_dram = [&]() {
1056  const auto q_dram_naive = make_naive_tensor_view<address_space_enum::global>(
1057  q_ptr,
1058  make_tuple(kargs.seqlen_q, kargs.hdim_q),
1059  make_tuple(kargs.stride_q, 1),
1061  number<1>{});
1062  if constexpr(FmhaPipeline::kQLoadOnce)
1063  {
1064  return pad_tensor_view(
1065  q_dram_naive,
1068  }
1069  else
1070  {
1071  return pad_tensor_view(
1072  q_dram_naive,
1075  }
1076  }();
1077  const auto k_dram = [&]() {
1078  const auto k_dram_naive = make_naive_tensor_view<address_space_enum::global>(
1079  k_ptr,
1080  make_tuple(kargs.seqlen_k, kargs.hdim_q),
1081  make_tuple(kargs.stride_k, 1),
1083  number<1>{});
1084 
1085  return pad_tensor_view(
1086  k_dram_naive,
1089  }();
1090  const auto v_dram = [&]() {
1091  if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
1092  {
1093  const auto v_dram_naive = make_naive_tensor_view<address_space_enum::global>(
1094  v_ptr,
1095  make_tuple(kargs.seqlen_k, kargs.hdim_v),
1096  make_tuple(kargs.stride_v, 1),
1098  number<1>{});
1099 
1100  const auto v_dram_transposed =
1101  transform_tensor_view(v_dram_naive,
1103  make_pass_through_transform(kargs.seqlen_k)),
1106 
1107  return pad_tensor_view(
1108  v_dram_transposed,
1111  }
1112  else
1113  {
1114  const auto v_dram_naive = make_naive_tensor_view<address_space_enum::global>(
1115  v_ptr,
1116  make_tuple(kargs.hdim_v, kargs.seqlen_k),
1117  make_tuple(kargs.stride_v, 1),
1119  number<1>{});
1120 
1121  return pad_tensor_view(
1122  v_dram_naive,
1125  }
1126  }();
1127 
1128  auto q_dram_window = make_tile_window(
1129  q_dram,
1130  [&]() {
1131  if constexpr(FmhaPipeline::kQLoadOnce)
1134  else
1136  }(),
1137  {i_m0, 0});
1138 
1139  auto k_dram_window = make_tile_window(
1141 
1142  auto v_dram_window =
1143  make_tile_window(v_dram,
1145  {i_n1, 0});
1148  const auto bias_dram_window = [&, i_nhead_ = i_nhead]() {
1149  constexpr auto bias_dram_window_lengths =
1152  {
1153  const BiasDataType* bias_ptr =
1154  reinterpret_cast<const BiasDataType*>(kargs.bias_ptr) +
1155  static_cast<long_index_t>(i_nhead_) * kargs.nhead_stride_bias +
1156  batch_offset_bias;
1157 
1158  const auto bias_dram = [&]() {
1159  const auto bias_dram_naive = make_naive_tensor_view<address_space_enum::global>(
1160  bias_ptr,
1161  make_tuple(kargs.seqlen_q, kargs.seqlen_k),
1162  make_tuple(kargs.stride_bias, 1),
1164  number<1>{});
1165 
1166  return pad_tensor_view(bias_dram_naive,
1167  bias_dram_window_lengths,
1169  }();
1170 
1171  return make_tile_window(bias_dram, bias_dram_window_lengths, {i_m0, 0});
1172  }
1173  else
1174  {
1175  return make_null_tile_window(bias_dram_window_lengths);
1176  }
1177  }();
1178 
1179  // lse
1180  auto lse_dram_window = [&, i_nhead_ = i_nhead]() {
1181  constexpr auto lse_dram_window_lengths = make_tuple(number<FmhaPipeline::kM0>{});
1182  if constexpr(kStoreLSE)
1183  {
1184  LSEDataType* lse_ptr =
1185  reinterpret_cast<LSEDataType*>(kargs.lse_ptr) +
1186  static_cast<long_index_t>(i_nhead_) * kargs.nhead_stride_lse + batch_offset_lse;
1187 
1188  const auto lse_dram = [&]() {
1189  const auto lse_dram_naive = make_naive_tensor_view<address_space_enum::global>(
1190  lse_ptr,
1191  make_tuple(kargs.seqlen_q),
1192  make_tuple(1),
1193  number<1>{},
1194  number<1>{});
1195 
1196  return pad_tensor_view(
1197  lse_dram_naive, lse_dram_window_lengths, sequence<kPadSeqLenQ>{});
1198  }();
1199 
1200  return make_tile_window(lse_dram, lse_dram_window_lengths, {i_m0});
1201  }
1202  else
1203  {
1204  return make_null_tile_window(lse_dram_window_lengths);
1205  }
1206  }();
1207 
1208  auto dropout = [&, i_nhead_ = i_nhead, i_batch_ = i_batch]() {
1209  if constexpr(kHasDropout)
1210  {
1211  return BlockDropout{i_batch_,
1212  i_nhead_,
1213  kargs.num_head_q,
1214  kargs.is_drop_seed_offset_from_host ? kargs.drop_seed.val
1215  : *kargs.drop_seed.ptr,
1216  kargs.is_drop_seed_offset_from_host ? kargs.drop_offset.val
1217  : *kargs.drop_offset.ptr,
1218  kargs.rp_undrop,
1219  kargs.p_undrop_in_uint8_t,
1220  kargs.is_store_randval};
1221  }
1222  else
1223  {
1224  return NullBlockDropout{};
1225  };
1226  }();
1227 
1228  auto randval_dram_window = [&, i_nhead_ = i_nhead]() {
1229  constexpr auto randval_dram_window_lengths =
1231  if constexpr(kHasDropout)
1232  {
1233  RandValOutputDataType* rand_val_ptr =
1234  reinterpret_cast<RandValOutputDataType*>(kargs.rand_val_ptr) +
1235  static_cast<long_index_t>(i_nhead_) * kargs.nhead_stride_randval +
1236  batch_offset_randval;
1237 
1238  const auto randval_dram = [&]() {
1239  const auto randval_dram_naive =
1240  make_naive_tensor_view<address_space_enum::global>(
1241  rand_val_ptr,
1242  make_tuple(kargs.seqlen_q, kargs.seqlen_k),
1243  make_tuple(kargs.stride_randval, 1),
1244  number<1>{},
1245  number<1>{});
1246 
1247  return pad_tensor_view(randval_dram_naive,
1248  randval_dram_window_lengths,
1250  }();
1251 
1252  return make_tile_window(randval_dram, randval_dram_window_lengths, {i_m0, 0});
1253  }
1254  else
1255  {
1256  return make_null_tile_window(randval_dram_window_lengths);
1257  }
1258  }();
1259 
1260  FmhaMask mask = [&]() {
1261  if constexpr(kHasMask)
1262  return ck_tile::make_generic_attention_mask_from_lr_window<FmhaMask>(
1263  kargs.window_size_left,
1264  kargs.window_size_right,
1265  kargs.seqlen_q,
1266  kargs.seqlen_k,
1268  else
1269  return FmhaMask{kargs.seqlen_q, kargs.seqlen_k};
1270  }();
1271 
1272  // WA i_batch capture structure binding before c++20
1273  auto position_encoding = [&, i_batch_ = i_batch, i_nhead_ = i_nhead]() {
1274  if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI)
1275  {
1276  // data loading, shared by entire wg
1277  // TODO: how to use s_read?
1278  SaccDataType slope =
1279  *(reinterpret_cast<const SaccDataType*>(kargs.alibi_slope_ptr) +
1280  i_batch_ * kargs.alibi_slope_stride + i_nhead_);
1281 #if CK_TILE_FMHA_FWD_FAST_EXP2
1282  slope *= ck_tile::log2e_v<>;
1283 #endif
1284  if constexpr(kHasMask)
1285  {
1286  return make_alibi_from_lr_mask<SaccDataType, true>(slope,
1287  kargs.window_size_left,
1288  kargs.window_size_right,
1289  kargs.seqlen_q,
1290  kargs.seqlen_k,
1291  kargs.mask_type);
1292  }
1293  else
1294  {
1296  slope, kargs.seqlen_q, kargs.seqlen_k, AlibiMode::FROM_BOTTOM_RIGHT};
1297  }
1298  }
1299  else
1300  {
1302  }
1303  }();
1304 
1305  auto o_acc_tile = [&]() {
1306  if constexpr(kDoFp8StaticQuant)
1307  {
1308  return FmhaPipeline{}(
1309  q_dram_window,
1310  identity{}, // q_element_func
1311  k_dram_window,
1312  identity{}, // k_element_func
1313  v_dram_window,
1314  identity{}, // v_element_func
1315  bias_dram_window,
1316  identity{}, // bias_element_func
1317  randval_dram_window,
1318  lse_dram_window,
1319  identity{}, // lse_element_func
1320  identity{}, // s_acc_element_func
1321  scales{kargs.scale_p}, // p_compute_element_func
1322  composes(saturates<fp8_t>{}, scales{kargs.scale_o}), // o_acc_element_func
1323  mask,
1324  position_encoding,
1325  kargs.scale_s,
1326  smem_ptr,
1327  dropout);
1328  }
1329  else
1330  {
1331  return FmhaPipeline{}(q_dram_window,
1332  k_dram_window,
1333  v_dram_window,
1334  bias_dram_window,
1335  randval_dram_window,
1336  lse_dram_window,
1337  mask,
1338  position_encoding,
1339  kargs.scale_s,
1340  smem_ptr,
1341  dropout);
1342  }
1343  }();
1344 
1345  // O DRAM and O DRAM window
1346  auto o_dram = [&]() {
1347  const auto o_dram_naive = make_naive_tensor_view<address_space_enum::global>(
1348  o_ptr,
1349  make_tuple(kargs.seqlen_q, kargs.hdim_v),
1350  make_tuple(kargs.stride_o, 1),
1352  number<1>{});
1353 
1354  return pad_tensor_view(
1355  o_dram_naive,
1358  }();
1359 
1360  auto o_dram_window =
1361  make_tile_window(o_dram,
1363  {i_m0, i_n1});
1364 
1365  EpiloguePipeline{}(o_dram_window, o_acc_tile);
1366  }
1367 };
1368 
1369 } // namespace ck_tile
#define CK_TILE_DEVICE
Definition: config.hpp:40
#define CK_TILE_HOST
Definition: config.hpp:39
#define CK_TILE_HOST_DEVICE
Definition: config.hpp:41
#define _TS_
#define _SS_
__host__ constexpr __device__ T max(T x)
Definition: math.hpp:84
__host__ T floor(T x)
Definition: math_v2.hpp:367
Definition: cluster_descriptor.hpp:13
constexpr CK_TILE_DEVICE auto make_null_tile_window(const WindowLengths &window_lengths)
Definition: null_tile_window.hpp:63
_BitInt(8) fp8_t
Definition: float8.hpp:204
constexpr CK_TILE_HOST_DEVICE auto integer_divide_ceil(X x, Y y)
Definition: math.hpp:149
_Float16 fp16_t
Definition: half.hpp:110
constexpr CK_TILE_HOST_DEVICE auto transform_tensor_view(const OldTensorView &old_tensor_view, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition: tensor_view.hpp:461
bfloat16_t bf16_t
Definition: bfloat16.hpp:106
int32_t index_t
Definition: integer.hpp:9
constexpr CK_TILE_HOST_DEVICE auto pad_tensor_view(const TensorView &tensor_view, const TileLengths &tile_lengths, DoPads)
Definition: tensor_view.hpp:480
constexpr CK_TILE_HOST_DEVICE auto make_pass_through_transform(const LowLength &low_length)
Definition: coordinate_transform.hpp:1615
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition: type_traits.hpp:20
int64_t long_index_t
Definition: integer.hpp:10
constexpr CK_TILE_DEVICE auto make_tile_window(null_tensor_view, const WindowLengths &window_lengths, const multi_index< WindowLengths::size()> &, Ts &&...)
Definition: null_tile_window.hpp:72
unsigned _BitInt(8) bf8_t
Definition: float8.hpp:206
constexpr CK_TILE_HOST_DEVICE auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:337
GenericAttentionMaskEnum
Definition: block_masking.hpp:11
__host__ __device__ composes(Ts &&...) -> composes< remove_cvref_t< Ts >... >
FIXME: create macro to replace 'host device' and nothing more.
constexpr CK_TILE_HOST_DEVICE T max(T x)
Definition: math.hpp:161
typename conditional< predicate, X, Y >::type conditional_t
Definition: functional.hpp:115
Definition: block_position_encoding.hpp:48
Definition: block_attention_bias_enum.hpp:19
Definition: block_dropout.hpp:26
const float rp_undrop
Definition: block_dropout.hpp:290
Definition: block_position_encoding.hpp:137
Definition: fmha_fwd_kernel.hpp:153
ck_tile::index_t alibi_slope_stride
Definition: fmha_fwd_kernel.hpp:156
const void * alibi_slope_ptr
Definition: fmha_fwd_kernel.hpp:155
Definition: fmha_fwd_kernel.hpp:148
ck_tile::index_t batch_stride_bias
Definition: fmha_fwd_kernel.hpp:149
Definition: fmha_fwd_kernel.hpp:229
ck_tile::index_t batch_stride_randval
Definition: fmha_fwd_kernel.hpp:230
Definition: fmha_fwd_kernel.hpp:244
ck_tile::index_t batch_stride_o
Definition: fmha_fwd_kernel.hpp:248
ck_tile::index_t batch_stride_q
Definition: fmha_fwd_kernel.hpp:245
ck_tile::index_t batch_stride_k
Definition: fmha_fwd_kernel.hpp:246
ck_tile::index_t batch_stride_v
Definition: fmha_fwd_kernel.hpp:247
Definition: fmha_fwd_kernel.hpp:141
const void * bias_ptr
Definition: fmha_fwd_kernel.hpp:142
ck_tile::index_t stride_bias
Definition: fmha_fwd_kernel.hpp:143
ck_tile::index_t nhead_stride_bias
Definition: fmha_fwd_kernel.hpp:144
Definition: fmha_fwd_kernel.hpp:194
void init_dropout(float p_drop, const uint64_t *seed_ptr, const uint64_t *offset_ptr)
Definition: fmha_fwd_kernel.hpp:207
float rp_undrop
Definition: fmha_fwd_kernel.hpp:219
ck_tile::index_t stride_randval
Definition: fmha_fwd_kernel.hpp:224
ck_tile::index_t nhead_stride_randval
Definition: fmha_fwd_kernel.hpp:225
void * rand_val_ptr
Definition: fmha_fwd_kernel.hpp:222
void init_dropout(float p_drop, uint64_t seed, uint64_t offset)
Definition: fmha_fwd_kernel.hpp:195
bool is_store_randval
Definition: fmha_fwd_kernel.hpp:221
uint8_t p_undrop_in_uint8_t
Definition: fmha_fwd_kernel.hpp:220
Definition: fmha_fwd_kernel.hpp:112
ck_tile::index_t nhead_stride_k
Definition: fmha_fwd_kernel.hpp:135
float scale_s
Definition: fmha_fwd_kernel.hpp:127
ck_tile::index_t seqlen_k
Definition: fmha_fwd_kernel.hpp:119
ck_tile::index_t nhead_stride_o
Definition: fmha_fwd_kernel.hpp:137
ck_tile::index_t nhead_ratio_qk
Definition: fmha_fwd_kernel.hpp:126
ck_tile::index_t num_head_q
Definition: fmha_fwd_kernel.hpp:123
ck_tile::index_t hdim_q
Definition: fmha_fwd_kernel.hpp:120
const void * v_ptr
Definition: fmha_fwd_kernel.hpp:115
void * o_ptr
Definition: fmha_fwd_kernel.hpp:116
const void * k_ptr
Definition: fmha_fwd_kernel.hpp:114
ck_tile::index_t nhead_stride_q
Definition: fmha_fwd_kernel.hpp:134
ck_tile::index_t stride_k
Definition: fmha_fwd_kernel.hpp:130
ck_tile::index_t stride_o
Definition: fmha_fwd_kernel.hpp:132
ck_tile::index_t stride_v
Definition: fmha_fwd_kernel.hpp:131
ck_tile::index_t hdim_v
Definition: fmha_fwd_kernel.hpp:121
ck_tile::index_t nhead_stride_v
Definition: fmha_fwd_kernel.hpp:136
const void * q_ptr
Definition: fmha_fwd_kernel.hpp:113
ck_tile::index_t seqlen_q
Definition: fmha_fwd_kernel.hpp:118
ck_tile::index_t stride_q
Definition: fmha_fwd_kernel.hpp:129
Definition: fmha_fwd_kernel.hpp:173
ck_tile::index_t batch_stride_lse
Definition: fmha_fwd_kernel.hpp:176
void * lse_ptr
Definition: fmha_fwd_kernel.hpp:174
ck_tile::index_t nhead_stride_lse
Definition: fmha_fwd_kernel.hpp:175
Definition: fmha_fwd_kernel.hpp:180
bool is_drop_seed_offset_from_host
Definition: fmha_fwd_kernel.hpp:190
ValueOrPointer< uint64_t > drop_seed
Definition: fmha_fwd_kernel.hpp:188
ValueOrPointer< uint64_t > drop_offset
Definition: fmha_fwd_kernel.hpp:189
Definition: fmha_fwd_kernel.hpp:105
Definition: fmha_fwd_kernel.hpp:167
float scale_o
Definition: fmha_fwd_kernel.hpp:169
float scale_p
Definition: fmha_fwd_kernel.hpp:168
Definition: fmha_fwd_kernel.hpp:262
const int32_t * seqstart_q_ptr
Definition: fmha_fwd_kernel.hpp:263
const int32_t * seqlen_k_ptr
Definition: fmha_fwd_kernel.hpp:265
const int32_t * seqstart_k_ptr
Definition: fmha_fwd_kernel.hpp:264
Definition: fmha_fwd_kernel.hpp:160
ck_tile::GenericAttentionMaskEnum mask_type
Definition: fmha_fwd_kernel.hpp:163
ck_tile::index_t window_size_right
Definition: fmha_fwd_kernel.hpp:162
ck_tile::index_t window_size_left
Definition: fmha_fwd_kernel.hpp:162
Definition: fmha_fwd_kernel.hpp:58
Definition: fmha_fwd_kernel.hpp:25
static constexpr bool kHasDropout
Definition: fmha_fwd_kernel.hpp:52
static CK_TILE_HOST std::string GetName()
Definition: fmha_fwd_kernel.hpp:66
static constexpr CK_TILE_HOST std::enable_if_t< Cond, Kargs > MakeKargs(const void *q_ptr, const void *k_ptr, const void *v_ptr, const void *bias_ptr, void *rand_val_ptr, void *lse_ptr, void *o_ptr, const void *seqstart_q_ptr, const void *seqstart_k_ptr, const void *seqlen_k_ptr, ck_tile::index_t hdim_q, ck_tile::index_t hdim_v, ck_tile::index_t num_head_q, ck_tile::index_t nhead_ratio_qk, float scale_s, float scale_p, float scale_o, ck_tile::index_t stride_q, ck_tile::index_t stride_k, ck_tile::index_t stride_v, ck_tile::index_t stride_bias, ck_tile::index_t stride_randval, ck_tile::index_t stride_o, ck_tile::index_t nhead_stride_q, ck_tile::index_t nhead_stride_k, ck_tile::index_t nhead_stride_v, ck_tile::index_t nhead_stride_bias, ck_tile::index_t nhead_stride_randval, ck_tile::index_t nhead_stride_lse, ck_tile::index_t nhead_stride_o, ck_tile::index_t window_size_left, ck_tile::index_t window_size_right, ck_tile::index_t mask_type, float p_drop, bool s_randval, const std::tuple< uint64_t, uint64_t > &drop_seed_offset)
Definition: fmha_fwd_kernel.hpp:711
static constexpr CK_TILE_HOST std::enable_if_t< Cond, Kargs > MakeKargsImpl(const void *q_ptr, const void *k_ptr, const void *v_ptr, const void *bias_ptr, void *rand_val_ptr, void *lse_ptr, void *o_ptr, ck_tile::index_t seqlen_q, ck_tile::index_t seqlen_k, ck_tile::index_t hdim_q, ck_tile::index_t hdim_v, ck_tile::index_t num_head_q, ck_tile::index_t nhead_ratio_qk, float scale_s, float scale_p, float scale_o, ck_tile::index_t stride_q, ck_tile::index_t stride_k, ck_tile::index_t stride_v, ck_tile::index_t stride_bias, ck_tile::index_t stride_randval, ck_tile::index_t stride_o, ck_tile::index_t nhead_stride_q, ck_tile::index_t nhead_stride_k, ck_tile::index_t nhead_stride_v, ck_tile::index_t nhead_stride_bias, ck_tile::index_t nhead_stride_randval, ck_tile::index_t nhead_stride_lse, ck_tile::index_t nhead_stride_o, ck_tile::index_t batch_stride_q, ck_tile::index_t batch_stride_k, ck_tile::index_t batch_stride_v, ck_tile::index_t batch_stride_bias, ck_tile::index_t batch_stride_randval, ck_tile::index_t batch_stride_lse, ck_tile::index_t batch_stride_o, ck_tile::index_t window_size_left, ck_tile::index_t window_size_right, ck_tile::index_t mask_type, float p_drop, bool s_randval, std::variant< std::pair< uint64_t, uint64_t >, std::pair< const void *, const void * >> drop_seed_offset)
Definition: fmha_fwd_kernel.hpp:272
static constexpr bool kDoFp8StaticQuant
Definition: fmha_fwd_kernel.hpp:53
static constexpr CK_TILE_HOST std::enable_if_t< Cond, Kargs > MakeKargs(const void *q_ptr, const void *k_ptr, const void *v_ptr, const void *bias_ptr, void *rand_val_ptr, void *lse_ptr, void *o_ptr, const void *seqstart_q_ptr, const void *seqstart_k_ptr, const void *seqlen_k_ptr, ck_tile::index_t hdim_q, ck_tile::index_t hdim_v, ck_tile::index_t num_head_q, ck_tile::index_t nhead_ratio_qk, float scale_s, float scale_p, float scale_o, ck_tile::index_t stride_q, ck_tile::index_t stride_k, ck_tile::index_t stride_v, ck_tile::index_t stride_bias, ck_tile::index_t stride_randval, ck_tile::index_t stride_o, ck_tile::index_t nhead_stride_q, ck_tile::index_t nhead_stride_k, ck_tile::index_t nhead_stride_v, ck_tile::index_t nhead_stride_bias, ck_tile::index_t nhead_stride_randval, ck_tile::index_t nhead_stride_lse, ck_tile::index_t nhead_stride_o, ck_tile::index_t window_size_left, ck_tile::index_t window_size_right, ck_tile::index_t mask_type, float p_drop, bool s_randval, const std::tuple< const void *, const void * > &drop_seed_offset)
Definition: fmha_fwd_kernel.hpp:790
static constexpr bool kStoreLSE
Definition: fmha_fwd_kernel.hpp:51
static constexpr CK_TILE_HOST std::enable_if_t< Cond, Kargs > MakeKargs(const void *q_ptr, const void *k_ptr, const void *v_ptr, const void *bias_ptr, void *rand_val_ptr, void *lse_ptr, void *o_ptr, ck_tile::index_t seqlen_q, ck_tile::index_t seqlen_k, ck_tile::index_t hdim_q, ck_tile::index_t hdim_v, ck_tile::index_t num_head_q, ck_tile::index_t nhead_ratio_qk, float scale_s, float scale_p, float scale_o, ck_tile::index_t stride_q, ck_tile::index_t stride_k, ck_tile::index_t stride_v, ck_tile::index_t stride_bias, ck_tile::index_t stride_randval, ck_tile::index_t stride_o, ck_tile::index_t nhead_stride_q, ck_tile::index_t nhead_stride_k, ck_tile::index_t nhead_stride_v, ck_tile::index_t nhead_stride_bias, ck_tile::index_t nhead_stride_randval, ck_tile::index_t nhead_stride_lse, ck_tile::index_t nhead_stride_o, ck_tile::index_t batch_stride_q, ck_tile::index_t batch_stride_k, ck_tile::index_t batch_stride_v, ck_tile::index_t batch_stride_bias, ck_tile::index_t batch_stride_randval, ck_tile::index_t batch_stride_lse, ck_tile::index_t batch_stride_o, ck_tile::index_t window_size_left, ck_tile::index_t window_size_right, ck_tile::index_t mask_type, float p_drop, bool s_randval, const std::tuple< uint64_t, uint64_t > &drop_seed_offset)
Definition: fmha_fwd_kernel.hpp:406
ck_tile::remove_cvref_t< typename FmhaPipeline::KDataType > KDataType
Definition: fmha_fwd_kernel.hpp:34
std::conditional_t< kIsGroupMode, FmhaFwdGroupModeKargs, FmhaFwdBatchModeKargs > Kargs
Definition: fmha_fwd_kernel.hpp:268
static constexpr CK_TILE_HOST std::enable_if_t< Cond, Kargs > MakeKargs(const void *q_ptr, const void *k_ptr, const void *v_ptr, const void *bias_ptr, void *rand_val_ptr, void *lse_ptr, void *o_ptr, ck_tile::index_t seqlen_q, ck_tile::index_t seqlen_k, ck_tile::index_t hdim_q, ck_tile::index_t hdim_v, ck_tile::index_t num_head_q, ck_tile::index_t nhead_ratio_qk, float scale_s, float scale_p, float scale_o, ck_tile::index_t stride_q, ck_tile::index_t stride_k, ck_tile::index_t stride_v, ck_tile::index_t stride_bias, ck_tile::index_t stride_randval, ck_tile::index_t stride_o, ck_tile::index_t nhead_stride_q, ck_tile::index_t nhead_stride_k, ck_tile::index_t nhead_stride_v, ck_tile::index_t nhead_stride_bias, ck_tile::index_t nhead_stride_randval, ck_tile::index_t nhead_stride_lse, ck_tile::index_t nhead_stride_o, ck_tile::index_t batch_stride_q, ck_tile::index_t batch_stride_k, ck_tile::index_t batch_stride_v, ck_tile::index_t batch_stride_bias, ck_tile::index_t batch_stride_randval, ck_tile::index_t batch_stride_lse, ck_tile::index_t batch_stride_o, ck_tile::index_t window_size_left, ck_tile::index_t window_size_right, ck_tile::index_t mask_type, float p_drop, bool s_randval, const std::tuple< const void *, const void * > &drop_seed_offset)
Definition: fmha_fwd_kernel.hpp:497
static constexpr ck_tile::index_t kBlockPerCu
Definition: fmha_fwd_kernel.hpp:29
ck_tile::remove_cvref_t< typename FmhaPipeline::ODataType > ODataType
Definition: fmha_fwd_kernel.hpp:40
ck_tile::remove_cvref_t< typename FmhaPipeline::VLayout > VLayout
Definition: fmha_fwd_kernel.hpp:43
static constexpr CK_TILE_HOST std::enable_if_t< Cond, Kargs > MakeKargsImpl(const void *q_ptr, const void *k_ptr, const void *v_ptr, const void *bias_ptr, void *rand_val_ptr, void *lse_ptr, void *o_ptr, const void *seqstart_q_ptr, const void *seqstart_k_ptr, const void *seqlen_k_ptr, ck_tile::index_t hdim_q, ck_tile::index_t hdim_v, ck_tile::index_t num_head_q, ck_tile::index_t nhead_ratio_qk, float scale_s, float scale_p, float scale_o, ck_tile::index_t stride_q, ck_tile::index_t stride_k, ck_tile::index_t stride_v, ck_tile::index_t stride_bias, ck_tile::index_t stride_randval, ck_tile::index_t stride_o, ck_tile::index_t nhead_stride_q, ck_tile::index_t nhead_stride_k, ck_tile::index_t nhead_stride_v, ck_tile::index_t nhead_stride_bias, ck_tile::index_t nhead_stride_randval, ck_tile::index_t nhead_stride_lse, ck_tile::index_t nhead_stride_o, ck_tile::index_t window_size_left, ck_tile::index_t window_size_right, ck_tile::index_t mask_type, float p_drop, bool s_randval, std::variant< std::pair< uint64_t, uint64_t >, std::pair< const void *, const void * >> drop_seed_offset)
Definition: fmha_fwd_kernel.hpp:587
static constexpr ck_tile::index_t kBlockSize
Definition: fmha_fwd_kernel.hpp:28
ck_tile::remove_cvref_t< typename FmhaPipeline::BiasDataType > BiasDataType
Definition: fmha_fwd_kernel.hpp:36
static constexpr CK_TILE_HOST auto GridSize(ck_tile::index_t batch_size_, ck_tile::index_t nhead_, ck_tile::index_t seqlen_q_, ck_tile::index_t hdim_v_, bool has_padded_seqlen_k=false)
Definition: fmha_fwd_kernel.hpp:866
static constexpr CK_TILE_HOST auto BlockSize()
Definition: fmha_fwd_kernel.hpp:940
ck_tile::remove_cvref_t< typename FmhaPipeline::VDataType > VDataType
Definition: fmha_fwd_kernel.hpp:35
static constexpr ck_tile::index_t kBlockPerCuInput
Definition: fmha_fwd_kernel.hpp:31
static constexpr auto BiasEnum
Definition: fmha_fwd_kernel.hpp:50
static constexpr bool kPadHeadDimV
Definition: fmha_fwd_kernel.hpp:49
static constexpr CK_TILE_DEVICE auto GetTileIndex(const Kargs &kargs)
Definition: fmha_fwd_kernel.hpp:891
ck_tile::remove_cvref_t< typename FmhaPipeline::LSEDataType > LSEDataType
Definition: fmha_fwd_kernel.hpp:39
ck_tile::remove_cvref_t< typename FmhaPipeline::QDataType > QDataType
Definition: fmha_fwd_kernel.hpp:33
static constexpr CK_TILE_HOST_DEVICE ck_tile::index_t GetSmemSize()
Definition: fmha_fwd_kernel.hpp:942
static constexpr bool kHasMask
Definition: fmha_fwd_kernel.hpp:55
ck_tile::remove_cvref_t< FmhaPipeline_ > FmhaPipeline
Definition: fmha_fwd_kernel.hpp:26
static constexpr bool kPadHeadDimQ
Definition: fmha_fwd_kernel.hpp:48
ck_tile::remove_cvref_t< typename FmhaPipeline::SaccDataType > SaccDataType
Definition: fmha_fwd_kernel.hpp:41
static constexpr bool kPadSeqLenQ
Definition: fmha_fwd_kernel.hpp:46
ck_tile::remove_cvref_t< typename FmhaPipeline::FmhaMask > FmhaMask
Definition: fmha_fwd_kernel.hpp:54
ck_tile::remove_cvref_t< typename FmhaPipeline::RandValOutputDataType > RandValOutputDataType
Definition: fmha_fwd_kernel.hpp:38
ck_tile::remove_cvref_t< EpiloguePipeline_ > EpiloguePipeline
Definition: fmha_fwd_kernel.hpp:27
static constexpr bool kPadSeqLenK
Definition: fmha_fwd_kernel.hpp:47
static constexpr bool kIsGroupMode
Definition: fmha_fwd_kernel.hpp:45
CK_TILE_DEVICE void operator()(Kargs kargs) const
Definition: fmha_fwd_kernel.hpp:947
Definition: block_dropout.hpp:12
Definition: integral_constant.hpp:13
Definition: functional.hpp:62
Definition: coordinate_transform.hpp:1443
Definition: unary_element_function.hpp:56
Definition: math.hpp:28
Definition: sequence.hpp:52
const T * ptr
Definition: fmha_fwd_kernel.hpp:185