/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp Source File
fmha_bwd_kernel.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include "ck_tile/core.hpp"
7 #include "ck_tile/ops/common.hpp"
10 
11 #include <string>
12 #include <type_traits>
13 #include <utility>
14 #include <variant>
15 
16 // S[seqlen_q, seqlen_k] = Q[seqlen_q, hdim_q] @ K[seqlen_k, hdim_q]
17 // S'[seqlen_q, seqlen_k] = S[seqlen_q, seqlen_k] * Scale[1]
18 // S''[seqlen_q, seqlen_k] = S'[seqlen_q, seqlen_k] + Bias[seqlen_q, seqlen_k]
19 // P[seqlen_q, seqlen_k] = Softmax(S''[seqlen_q, seqlen_k])
20 // dV[seqlen_k, hdim_v] = P^T[seqlen_k, seqlen_q] @ dO^T[hdim_v, seqlen_q]
21 // dP[seqlen_q, seqlen_k] = dO[seqlen_q, hdim_v] @ V[seqlen_k, hdim_v]
22 // D[seqlen_q] = rowsum(dO[seqlen_q, hdim_v] * O[seqlen_q, hdim_v])
23 // dS''[seqlen_q, seqlen_k] = P[seqlen_q, seqlen_k] * (dP[seqlen_q, seqlen_k] - D[seqlen_q])
24 // dBias[seqlen_q, seqlen_k] = dS'[seqlen_q, seqlen_k] = dS''[seqlen_q, seqlen_k]
25 // dK[seqlen_k, hdim_q] = dS'^T[seqlen_k, seqlen_q] @ Q^T[hdim_q, seqlen_q] * Scale[1]
26 // dQ[seqlen_q, hdim_q] = dS'[seqlen_q, seqlen_k] @ K^T[hdim_q, seqlen_k] * Scale[1]
27 
28 namespace ck_tile {
29 
30 template <typename FmhaPipeline_,
31  typename KGradEpiloguePipeline_,
32  typename VGradEpiloguePipeline_,
33  typename QGradEpiloguePipeline_ = void>
35 {
40  static constexpr ck_tile::index_t kBlockSize = FmhaPipeline::kBlockSize;
41  static constexpr ck_tile::index_t kBlockPerCu = FmhaPipeline::kBlockPerCu;
42  static constexpr bool kUseQrQtrDorPipeline =
44  static_assert(!kUseQrQtrDorPipeline || !std::is_same_v<QGradEpiloguePipeline_, void>,
45  "QrQtrDorPipeline needs QGradEpiloguePipeline");
46 
62 
63  static constexpr bool kIsGroupMode = FmhaPipeline::kIsGroupMode;
64  static constexpr index_t kPadHeadDimQ = FmhaPipeline::kPadHeadDimQ;
65  static constexpr index_t kPadHeadDimV = FmhaPipeline::kPadHeadDimV;
66  static constexpr auto BiasEnum = FmhaPipeline::BiasEnum;
67  static constexpr bool kHasBiasGrad = FmhaPipeline::kHasBiasGrad;
70  static constexpr bool kHasMask = FmhaMask::IsMasking;
71  static constexpr bool kHasDropout = FmhaDropout::IsDropout;
72  static constexpr bool kIsStoreRandval = FmhaDropout::IsStoreRandval;
73  static constexpr bool kIsDeterministic = FmhaPipeline::kIsDeterministic;
74  static constexpr bool kUseTrLoad = FmhaPipeline::kUseTrLoad;
75  static constexpr index_t kMaxSeqLenQ = FmhaPipeline::BlockFmhaShape::kMaxSeqLenQ;
76  static_assert(kUseQrQtrDorPipeline == (kMaxSeqLenQ != 0));
77 #if defined(__gfx950__)
78  static constexpr bool kIsAvailable = true;
79 #else
80  static constexpr bool kIsAvailable = !kUseTrLoad;
81 #endif
82 
83  // clang-format off
84  template <typename T> struct t2s;
85  template <> struct t2s<float> { static constexpr const char * name = "fp32"; };
86  template <> struct t2s<ck_tile::fp16_t> { static constexpr const char * name = "fp16"; };
87  template <> struct t2s<ck_tile::bf16_t> { static constexpr const char * name = "bf16"; };
88  // clang-format on
89 
90  CK_TILE_HOST static std::string GetName()
91  {
92  // sync with generate.py
93  // clang-format off
94  using bfs = typename FmhaPipeline::BlockFmhaShape;
95  using gbr0 = typename bfs::Gemm0BlockWarps;
96  using gbr1 = typename bfs::Gemm1BlockWarps;
97  using gbr4 = typename bfs::Gemm4BlockWarps;
98  using gwt0 = typename bfs::Gemm0WarpTile;
99  using gwt1 = typename bfs::Gemm1WarpTile;
100  #define _SS_ std::string
101  #define _TS_ std::to_string
102  auto pn = [&] () {
103  std::string n;
104  if (kPadHeadDimQ) n += "d" + _TS_(kPadHeadDimQ);
105  if (kPadHeadDimV) n += "dv"+ _TS_(kPadHeadDimV);
106  return n.empty() ? n : std::string("p") + n; }();
107  return
108  _SS_("fmha_bwd_d") + _TS_(bfs::kQKHeaddim) + "_" + _SS_(t2s<QDataType>::name) +
109  "_" + (kIsGroupMode ? "group" : "batch") + "_" +
110  "b" + _TS_(bfs::kM0) + "x" + _TS_(bfs::kN0) + "x" + _TS_(bfs::kK0) + "x" + _TS_(bfs::kK1) + "x" + _TS_(bfs::kK2) + "x" + _TS_(bfs::kK3) + "x" +
111  _TS_(bfs::kK4) + "x" + _TS_(bfs::kQKHeaddim) + "x" + _TS_(bfs::kVHeaddim) + "_" +
112  "r" + _TS_(gbr0::at(ck_tile::number<0>{})) + "x" + _TS_(gbr0::at(ck_tile::number<1>{})) + "x" + _TS_(gbr0::at(ck_tile::number<2>{})) + "_" +
113  "r" + _TS_(gbr1::at(ck_tile::number<0>{})) + "x" + _TS_(gbr1::at(ck_tile::number<1>{})) + "x" + _TS_(gbr1::at(ck_tile::number<2>{})) + "_" +
114  "r" + _TS_(gbr4::at(ck_tile::number<0>{})) + "x" + _TS_(gbr4::at(ck_tile::number<1>{})) + "x" + _TS_(gbr4::at(ck_tile::number<2>{})) + "_" +
115  "w" + _TS_(gwt0::at(ck_tile::number<0>{})) + "x" + _TS_(gwt0::at(ck_tile::number<1>{})) + "x" + _TS_(gwt0::at(ck_tile::number<2>{})) + "_" +
116  "w" + _TS_(gwt1::at(ck_tile::number<0>{})) + "x" + _TS_(gwt1::at(ck_tile::number<1>{})) + "x" + _TS_(gwt1::at(ck_tile::number<2>{})) + "_" +
117  ("o" + _TS_(kBlockPerCu)) + "_" +
118  ("maxq" + _TS_(kMaxSeqLenQ)) +
119  (pn.empty() ? "_npad" : "_" + pn) +
121  (kHasBiasGrad ? "_dbias" : "_ndbias") + (kHasMask ? "_" + _SS_(FmhaMask::name) : "_nmask") + (kHasDropout ? gwt0::at(ck_tile::number<0>{}) == 16? "_dropout_wg16":"_dropout_wg32" : "_ndropout" ) +
122  (kIsStoreRandval ? "_storerandval" : "" ) + (kIsDeterministic ? "_deterministic" : "_ndeterministic" ) + (kUseTrLoad ? "_trload" : "_ntrload");
123  #undef _SS_
124  #undef _TS_
125  // clang-format on
126  }
127 
128  template <ck_tile::index_t I> // to avoid duplicated base class prblem, introduce an template
129  // arg
131  {
132  };
133 
134  // kargs use aggregate initializer, so no constructor will provided
135  // use inheritance to minimize karg size
136  // user need to use MakeKargs() function to create kargs.
138  {
139  const void* q_ptr;
140  const void* k_ptr;
141  const void* v_ptr;
142  const void* lse_ptr;
143  const void* do_ptr;
144  const void* d_ptr;
145  void* dq_acc_ptr; // can be dq_ptr for qrqtrdor pipeline
146  void* dk_ptr;
147  void* dv_ptr;
148 
153 
154  // for MQA/GQA, nhead could be different. This parameter is nhead_q / nhead_k
155  // if this param is larger than 1, indicate MQA/GQA case
158  float raw_scale;
159  float scale;
160 
168 
177  };
178 
180  {
181  const void* bias_ptr = nullptr;
184  };
185 
187  {
189  };
190 
192  {
193  // alibi is batch*nhead*1, no matter in batch/group mode, they are the same
194  const void* alibi_slope_ptr;
195  ck_tile::index_t alibi_slope_stride; // stride in batch, or 0 for all batch share same slope
196  };
197 
199  {
200  void* dbias_ptr = nullptr;
203  };
204 
206  {
208  };
209 
211  {
214  };
215 
217  {
218  template <typename T>
220  {
221  T val;
222  const T* ptr;
223  };
224 
228  };
229 
231  {
232  void init_dropout(float p_drop, uint64_t seed, uint64_t offset, float raw_scale)
233  {
234  float p_undrop = 1.0 - p_drop;
237  rp_undrop = 1.0 / p_undrop;
238  scale_rp_undrop = rp_undrop * raw_scale;
239 
240  this->drop_seed.val = seed;
241  this->drop_offset.val = offset;
242  this->is_drop_seed_offset_from_host = true;
243  }
244 
245  void init_dropout(float p_drop,
246  const uint64_t* seed_ptr,
247  const uint64_t* offset_ptr,
248  float raw_scale)
249  {
250  float p_undrop = 1.0 - p_drop;
253  rp_undrop = 1.0 / p_undrop;
254  scale_rp_undrop = rp_undrop * raw_scale;
255 
256  this->drop_seed.ptr = seed_ptr;
257  this->drop_offset.ptr = offset_ptr;
258  this->is_drop_seed_offset_from_host = false;
259  }
260 
261  float rp_undrop = 1;
262  float scale_rp_undrop = 1;
264  void* rand_val_ptr = nullptr;
265 
268  };
269 
271  {
273  };
274 
276  {
278  };
279 
282  std::conditional_t<BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS,
283  FmhaBwdBatchModeBiasKargs,
284  std::conditional_t<BiasEnum == BlockAttentionBiasEnum::ALIBI,
285  FmhaBwdAlibiKargs,
286  FmhaBwdEmptyKargs<0>>>,
287  std::conditional_t<kHasBiasGrad, FmhaBwdBatchModeBiasGradKargs, FmhaBwdEmptyKargs<1>>,
288  std::conditional_t<kHasMask, FmhaBwdMaskKargs, FmhaBwdEmptyKargs<2>>,
289  std::conditional_t<kHasDropout, FmhaBwdBatchModeDropoutKargs, FmhaBwdEmptyKargs<3>>,
290  std::conditional_t<kIsDeterministic, FmhaBwdDeterministicKargs, FmhaBwdEmptyKargs<4>>
291  {
300  };
301 
304  std::conditional_t<BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS,
305  FmhaBwdCommonBiasKargs,
306  std::conditional_t<BiasEnum == BlockAttentionBiasEnum::ALIBI,
307  FmhaBwdAlibiKargs,
308  FmhaBwdEmptyKargs<0>>>,
309  std::conditional_t<kHasBiasGrad, FmhaBwdCommonBiasGradKargs, FmhaBwdEmptyKargs<1>>,
310  std::conditional_t<kHasMask, FmhaBwdMaskKargs, FmhaBwdEmptyKargs<2>>,
311  std::conditional_t<kHasDropout, FmhaBwdCommonDropoutKargs, FmhaBwdEmptyKargs<3>>,
312  std::conditional_t<kIsDeterministic, FmhaBwdDeterministicKargs, FmhaBwdEmptyKargs<4>>
313  {
316  const int32_t* seqlen_q_ptr; // per-batch actual length [batch]
317  const int32_t* seqlen_k_ptr; // per-batch actual length [batch]
318  const int32_t* cu_seqlen_q_ptr; // cumulative seqlen [batch+1], optional
319  const int32_t* cu_seqlen_k_ptr; // cumulative seqlen [batch+1], optional
320  };
321 
322  using Kargs = std::conditional_t<kIsGroupMode, FmhaBwdGroupModeKargs, FmhaBwdBatchModeKargs>;
323 
324  // std::variant<> can't take in a list initializer, overload for backward compatibility
325  template <typename... Ts>
326  CK_TILE_HOST static constexpr Kargs
327  MakeKargs(Ts... args, const std::tuple<uint64_t, uint64_t>& drop_seed_offset)
328  {
329  return MakeKargsImpl(
330  args..., std::make_pair(std::get<0>(drop_seed_offset), std::get<1>(drop_seed_offset)));
331  }
332 
333  // std::variant<> can't take in a list initializer, overload for backward compatibility
334  template <typename... Ts>
335  CK_TILE_HOST static constexpr Kargs
336  MakeKargs(Ts... args, const std::tuple<const void*, const void*>& drop_seed_offset)
337  {
338  return MakeKargsImpl(
339  args..., std::make_pair(std::get<0>(drop_seed_offset), std::get<1>(drop_seed_offset)));
340  }
341 
342  template <bool Cond = !kIsGroupMode>
343  CK_TILE_HOST static constexpr std::enable_if_t<Cond, Kargs>
344  MakeKargsImpl(const void* q_ptr,
345  const void* k_ptr,
346  const void* v_ptr,
347  const void* bias_ptr,
348  const void* lse_ptr,
349  const void* do_ptr,
350  const void* d_ptr,
351  void* rand_val_ptr,
352  void* dk_ptr,
353  void* dv_ptr,
354  void* dbias_ptr,
355  void* dq_acc_ptr, // can be dq_acc_ptr for qrqtrdor pipeline
356  ck_tile::index_t seqlen_q,
357  ck_tile::index_t seqlen_k,
358  ck_tile::index_t hdim_q,
359  ck_tile::index_t hdim_v,
360  ck_tile::index_t num_head_q,
361  ck_tile::index_t nhead_ratio_qk,
362  float scale,
363  ck_tile::index_t stride_q,
364  ck_tile::index_t stride_k,
365  ck_tile::index_t stride_v,
366  ck_tile::index_t stride_bias,
367  ck_tile::index_t stride_randval,
368  ck_tile::index_t stride_do,
369  ck_tile::index_t stride_dq_acc,
370  ck_tile::index_t stride_dk,
371  ck_tile::index_t stride_dv,
372  ck_tile::index_t stride_dbias,
373  ck_tile::index_t nhead_stride_q,
374  ck_tile::index_t nhead_stride_k,
375  ck_tile::index_t nhead_stride_v,
376  ck_tile::index_t nhead_stride_bias,
377  ck_tile::index_t nhead_stride_randval,
378  ck_tile::index_t nhead_stride_do,
379  ck_tile::index_t nhead_stride_lsed,
380  ck_tile::index_t nhead_stride_dq_acc,
381  ck_tile::index_t nhead_stride_dk,
382  ck_tile::index_t nhead_stride_dv,
383  ck_tile::index_t nhead_stride_dbias,
384  ck_tile::index_t batch_stride_q,
385  ck_tile::index_t batch_stride_k,
386  ck_tile::index_t batch_stride_v,
387  ck_tile::index_t batch_stride_bias,
388  ck_tile::index_t batch_stride_randval,
389  ck_tile::index_t batch_stride_do,
390  ck_tile::index_t batch_stride_lsed,
391  ck_tile::index_t batch_stride_dq_acc,
392  ck_tile::index_t batch_stride_dk,
393  ck_tile::index_t batch_stride_dv,
394  ck_tile::index_t batch_stride_dbias,
395  ck_tile::index_t split_stride_dq_acc,
396  ck_tile::index_t window_size_left,
397  ck_tile::index_t window_size_right,
398  ck_tile::index_t mask_type,
399  float p_drop,
400  std::variant<std::pair<uint64_t, uint64_t>, std::pair<const void*, const void*>>
401  drop_seed_offset)
402  {
403  Kargs kargs{{q_ptr,
404  k_ptr,
405  v_ptr,
406  lse_ptr,
407  do_ptr,
408  d_ptr,
409  dq_acc_ptr,
410  dk_ptr,
411  dv_ptr,
412  seqlen_q,
413  seqlen_k,
414  hdim_q,
415  hdim_v,
416  num_head_q,
417  nhead_ratio_qk,
418  scale,
419  static_cast<float>(scale * ck_tile::log2e_v<>),
420  stride_q,
421  stride_k,
422  stride_v,
423  stride_do,
424  stride_dq_acc,
425  stride_dk,
426  stride_dv,
427  nhead_stride_q,
428  nhead_stride_k,
429  nhead_stride_v,
430  nhead_stride_do,
431  nhead_stride_lsed,
432  nhead_stride_dq_acc,
433  nhead_stride_dk,
434  nhead_stride_dv}, // args for common karg
435  {}, // placeholder for bias
436  {}, // placeholder for dbias
437  {}, // placeholder for mask
438  {}, // placeholder for dropout
439  {}, // placeholder for deterministic
440  batch_stride_q,
441  batch_stride_k,
442  batch_stride_v,
443  batch_stride_do,
444  batch_stride_lsed,
445  batch_stride_dq_acc,
446  batch_stride_dk,
447  batch_stride_dv};
448 
450  {
451  kargs.bias_ptr = bias_ptr;
452  kargs.stride_bias = stride_bias;
453  kargs.nhead_stride_bias = nhead_stride_bias;
454  kargs.batch_stride_bias = batch_stride_bias;
455  }
456  else if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI)
457  {
458  kargs.alibi_slope_ptr = bias_ptr;
459  kargs.alibi_slope_stride = stride_bias;
460  }
461 
462  if constexpr(kHasBiasGrad)
463  {
464  kargs.dbias_ptr = dbias_ptr;
465  kargs.stride_dbias = stride_dbias;
466  kargs.nhead_stride_dbias = nhead_stride_dbias;
467  kargs.batch_stride_dbias = batch_stride_dbias;
468  }
469 
470  if constexpr(kHasMask)
471  {
472  kargs.window_size_left = window_size_left;
473  kargs.window_size_right = window_size_right;
474  kargs.mask_type = static_cast<ck_tile::GenericAttentionMaskEnum>(mask_type);
475  }
476 
477  if constexpr(kHasDropout)
478  {
479  if(drop_seed_offset.index() == 0) // seed & offset come from host
480  {
481  const auto& [seed, offset] = std::get<0>(drop_seed_offset);
482  kargs.init_dropout(p_drop, seed, offset, scale);
483  }
484  else // seed & offset come from device
485  {
486  const auto& [seed_ptr, offset_ptr] = std::get<1>(drop_seed_offset);
487  kargs.init_dropout(p_drop,
488  reinterpret_cast<const uint64_t*>(seed_ptr),
489  reinterpret_cast<const uint64_t*>(offset_ptr),
490  scale);
491  }
492 
493  if constexpr(kIsStoreRandval)
494  {
495  kargs.rand_val_ptr = rand_val_ptr;
496  kargs.stride_randval = stride_randval;
497  kargs.nhead_stride_randval = nhead_stride_randval;
498  kargs.batch_stride_randval = batch_stride_randval;
499  }
500  }
501 
502  if constexpr(kIsDeterministic && !kUseQrQtrDorPipeline)
503  {
504  kargs.split_stride_dq_acc = split_stride_dq_acc;
505  }
506 
507  return kargs;
508  }
509 
510  template <bool Cond = kIsGroupMode>
511  CK_TILE_HOST static constexpr std::enable_if_t<Cond, Kargs>
512  MakeKargsImpl(const void* q_ptr,
513  const void* k_ptr,
514  const void* v_ptr,
515  const void* bias_ptr,
516  const void* lse_ptr,
517  const void* do_ptr,
518  const void* d_ptr,
519  void* rand_val_ptr,
520  void* dk_ptr,
521  void* dv_ptr,
522  void* dbias_ptr,
523  void* dq_acc_ptr,
524  const void* seqstart_q_ptr,
525  const void* seqstart_k_ptr,
526  const void* seqlen_q_ptr,
527  const void* seqlen_k_ptr,
528  const void* cu_seqlen_q_ptr,
529  const void* cu_seqlen_k_ptr,
530  ck_tile::index_t hdim_q,
531  ck_tile::index_t hdim_v,
532  ck_tile::index_t num_head_q,
533  ck_tile::index_t nhead_ratio_qk,
534  float scale,
535  ck_tile::index_t stride_q,
536  ck_tile::index_t stride_k,
537  ck_tile::index_t stride_v,
538  ck_tile::index_t stride_bias,
539  ck_tile::index_t stride_randval,
540  ck_tile::index_t stride_do,
541  ck_tile::index_t stride_dq_acc,
542  ck_tile::index_t stride_dk,
543  ck_tile::index_t stride_dv,
544  ck_tile::index_t stride_dbias,
545  ck_tile::index_t nhead_stride_q,
546  ck_tile::index_t nhead_stride_k,
547  ck_tile::index_t nhead_stride_v,
548  ck_tile::index_t nhead_stride_bias,
549  ck_tile::index_t nhead_stride_randval,
550  ck_tile::index_t nhead_stride_do,
551  ck_tile::index_t nhead_stride_lsed,
552  ck_tile::index_t nhead_stride_dq_acc,
553  ck_tile::index_t nhead_stride_dk,
554  ck_tile::index_t nhead_stride_dv,
555  ck_tile::index_t nhead_stride_dbias,
556  ck_tile::index_t split_stride_dq_acc,
557  ck_tile::index_t window_size_left,
558  ck_tile::index_t window_size_right,
559  ck_tile::index_t mask_type,
560  float p_drop,
561  std::variant<std::pair<uint64_t, uint64_t>, std::pair<const void*, const void*>>
562  drop_seed_offset)
563  {
564  Kargs kargs{{q_ptr,
565  k_ptr,
566  v_ptr,
567  lse_ptr,
568  do_ptr,
569  d_ptr,
570  dq_acc_ptr,
571  dk_ptr,
572  dv_ptr,
573  -1, // seqlen will be updated by another pointer
574  -1, //
575  hdim_q,
576  hdim_v,
577  num_head_q,
578  nhead_ratio_qk,
579  scale,
580  static_cast<float>(scale * ck_tile::log2e_v<>),
581  stride_q,
582  stride_k,
583  stride_v,
584  stride_do,
585  stride_dq_acc,
586  stride_dk,
587  stride_dv,
588  nhead_stride_q,
589  nhead_stride_k,
590  nhead_stride_v,
591  nhead_stride_do,
592  nhead_stride_lsed,
593  nhead_stride_dq_acc,
594  nhead_stride_dk,
595  nhead_stride_dv}, // args for common karg
596  {}, // placeholder for bias
597  {}, // placeholder for dbias
598  {}, // placeholder for mask
599  {}, // placeholder for dropout
600  {}, // placeholder for deterministic
601  reinterpret_cast<const int32_t*>(seqstart_q_ptr),
602  reinterpret_cast<const int32_t*>(seqstart_k_ptr),
603  reinterpret_cast<const int32_t*>(seqlen_q_ptr),
604  reinterpret_cast<const int32_t*>(seqlen_k_ptr),
605  reinterpret_cast<const int32_t*>(cu_seqlen_q_ptr),
606  reinterpret_cast<const int32_t*>(cu_seqlen_k_ptr)};
607 
609  {
610  kargs.bias_ptr = bias_ptr;
611  kargs.stride_bias = stride_bias;
612  kargs.nhead_stride_bias = nhead_stride_bias;
613  }
614  else if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI)
615  {
616  kargs.alibi_slope_ptr = bias_ptr;
617  kargs.alibi_slope_stride = stride_bias;
618  }
619  if constexpr(kHasBiasGrad)
620  {
621  kargs.dbias_ptr = dbias_ptr;
622  kargs.stride_dbias = stride_dbias;
623  kargs.nhead_stride_dbias = nhead_stride_dbias;
624  }
625  if constexpr(kHasMask)
626  {
627  kargs.window_size_left = window_size_left;
628  kargs.window_size_right = window_size_right;
629  kargs.mask_type = static_cast<ck_tile::GenericAttentionMaskEnum>(mask_type);
630  }
631  if constexpr(kHasDropout)
632  {
633  if(drop_seed_offset.index() == 0) // seed & offset come from host
634  {
635  const auto& [seed, offset] = std::get<0>(drop_seed_offset);
636  kargs.init_dropout(p_drop, seed, offset, scale);
637  }
638  else // seed & offset come from device
639  {
640  const auto& [seed_ptr, offset_ptr] = std::get<1>(drop_seed_offset);
641  kargs.init_dropout(p_drop,
642  reinterpret_cast<const uint64_t*>(seed_ptr),
643  reinterpret_cast<const uint64_t*>(offset_ptr),
644  scale);
645  }
646 
647  if constexpr(kIsStoreRandval)
648  {
649  kargs.rand_val_ptr = rand_val_ptr;
650  kargs.stride_randval = stride_randval;
651  kargs.nhead_stride_randval = nhead_stride_randval;
652  }
653  }
654  if constexpr(kIsDeterministic)
655  {
656  kargs.split_stride_dq_acc = split_stride_dq_acc;
657  }
658 
659  return kargs;
660  }
661 
662  CK_TILE_HOST static constexpr auto
664  {
665  return dim3(
666  kUseQrQtrDorPipeline ? 1 : ck_tile::integer_divide_ceil(seqlen_k_, FmhaPipeline::kN0),
667  nhead_,
668  batch_size_);
669  }
670 
671  CK_TILE_DEVICE static constexpr auto GetTileIndex()
672  {
673  const index_t i_block = blockIdx.x;
674  const index_t i_nhead = blockIdx.y;
675  const index_t i_batch = blockIdx.z;
676 
677  return ck_tile::make_tuple(i_block, i_nhead, i_batch);
678  }
679 
680  CK_TILE_HOST static dim3 BlockSize()
681  {
682  if(is_wave32())
683  {
684  return dim3(kBlockSize / 2);
685  }
686  else
687  {
688  return dim3(kBlockSize);
689  }
690  }
691 
693  {
694  return ck_tile::max(FmhaPipeline::GetSmemSize(),
695  KGradEpiloguePipeline::GetSmemSize(),
696  VGradEpiloguePipeline::GetSmemSize());
697  }
698 
699  CK_TILE_DEVICE void operator()(Kargs kargs) const
700  {
701  if constexpr(kIsAvailable)
702  run_(std::move(kargs));
703  }
704 
705  CK_TILE_DEVICE void run_(Kargs kargs) const
706  {
707  // allocate LDS
708  __shared__ char smem_ptr[GetSmemSize()];
709 
710  // divide problem
711  const auto [i_tile_n, i_nhead, i_batch] = GetTileIndex();
712 
713  const index_t i_n0 = amd_wave_read_first_lane(i_tile_n * FmhaPipeline::kN0);
714 
715  long_index_t batch_offset_q = 0;
716  long_index_t batch_offset_k = 0;
717  long_index_t batch_offset_v = 0;
718  long_index_t batch_offset_bias = 0;
719  long_index_t batch_offset_randval = 0;
720  long_index_t batch_offset_do = 0;
721  long_index_t batch_offset_lsed = 0;
722  long_index_t batch_offset_dq_acc = 0;
723  long_index_t batch_offset_dk = 0;
724  long_index_t batch_offset_dv = 0;
725  long_index_t batch_offset_dbias = 0;
726 
727  if constexpr(kIsGroupMode)
728  {
729  // get starting offset for each batch
730  const long_index_t query_start = kargs.seqstart_q_ptr[i_batch];
731  const long_index_t key_start = kargs.seqstart_k_ptr[i_batch];
732 
733  batch_offset_q = query_start * kargs.stride_q;
734  batch_offset_k = key_start * kargs.stride_k;
735  batch_offset_v = key_start * kargs.stride_v;
736  batch_offset_do = query_start * kargs.stride_do;
737  batch_offset_lsed = query_start;
738  batch_offset_dq_acc = query_start * kargs.stride_dq_acc;
739  batch_offset_dk = key_start * kargs.stride_dk;
740  batch_offset_dv = key_start * kargs.stride_dv;
742  {
743  batch_offset_bias = query_start * kargs.stride_bias;
744  }
745  if constexpr(kHasBiasGrad)
746  {
747  batch_offset_dbias = query_start * kargs.stride_dbias;
748  }
749  else
750  {
751  batch_offset_dbias = key_start;
752  }
753  if constexpr(kIsStoreRandval)
754  {
755  batch_offset_randval = query_start * kargs.stride_randval;
756  }
757 
758  // Priority: cu_seqlen_q_ptr > seqlen_q_ptr > physical_seqlen_q
759  if(kargs.cu_seqlen_q_ptr != nullptr)
760  {
761  kargs.seqlen_q =
762  kargs.cu_seqlen_q_ptr[i_batch + 1] - kargs.cu_seqlen_q_ptr[i_batch];
763  }
764  else
765  {
766  // get real # queries & # keys under group mode
767  const auto adjusted_seqstart_q_ptr = kargs.seqstart_q_ptr + i_batch;
768  const ck_tile::index_t physical_seqlen_q =
769  adjusted_seqstart_q_ptr[1] - adjusted_seqstart_q_ptr[0];
770  kargs.seqlen_q =
771  kargs.seqlen_q_ptr ? kargs.seqlen_q_ptr[i_batch] : physical_seqlen_q;
772  }
773 
774  // Priority: cu_seqlen_k_ptr > seqlen_k_ptr > seqstart_k
775  if(kargs.cu_seqlen_k_ptr != nullptr)
776  {
777  kargs.seqlen_k =
778  kargs.cu_seqlen_k_ptr[i_batch + 1] - kargs.cu_seqlen_k_ptr[i_batch];
779  }
780  else if(kargs.seqlen_k_ptr != nullptr)
781  {
782  kargs.seqlen_k = kargs.seqlen_k_ptr[i_batch];
783  }
784  else
785  {
786  const auto adjusted_seqstart_k_ptr = kargs.seqstart_k_ptr + i_batch;
787  kargs.seqlen_k = adjusted_seqstart_k_ptr[1] - adjusted_seqstart_k_ptr[0];
788  }
789 
790  // skip if logical lengths are zero
791  if(kargs.seqlen_q == 0 || kargs.seqlen_k == 0)
792  {
793  return;
794  }
795 
796  // # of required blocks is different in each groups, terminate unnecessary blocks
797  // earlier
798  if constexpr(!kUseQrQtrDorPipeline)
799  if(kargs.seqlen_k <= i_n0)
800  return;
801  }
802  else
803  {
804  batch_offset_q = static_cast<long_index_t>(i_batch) * kargs.batch_stride_q;
805  batch_offset_k = static_cast<long_index_t>(i_batch) * kargs.batch_stride_k;
806  batch_offset_v = static_cast<long_index_t>(i_batch) * kargs.batch_stride_v;
807  batch_offset_do = static_cast<long_index_t>(i_batch) * kargs.batch_stride_do;
808  batch_offset_lsed = static_cast<long_index_t>(i_batch) * kargs.batch_stride_lsed;
809  batch_offset_dq_acc = static_cast<long_index_t>(i_batch) * kargs.batch_stride_dq_acc;
810  batch_offset_dk = static_cast<long_index_t>(i_batch) * kargs.batch_stride_dk;
811  batch_offset_dv = static_cast<long_index_t>(i_batch) * kargs.batch_stride_dv;
813  {
814  batch_offset_bias = static_cast<long_index_t>(i_batch) * kargs.batch_stride_bias;
815  }
816  if constexpr(kHasBiasGrad)
817  {
818  batch_offset_dbias = static_cast<long_index_t>(i_batch) * kargs.batch_stride_dbias;
819  }
820  if constexpr(kIsStoreRandval)
821  {
822  batch_offset_randval =
823  static_cast<long_index_t>(i_batch) * kargs.batch_stride_randval;
824  }
825  }
826 
827  // for simplicity, batch stride we just modify the pointer
828  const QDataType* q_ptr = reinterpret_cast<const QDataType*>(kargs.q_ptr) +
829  static_cast<long_index_t>(i_nhead) * kargs.nhead_stride_q +
830  batch_offset_q;
831  const KDataType* k_ptr =
832  reinterpret_cast<const KDataType*>(kargs.k_ptr) +
833  static_cast<long_index_t>(i_nhead / kargs.nhead_ratio_qk) * kargs.nhead_stride_k +
834  batch_offset_k;
835  const VDataType* v_ptr =
836  reinterpret_cast<const VDataType*>(kargs.v_ptr) +
837  static_cast<long_index_t>(i_nhead / kargs.nhead_ratio_qk) * kargs.nhead_stride_v +
838  batch_offset_v;
839  const LSEDataType* lse_ptr = reinterpret_cast<const LSEDataType*>(kargs.lse_ptr) +
840  static_cast<long_index_t>(i_nhead) * kargs.nhead_stride_lsed +
841  batch_offset_lsed;
842  const DDataType* d_ptr = reinterpret_cast<const DDataType*>(kargs.d_ptr) +
843  static_cast<long_index_t>(i_nhead) * kargs.nhead_stride_lsed +
844  batch_offset_lsed;
845  const OGradDataType* do_ptr = reinterpret_cast<const OGradDataType*>(kargs.do_ptr) +
846  static_cast<long_index_t>(i_nhead) * kargs.nhead_stride_do +
847  batch_offset_do;
848  auto dk_ptr = reinterpret_cast<KGradDataType*>(kargs.dk_ptr) +
849  static_cast<long_index_t>(i_nhead) * kargs.nhead_stride_dk + batch_offset_dk;
850  auto dv_ptr = reinterpret_cast<VGradDataType*>(kargs.dv_ptr) +
851  static_cast<long_index_t>(i_nhead) * kargs.nhead_stride_dv + batch_offset_dv;
852 
853  // Q/K/V/LSE/D/dO/dQ/dK/dV DRAM and DRAM window
854  const auto q_dram_naive = make_naive_tensor_view<address_space_enum::global>(
855  q_ptr,
856  make_tuple(kargs.seqlen_q, kargs.hdim_q),
857  make_tuple(kargs.stride_q, 1),
859  number<1>{});
860  const auto q_dram = pad_tensor_view(
861  q_dram_naive,
863  sequence<false, (kPadHeadDimQ > 0)>{});
864 
865  const auto k_dram_naive = make_naive_tensor_view<address_space_enum::global>(
866  k_ptr,
867  make_tuple(kargs.seqlen_k, kargs.hdim_q),
868  make_tuple(kargs.stride_k, 1),
870  number<1>{});
871  const auto k_dram = pad_tensor_view(
872  k_dram_naive,
874  sequence<false, (kPadHeadDimQ > 0)>{});
875 
876  const auto v_dram = [&]() {
877  const auto v_dram_naive = make_naive_tensor_view<address_space_enum::global>(
878  v_ptr,
879  make_tuple(kargs.seqlen_k, kargs.hdim_v),
880  make_tuple(kargs.stride_v, 1),
882  number<1>{});
883  return pad_tensor_view(
884  v_dram_naive,
886  sequence<false, (kPadHeadDimV > 0)>{});
887  }();
888 
889  // lse and d should be fine to read unpaded data as they are not on the reduction dimension
890  const auto lse_dram = make_naive_tensor_view_packed<address_space_enum::global>(
891  lse_ptr, make_tuple(kargs.seqlen_q), number<FmhaPipeline::kM0>{});
892 
893  const auto d_dram = make_naive_tensor_view_packed<address_space_enum::global>(
894  d_ptr, make_tuple(kargs.seqlen_q), number<FmhaPipeline::kM0>{});
895 
896  const auto do_dram_naive = make_naive_tensor_view<address_space_enum::global>(
897  do_ptr,
898  make_tuple(kargs.seqlen_q, kargs.hdim_v),
899  make_tuple(kargs.stride_do, 1),
901  number<1>{});
902  const auto do_dram = pad_tensor_view(
903  do_dram_naive,
905  sequence<false, (kPadHeadDimV > 0)>{});
906 
907  auto q_dram_window = make_tile_window(
908  q_dram,
910  {0, 0});
911 
912  auto k_dram_window = make_tile_window(
913  k_dram,
915  {i_n0, 0});
916 
917  auto v_dram_window = make_tile_window(
918  v_dram,
920  {i_n0, 0});
921 
922  auto do_dram_window = make_tile_window(
923  do_dram,
925  {0, 0});
926 
927  auto dq_dram_window = [&, i_tile_n_ = i_tile_n, i_nhead_ = i_nhead]() {
928  constexpr bool kUseKSplit = !kUseQrQtrDorPipeline && kIsDeterministic;
929  using DType = std::conditional_t<kUseQrQtrDorPipeline, QGradDataType, AccDataType>;
930 
931  auto dq_acc_ptr = reinterpret_cast<DType*>(kargs.dq_acc_ptr) + [&]() {
932  if constexpr(kUseKSplit)
933  return static_cast<long_index_t>(i_nhead_) * kargs.nhead_stride_dq_acc +
934  static_cast<long_index_t>(i_tile_n_) * kargs.split_stride_dq_acc +
935  batch_offset_dq_acc;
936  else
937  return static_cast<long_index_t>(i_nhead_) * kargs.nhead_stride_dq_acc +
938  batch_offset_dq_acc;
939  }();
940 
941  constexpr auto DstInMemOp = conditional_expr<kUseKSplit>(
942  memory_operation_enum::set, memory_operation_enum::atomic_add);
943  const auto dq_acc_dram_naive =
944  make_naive_tensor_view<address_space_enum::global, DstInMemOp>(
945  dq_acc_ptr,
946  make_tuple(kargs.seqlen_q, kargs.hdim_q),
947  make_tuple(kargs.stride_dq_acc, 1),
949  number<1>{});
950  const auto dq_acc_dram = pad_tensor_view(
951  dq_acc_dram_naive,
953  sequence<false, (kPadHeadDimQ > 0)>{});
954  return make_tile_window(
955  dq_acc_dram,
957  {0, 0});
958  }();
959 
960  auto lse_dram_window =
962 
963  auto d_dram_window = make_tile_window(d_dram, make_tuple(number<FmhaPipeline::kM0>{}), {0});
964 
967  constexpr auto bias_dram_window_lengths =
969  const auto bias_dram_window = [&, i_nhead_ = i_nhead]() {
971  {
972  const BiasDataType* bias_ptr =
973  reinterpret_cast<const BiasDataType*>(kargs.bias_ptr) +
974  static_cast<long_index_t>(i_nhead_) * kargs.nhead_stride_bias +
975  batch_offset_bias;
976 
977  const auto bias_dram = [&]() {
978  const auto bias_dram_naive = make_naive_tensor_view<address_space_enum::global>(
979  bias_ptr,
980  make_tuple(kargs.seqlen_q, kargs.seqlen_k),
981  make_tuple(kargs.stride_bias, 1),
983  number<1>{});
984 
985  return pad_tensor_view(
986  bias_dram_naive, bias_dram_window_lengths, sequence<false, true>{});
987  }();
988 
989  return make_tile_window(bias_dram, bias_dram_window_lengths, {0, i_n0});
990  }
991  else
992  {
993  return make_null_tile_window(bias_dram_window_lengths);
994  }
995  }();
996 
997  auto dbias_dram_window = [&, i_nhead_ = i_nhead]() {
998  if constexpr(kHasBiasGrad)
999  {
1000  BiasGradDataType* dbias_ptr =
1001  reinterpret_cast<BiasGradDataType*>(kargs.dbias_ptr) +
1002  static_cast<long_index_t>(i_nhead_) * kargs.nhead_stride_dbias +
1003  batch_offset_dbias;
1004 
1005  auto dbias_dram = [&]() {
1006  const auto dbias_dram_naive =
1007  make_naive_tensor_view<address_space_enum::global>(
1008  dbias_ptr,
1009  make_tuple(kargs.seqlen_q, kargs.seqlen_k),
1010  make_tuple(kargs.stride_dbias, 1),
1012  number<1>{});
1013 
1014  return pad_tensor_view(
1015  dbias_dram_naive, bias_dram_window_lengths, sequence<false, true>{});
1016  }();
1017 
1018  return make_tile_window(dbias_dram, bias_dram_window_lengths, {0, i_n0});
1019  }
1020  else
1021  {
1022  return make_null_tile_window(bias_dram_window_lengths);
1023  }
1024  }();
1025 
1026  // WA i_batch capture structure binding before c++20
1027  auto position_encoding = [&, i_batch_ = i_batch, i_nhead_ = i_nhead]() {
1028  if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI)
1029  {
1030  // data loading, shared by entire wg
1031  // TODO: how to use s_read?
1032  AccDataType slope = *(reinterpret_cast<const AccDataType*>(kargs.alibi_slope_ptr) +
1033  i_batch_ * kargs.alibi_slope_stride + i_nhead_);
1034  slope *= ck_tile::log2e_v<>;
1035  if constexpr(kHasMask)
1036  {
1037  return make_alibi_from_lr_mask<AccDataType, false>(slope,
1038  kargs.window_size_left,
1039  kargs.window_size_right,
1040  kargs.seqlen_q,
1041  kargs.seqlen_k,
1042  kargs.mask_type);
1043  }
1044  else
1045  {
1047  slope, kargs.seqlen_q, kargs.seqlen_k, AlibiMode::FROM_BOTTOM_RIGHT};
1048  }
1049  }
1050  else
1051  {
1053  }
1054  }();
1055 
1056  // dropout
1057  float rp_undrop = 1;
1058  float scale_rp_undrop = 1;
1059  if constexpr(kHasDropout)
1060  {
1061  rp_undrop = kargs.rp_undrop;
1062  scale_rp_undrop = kargs.scale_rp_undrop;
1063  }
1064  auto dropout = [&, i_nhead_ = i_nhead, i_batch_ = i_batch]() {
1065  if constexpr(kHasDropout)
1066  {
1067  return FmhaDropout{i_batch_,
1068  i_nhead_,
1069  kargs.num_head_q,
1070  kargs.is_drop_seed_offset_from_host ? kargs.drop_seed.val
1071  : *kargs.drop_seed.ptr,
1072  kargs.is_drop_seed_offset_from_host ? kargs.drop_offset.val
1073  : *kargs.drop_offset.ptr,
1074  kargs.rp_undrop,
1075  kargs.p_undrop_in_uint8_t};
1076  }
1077  else
1078  {
1079  return FmhaDropout{};
1080  };
1081  }();
1082 
1083  auto randval_dram_window = [&, i_nhead_ = i_nhead]() {
1084  constexpr auto randval_dram_window_lengths =
1086  if constexpr(kIsStoreRandval)
1087  {
1088  RandValOutputDataType* rand_val_ptr =
1089  reinterpret_cast<RandValOutputDataType*>(kargs.rand_val_ptr) +
1090  static_cast<long_index_t>(i_nhead_) * kargs.nhead_stride_randval +
1091  batch_offset_randval;
1092 
1093  const auto randval_dram = [&]() {
1094  const auto randval_dram_naive =
1095  make_naive_tensor_view<address_space_enum::global>(
1096  rand_val_ptr,
1097  make_tuple(kargs.seqlen_q, kargs.seqlen_k),
1098  make_tuple(kargs.stride_randval, 1),
1099  number<1>{},
1100  number<1>{});
1101 
1102  return pad_tensor_view(
1103  randval_dram_naive, randval_dram_window_lengths, sequence<false, true>{});
1104  }();
1105 
1106  return make_tile_window(randval_dram, randval_dram_window_lengths, {0, i_n0});
1107  }
1108  else
1109  {
1110  return make_null_tile_window(randval_dram_window_lengths);
1111  }
1112  }();
1113 
1114  FmhaMask mask = [&]() {
1115  if constexpr(kHasMask)
1116  return ck_tile::make_generic_attention_mask_from_lr_window<FmhaMask>(
1117  kargs.window_size_left,
1118  kargs.window_size_right,
1119  kargs.seqlen_q,
1120  kargs.seqlen_k,
1122  else
1123  return FmhaMask{kargs.seqlen_q, kargs.seqlen_k};
1124  }();
1125 
1126  auto dk_dram = [&]() {
1127  const auto dk_dram_naive = make_naive_tensor_view<address_space_enum::global>(
1128  dk_ptr,
1129  make_tuple(kargs.seqlen_k, kargs.hdim_q),
1130  make_tuple(kargs.stride_dk, 1),
1132  number<1>{});
1133 
1134  return pad_tensor_view(
1135  dk_dram_naive,
1137  sequence<false, (kPadHeadDimQ > 0)>{});
1138  }();
1139 
1140  auto dv_dram = [&]() {
1141  const auto dv_dram_naive = make_naive_tensor_view<address_space_enum::global>(
1142  dv_ptr,
1143  make_tuple(kargs.seqlen_k, kargs.hdim_v),
1144  make_tuple(kargs.stride_dv, 1),
1146  number<1>{});
1147 
1148  return pad_tensor_view(
1149  dv_dram_naive,
1151  sequence<false, (kPadHeadDimV > 0)>{});
1152  }();
1153 
1154  auto dk_dram_window = make_tile_window(
1155  dk_dram,
1157  {i_n0, 0});
1158 
1159  auto dv_dram_window = make_tile_window(
1160  dv_dram,
1162  {i_n0, 0});
1163  if constexpr(!kUseQrQtrDorPipeline)
1164  {
1165  auto [dk_acc_tile, dv_acc_tile] = FmhaPipeline{}(smem_ptr,
1166  q_dram_window,
1167  k_dram_window,
1168  v_dram_window,
1169  bias_dram_window,
1170  randval_dram_window,
1171  do_dram_window,
1172  lse_dram_window,
1173  d_dram_window,
1174  dq_dram_window,
1175  dbias_dram_window,
1176  mask,
1177  position_encoding,
1178  kargs.raw_scale,
1179  kargs.scale,
1180  rp_undrop,
1181  scale_rp_undrop,
1182  dropout);
1183 
1184 #if defined(__gfx12__)
1185  // Workaround for a compiler bug (SWDEV-559729): v_wmma instructions can be incorrectly
1186  // placed in divergent branches used to store padded tensors (when some lanes are
1187  // inactive due to padding). Inline asm with dummy dependencies on VGPRs of the tensors
1188  // prevents the compiler doing this.
1189  if constexpr(kPadHeadDimQ > 0)
1190  {
1191  impl::insert_dummy_dep(dk_acc_tile.get_thread_buffer());
1192  }
1193  if constexpr(kPadHeadDimV > 0)
1194  {
1195  impl::insert_dummy_dep(dv_acc_tile.get_thread_buffer());
1196  }
1197 #endif
1198 
1199  KGradEpiloguePipeline{}(dk_dram_window, dk_acc_tile, nullptr);
1200  VGradEpiloguePipeline{}(dv_dram_window, dv_acc_tile, nullptr);
1201  }
1202  else
1203  {
1204  FmhaPipeline{}(smem_ptr,
1205  q_dram_window,
1206  k_dram_window,
1207  v_dram_window,
1208  bias_dram_window,
1209  randval_dram_window,
1210  do_dram_window,
1211  lse_dram_window,
1212  d_dram_window,
1213  dq_dram_window,
1214  dk_dram_window,
1215  dv_dram_window,
1216  dbias_dram_window,
1220  mask,
1221  position_encoding,
1222  kargs.raw_scale,
1223  kargs.scale,
1224  rp_undrop,
1225  scale_rp_undrop,
1226  dropout);
1227  }
1228  }
1229 };
1230 
1231 template <typename FmhaBwdOGradDotO_>
1233 {
1235  static constexpr ck_tile::index_t kBlockSize = FmhaBwdOGradDotO::kBlockSize;
1236  static constexpr ck_tile::index_t kBlockPerCu = FmhaBwdOGradDotO::kBlockPerCu;
1237  static constexpr ck_tile::index_t kM0 = kBlockSize;
1238  static constexpr ck_tile::index_t kVHeaddim = FmhaBwdOGradDotO::kVHeaddim;
1239 
1243 
1244  static constexpr bool kIsGroupMode = FmhaBwdOGradDotO::kIsGroupMode;
1245  static constexpr bool kPadSeqLenQ = FmhaBwdOGradDotO::kPadSeqLenQ;
1246  static constexpr bool kPadHeadDimV = FmhaBwdOGradDotO::kPadHeadDimV;
1247 
1248  // clang-format off
1249  template <typename T> struct t2s;
1250  template <> struct t2s<float> { static constexpr const char * name = "fp32"; };
1251  template <> struct t2s<ck_tile::fp16_t> { static constexpr const char * name = "fp16"; };
1252  template <> struct t2s<ck_tile::bf16_t> { static constexpr const char * name = "bf16"; };
1253  // clang-format on
1254 
1255  CK_TILE_HOST static std::string GetName()
1256  {
1257  // sync with generate.py
1258  // clang-format off
1259 
1260  #define _SS_ std::string
1261  #define _TS_ std::to_string
1262  auto pn = [&] () {
1263  std::string n;
1264  if (kPadSeqLenQ) n += "s";
1265  if (kPadHeadDimV) n += "dv";
1266  return n.empty() ? n : std::string("p") + n; }();
1267  return
1268  _SS_("fmha_bwd_dot_do_o_d") + _TS_(kVHeaddim) + "_" + _SS_(t2s<ODataType>::name) +
1269  "_b" + _TS_(kM0) + "_" + (kIsGroupMode ? "group" : "batch") + "_" +
1270  ("o" + _TS_(kBlockPerCu)) + (pn.empty() ? "_npad" : "_" + pn);
1271  #undef _SS_
1272  #undef _TS_
1273  // clang-format on
1274  }
1275 
1276  // kargs use aggregate initializer, so no constructor will provided
1277  // use inheritance to minimize karg size
1278  // user need to use MakeKargs() function to create kargs.
1280  {
1281  const void* o_ptr;
1282  const void* do_ptr;
1283  void* d_ptr;
1284 
1285  float p_undrop;
1286 
1289 
1292 
1296  };
1297 
1299  {
1303  };
1304 
1306  {
1308  const int32_t* seqlen_q_ptr; // per-batch actual length [batch]
1309  const int32_t* cu_seqlen_q_ptr; // cumulative seqlen [batch+1], optional
1310  };
1311 
1312  using Kargs = std::
1313  conditional_t<kIsGroupMode, FmhaBwdOGradDotOGroupModeKargs, FmhaBwdOGradDotOBatchModeKargs>;
1314 
1315  template <bool Cond = !kIsGroupMode>
1316  CK_TILE_HOST static constexpr std::enable_if_t<Cond, Kargs>
1317  MakeKargs(const void* o_ptr,
1318  const void* do_ptr,
1319  void* d_ptr,
1320  float p_undrop,
1321  ck_tile::index_t seqlen_q,
1322  ck_tile::index_t hdim_v,
1323  ck_tile::index_t stride_do,
1324  ck_tile::index_t stride_o,
1325  ck_tile::index_t nhead_stride_do,
1326  ck_tile::index_t nhead_stride_o,
1327  ck_tile::index_t nhead_stride_d,
1328  ck_tile::index_t batch_stride_do,
1329  ck_tile::index_t batch_stride_o,
1330  ck_tile::index_t batch_stride_d)
1331  {
1332  Kargs kargs{{o_ptr,
1333  do_ptr,
1334  d_ptr,
1335  p_undrop,
1336  seqlen_q,
1337  hdim_v,
1338  stride_do,
1339  stride_o,
1340  nhead_stride_do,
1341  nhead_stride_o,
1342  nhead_stride_d},
1343  batch_stride_do,
1344  batch_stride_o,
1345  batch_stride_d};
1346 
1347  return kargs;
1348  }
1349 
1350  template <bool Cond = kIsGroupMode>
1351  CK_TILE_HOST static constexpr std::enable_if_t<Cond, Kargs>
1352  MakeKargs(const void* o_ptr,
1353  const void* do_ptr,
1354  void* d_ptr,
1355  float p_undrop,
1356  const void* seqstart_q_ptr,
1357  const void* seqlen_q_ptr,
1358  const void* cu_seqlen_q_ptr,
1359  ck_tile::index_t hdim_v,
1360  ck_tile::index_t stride_do,
1361  ck_tile::index_t stride_o,
1362  ck_tile::index_t nhead_stride_do,
1363  ck_tile::index_t nhead_stride_o,
1364  ck_tile::index_t nhead_stride_d)
1365  {
1366  Kargs kargs{{o_ptr,
1367  do_ptr,
1368  d_ptr,
1369  p_undrop,
1370  -1, // seqlen will be updated by another pointer
1371  hdim_v,
1372  stride_do,
1373  stride_o,
1374  nhead_stride_do,
1375  nhead_stride_o,
1376  nhead_stride_d},
1377  reinterpret_cast<const int32_t*>(seqstart_q_ptr),
1378  reinterpret_cast<const int32_t*>(seqlen_q_ptr),
1379  reinterpret_cast<const int32_t*>(cu_seqlen_q_ptr)};
1380 
1381  return kargs;
1382  }
1383 
1384  CK_TILE_HOST static constexpr auto
1386  {
1387  return dim3(ck_tile::integer_divide_ceil(seqlen_q_, kM0), nhead_, batch_size_);
1388  }
1389 
1390  CK_TILE_DEVICE static constexpr auto GetTileIndex()
1391  {
1392  const index_t i_block = blockIdx.x;
1393  const index_t i_nhead = blockIdx.y;
1394  const index_t i_batch = blockIdx.z;
1395 
1396  return ck_tile::make_tuple(i_block, i_nhead, i_batch);
1397  }
1398 
1399  CK_TILE_HOST static dim3 BlockSize() { return dim3(kBlockSize); }
1400 
1401  CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize() { return 0; }
1402 
1403  CK_TILE_DEVICE void operator()(Kargs kargs) const
1404  {
1405  // divide problem
1406  const auto [i_tile_m, i_nhead, i_batch] = GetTileIndex();
1407 
1408  const index_t i_m0 = amd_wave_read_first_lane(i_tile_m * kM0);
1409 
1410  long_index_t batch_offset_o = 0;
1411  long_index_t batch_offset_do = 0;
1412  long_index_t batch_offset_d = 0;
1413 
1414  if constexpr(kIsGroupMode)
1415  {
1416  // get starting offset for each batch
1417  const long_index_t query_start = kargs.seqstart_q_ptr[i_batch];
1418 
1419  batch_offset_o = query_start * kargs.stride_o;
1420  batch_offset_do = query_start * kargs.stride_do;
1421  batch_offset_d = query_start;
1422 
1423  // Priority: cu_seqlen_q_ptr > seqlen_q_ptr > physical_seqlen_q
1424  if(kargs.cu_seqlen_q_ptr != nullptr)
1425  {
1426  kargs.seqlen_q =
1427  kargs.cu_seqlen_q_ptr[i_batch + 1] - kargs.cu_seqlen_q_ptr[i_batch];
1428  }
1429  else
1430  {
1431  // get real # queries & # keys under group mode
1432  const auto adjusted_seqstart_q_ptr = kargs.seqstart_q_ptr + i_batch;
1433  const ck_tile::index_t physical_seqlen_q =
1434  adjusted_seqstart_q_ptr[1] - adjusted_seqstart_q_ptr[0];
1435  kargs.seqlen_q = kargs.seqlen_q_ptr
1436  ? static_cast<ck_tile::index_t>(kargs.seqlen_q_ptr[i_batch])
1437  : physical_seqlen_q;
1438  }
1439 
1440  // # of required blocks is different in each groups, terminate unnecessary blocks
1441  // earlier
1442  if(kargs.seqlen_q <= i_m0)
1443  {
1444  return;
1445  }
1446  }
1447  else
1448  {
1449  batch_offset_o = static_cast<long_index_t>(i_batch) * kargs.batch_stride_o;
1450  batch_offset_do = static_cast<long_index_t>(i_batch) * kargs.batch_stride_do;
1451  batch_offset_d = static_cast<long_index_t>(i_batch) * kargs.batch_stride_d;
1452  }
1453 
1454  // for simplicity, batch stride we just modify the pointer
1455  const ODataType* o_ptr = reinterpret_cast<const ODataType*>(kargs.o_ptr) +
1456  static_cast<long_index_t>(i_nhead) * kargs.nhead_stride_o +
1457  batch_offset_o;
1458  const OGradDataType* do_ptr = reinterpret_cast<const OGradDataType*>(kargs.do_ptr) +
1459  static_cast<long_index_t>(i_nhead) * kargs.nhead_stride_do +
1460  batch_offset_do;
1461  DDataType* d_ptr = reinterpret_cast<DDataType*>(kargs.d_ptr) +
1462  static_cast<long_index_t>(i_nhead) * kargs.nhead_stride_d +
1463  batch_offset_d;
1464 
1465  // O/dO/D DRAM and DRAM window
1466  const auto o_dram = [&]() {
1467  auto o_dram_naive = make_naive_tensor_view<address_space_enum::global>(
1468  o_ptr,
1469  make_tuple(kargs.seqlen_q, kargs.hdim_v),
1470  make_tuple(kargs.stride_o, 1),
1472  number<1>{});
1473  return pad_tensor_view(o_dram_naive,
1476  }();
1477  const auto do_dram = [&]() {
1478  auto do_dram_naive = make_naive_tensor_view<address_space_enum::global>(
1479  do_ptr,
1480  make_tuple(kargs.seqlen_q, kargs.hdim_v),
1481  make_tuple(kargs.stride_do, 1),
1483  number<1>{});
1484  return pad_tensor_view(do_dram_naive,
1487  }();
1488  auto d_dram = [&]() {
1489  const auto d_dram_naive = make_naive_tensor_view_packed<address_space_enum::global>(
1490  d_ptr, make_tuple(kargs.seqlen_q), number<1>{});
1491  return pad_tensor_view(
1492  d_dram_naive, make_tuple(number<kM0>{}), sequence<kPadSeqLenQ>{});
1493  }();
1494 
1495  auto o_dram_window =
1496  make_tile_window(o_dram, make_tuple(number<kM0>{}, number<kVHeaddim>{}), {i_m0, 0});
1497 
1498  auto do_dram_window =
1499  make_tile_window(do_dram, make_tuple(number<kM0>{}, number<kVHeaddim>{}), {i_m0, 0});
1500 
1501  auto d_dram_window = make_tile_window(d_dram, make_tuple(number<kM0>{}), {i_m0});
1502 
1503  FmhaBwdOGradDotO{}(o_dram_window, do_dram_window, d_dram_window, kargs.p_undrop);
1504  }
1505 };
1506 
1507 template <typename FmhaBwdConvertQGrad_>
1509 {
1511  static constexpr ck_tile::index_t kBlockSize = FmhaBwdConvertQGrad::kBlockSize;
1512  static constexpr ck_tile::index_t kBlockPerCu = FmhaBwdConvertQGrad::kBlockPerCu;
1513  static constexpr ck_tile::index_t kM0 = FmhaBwdConvertQGrad::kM0;
1514  static constexpr ck_tile::index_t kN0 = FmhaBwdConvertQGrad::kN0;
1515  static constexpr ck_tile::index_t kQKHeaddim = FmhaBwdConvertQGrad::kQKHeaddim;
1516 
1519 
1520  static constexpr bool kIsGroupMode = FmhaBwdConvertQGrad::kIsGroupMode;
1521  static constexpr bool kPadSeqLenQ = FmhaBwdConvertQGrad::kPadSeqLenQ;
1522  static constexpr bool kPadHeadDimQ = FmhaBwdConvertQGrad::kPadHeadDimQ;
1523  static constexpr bool kIsDeterministic = FmhaBwdConvertQGrad::kIsDeterministic;
1524 
1525  // clang-format off
1526  template <typename T> struct t2s;
1527  template <> struct t2s<float> { static constexpr const char * name = "fp32"; };
1528  template <> struct t2s<ck_tile::fp16_t> { static constexpr const char * name = "fp16"; };
1529  template <> struct t2s<ck_tile::bf16_t> { static constexpr const char * name = "bf16"; };
1530  // clang-format on
1531 
1532  CK_TILE_HOST static std::string GetName()
1533  {
1534  // sync with generate.py
1535  // clang-format off
1536 
1537  #define _SS_ std::string
1538  #define _TS_ std::to_string
1539  auto pn = [&] () {
1540  std::string n;
1541  if (kPadSeqLenQ) n += "s";
1542  if (kPadHeadDimQ) n += "d";
1543  return n.empty() ? n : std::string("p") + n; }();
1544  return
1545  _SS_("fmha_bwd_convert_dq_d") + _TS_(kQKHeaddim) + "_"
1547  + "b" + _TS_(kM0) + "x" + _TS_(kN0) + "_"
1548  + (kIsGroupMode ? "group" : "batch") + "_"
1549  + ("o" + _TS_(kBlockPerCu)) + (pn.empty() ? "_npad" : "_" + pn)
1550  + (kIsDeterministic ? "_deterministic" : "_ndeterministic") ;
1551  #undef _SS_
1552  #undef _TS_
1553  // clang-format on
1554  }
1555 
1556  // to avoid duplicated base class prblem, introduce an template arg
1557  template <ck_tile::index_t I>
1559  {
1560  };
1561 
1562  // kargs use aggregate initializer, so no constructor will provided
1563  // use inheritance to minimize karg size
1564  // user need to use MakeKargs() function to create kargs.
1566  {
1567  const void* dq_acc_ptr;
1568  void* dq_ptr;
1569 
1573 
1578  };
1579 
1581  {
1583  };
1584 
1587  std::conditional_t<kIsDeterministic,
1588  FmhaBwdConvertQGradDeterministicKargs,
1589  FmhaBwdConvertQGradEmptyKargs<0>>
1590  {
1593  };
1594 
1597  std::conditional_t<kIsDeterministic,
1598  FmhaBwdConvertQGradDeterministicKargs,
1599  FmhaBwdConvertQGradEmptyKargs<0>>
1600  {
1603  const int32_t* seqlen_q_ptr; // per-batch actual length [batch]
1604  const int32_t* seqlen_k_ptr; // per-batch actual length [batch]
1605  const int32_t* cu_seqlen_q_ptr; // cumulative seqlen [batch+1], optional
1606  const int32_t* cu_seqlen_k_ptr; // cumulative seqlen [batch+1], optional
1607  };
1608 
1612 
1613  template <bool Cond = !kIsGroupMode>
1614  CK_TILE_HOST static constexpr std::enable_if_t<Cond, Kargs>
1615  MakeKargs(const void* dq_acc_ptr,
1616  void* dq_ptr,
1617  ck_tile::index_t seqlen_q,
1618  ck_tile::index_t seqlen_k,
1619  ck_tile::index_t hdim_q,
1620  ck_tile::index_t stride_dq,
1621  ck_tile::index_t stride_dq_acc,
1622  ck_tile::index_t nhead_stride_dq,
1623  ck_tile::index_t nhead_stride_dq_acc,
1624  ck_tile::index_t batch_stride_dq,
1625  ck_tile::index_t batch_stride_dq_acc,
1626  ck_tile::index_t split_stride_dq_acc)
1627  {
1628  Kargs kargs{{dq_acc_ptr,
1629  dq_ptr,
1630  seqlen_q,
1631  seqlen_k,
1632  hdim_q,
1633  stride_dq,
1634  stride_dq_acc,
1635  nhead_stride_dq,
1636  nhead_stride_dq_acc},
1637  {},
1638  batch_stride_dq,
1639  batch_stride_dq_acc};
1640 
1641  if constexpr(kIsDeterministic)
1642  {
1643  kargs.split_stride_dq_acc = split_stride_dq_acc;
1644  }
1645 
1646  return kargs;
1647  }
1648 
1649  template <bool Cond = kIsGroupMode>
1650  CK_TILE_HOST static constexpr std::enable_if_t<Cond, Kargs>
1651  MakeKargs(const void* dq_acc_ptr,
1652  void* dq_ptr,
1653  const void* seqstart_q_ptr,
1654  const void* seqstart_k_ptr,
1655  const void* seqlen_q_ptr,
1656  const void* seqlen_k_ptr,
1657  const void* cu_seqlen_q_ptr,
1658  const void* cu_seqlen_k_ptr,
1659  ck_tile::index_t hdim_q,
1660  ck_tile::index_t stride_dq,
1661  ck_tile::index_t stride_dq_acc,
1662  ck_tile::index_t nhead_stride_dq,
1663  ck_tile::index_t nhead_stride_dq_acc,
1664  ck_tile::index_t split_stride_dq_acc)
1665  {
1666  Kargs kargs{{dq_acc_ptr,
1667  dq_ptr,
1668  -1, // seqlen will be updated by another pointer
1669  -1, //
1670  hdim_q,
1671  stride_dq,
1672  stride_dq_acc,
1673  nhead_stride_dq,
1674  nhead_stride_dq_acc},
1675  {},
1676  reinterpret_cast<const int32_t*>(seqstart_q_ptr),
1677  reinterpret_cast<const int32_t*>(seqstart_k_ptr),
1678  reinterpret_cast<const int32_t*>(seqlen_q_ptr),
1679  reinterpret_cast<const int32_t*>(seqlen_k_ptr),
1680  reinterpret_cast<const int32_t*>(cu_seqlen_q_ptr),
1681  reinterpret_cast<const int32_t*>(cu_seqlen_k_ptr)};
1682 
1683  if constexpr(kIsDeterministic)
1684  {
1685  kargs.split_stride_dq_acc = split_stride_dq_acc;
1686  }
1687 
1688  return kargs;
1689  }
1690 
1691  CK_TILE_HOST static constexpr auto
1693  {
1694  return dim3(ck_tile::integer_divide_ceil(seqlen_q_, kM0), nhead_, batch_size_);
1695  }
1696 
1697  CK_TILE_DEVICE static constexpr auto GetTileIndex()
1698  {
1699  const index_t i_block = blockIdx.x;
1700  const index_t i_nhead = blockIdx.y;
1701  const index_t i_batch = blockIdx.z;
1702 
1703  return ck_tile::make_tuple(i_block, i_nhead, i_batch);
1704  }
1705 
1706  CK_TILE_HOST static dim3 BlockSize() { return dim3(kBlockSize); }
1707 
1708  CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize() { return 0; }
1709 
1710  CK_TILE_DEVICE void operator()(Kargs kargs) const
1711  {
1712  // divide problem
1713  const auto [i_tile_m, i_nhead, i_batch] = GetTileIndex();
1714 
1715  const index_t i_m0 = amd_wave_read_first_lane(i_tile_m * kM0);
1716 
1717  long_index_t batch_offset_dq = 0;
1718  long_index_t batch_offset_dq_acc = 0;
1719  if constexpr(kIsGroupMode)
1720  {
1721  // get starting offset for each batch
1722  const long_index_t query_start = kargs.seqstart_q_ptr[i_batch];
1723  batch_offset_dq = query_start * kargs.stride_dq;
1724  batch_offset_dq_acc = query_start * kargs.stride_dq_acc;
1725 
1726  if(kargs.cu_seqlen_q_ptr != nullptr)
1727  {
1728  kargs.seqlen_q =
1729  kargs.cu_seqlen_q_ptr[i_batch + 1] - kargs.cu_seqlen_q_ptr[i_batch];
1730  }
1731  else
1732  {
1733  // get real # queries & # keys under group mode
1734  const auto adjusted_seqstart_q_ptr = kargs.seqstart_q_ptr + i_batch;
1735  const ck_tile::index_t physical_seqlen_q =
1736  adjusted_seqstart_q_ptr[1] - adjusted_seqstart_q_ptr[0];
1737  kargs.seqlen_q = kargs.seqlen_q_ptr
1738  ? static_cast<ck_tile::index_t>(kargs.seqlen_q_ptr[i_batch])
1739  : physical_seqlen_q;
1740  }
1741 
1742  if constexpr(kIsDeterministic)
1743  {
1744  const auto adjusted_seqstart_k_ptr = kargs.seqstart_k_ptr + i_batch;
1745  const ck_tile::index_t physical_seqlen_k =
1746  adjusted_seqstart_k_ptr[1] - adjusted_seqstart_k_ptr[0];
1747 
1748  // Priority: cu_seqlen_k_ptr > seqlen_k_ptr > physical_seqlen_k
1749  if(kargs.cu_seqlen_k_ptr != nullptr)
1750  {
1751  kargs.seqlen_k =
1752  kargs.cu_seqlen_k_ptr[i_batch + 1] - kargs.cu_seqlen_k_ptr[i_batch];
1753  }
1754  else
1755  {
1756  kargs.seqlen_k =
1757  kargs.seqlen_k_ptr
1758  ? static_cast<ck_tile::index_t>(kargs.seqlen_k_ptr[i_batch])
1759  : physical_seqlen_k;
1760  }
1761  }
1762  // # of required blocks is different in each groups, terminate unnecessary blocks
1763  // earlier
1764  if(kargs.seqlen_q <= i_m0)
1765  {
1766  return;
1767  }
1768  }
1769  else
1770  {
1771  batch_offset_dq = static_cast<long_index_t>(i_batch) * kargs.batch_stride_dq;
1772  batch_offset_dq_acc = static_cast<long_index_t>(i_batch) * kargs.batch_stride_dq_acc;
1773  }
1774 
1775  // for simplicity, batch stride we just modify the pointer
1776  QGradDataType* dq_ptr = reinterpret_cast<QGradDataType*>(kargs.dq_ptr) +
1777  static_cast<long_index_t>(i_nhead) * kargs.nhead_stride_dq +
1778  batch_offset_dq;
1779 
1780  // dQAcc/dQ DRAM and DRAM window
1781  const auto dq_acc_dram = [&, i_nhead_ = i_nhead]() {
1782  if constexpr(kIsDeterministic)
1783  {
1784  const AccDataType* dq_acc_ptr =
1785  reinterpret_cast<const AccDataType*>(kargs.dq_acc_ptr) +
1786  static_cast<long_index_t>(i_nhead_) * (kargs.nhead_stride_dq_acc) +
1787  batch_offset_dq_acc;
1788 
1789  const index_t nsplits = ck_tile::integer_divide_ceil(kargs.seqlen_k, kN0);
1790 
1791  auto dq_acc_dram_naive = make_naive_tensor_view<address_space_enum::global>(
1792  dq_acc_ptr,
1793  make_tuple(nsplits, kargs.seqlen_q, kargs.hdim_q),
1794  make_tuple(kargs.split_stride_dq_acc, kargs.stride_dq_acc, 1),
1796  number<1>{});
1797  return pad_tensor_view(dq_acc_dram_naive,
1800  }
1801  else
1802  {
1803  const AccDataType* dq_acc_ptr =
1804  reinterpret_cast<const AccDataType*>(kargs.dq_acc_ptr) +
1805  static_cast<long_index_t>(i_nhead_) * (kargs.nhead_stride_dq_acc) +
1806  batch_offset_dq_acc;
1807 
1808  auto dq_acc_dram_naive = make_naive_tensor_view<address_space_enum::global>(
1809  dq_acc_ptr,
1810  make_tuple(kargs.seqlen_q, kargs.hdim_q),
1811  make_tuple(kargs.stride_dq_acc, 1),
1813  number<1>{});
1814  return pad_tensor_view(dq_acc_dram_naive,
1817  }
1818  }();
1819 
1820  auto dq_dram = [&]() {
1821  auto dq_dram_naive = make_naive_tensor_view<address_space_enum::global>(
1822  dq_ptr,
1823  make_tuple(kargs.seqlen_q, kargs.hdim_q),
1824  make_tuple(kargs.stride_dq, 1),
1826  number<1>{});
1827  return pad_tensor_view(dq_dram_naive,
1830  }();
1831 
1832  auto dq_acc_dram_window = [&]() {
1833  if constexpr(kIsDeterministic)
1834  {
1835  return make_tile_window(
1836  dq_acc_dram,
1838  {0, i_m0, 0});
1839  }
1840  else
1841  {
1842  return make_tile_window(
1843  dq_acc_dram, make_tuple(number<kM0>{}, number<kQKHeaddim>{}), {i_m0, 0});
1844  }
1845  }();
1846 
1847  auto dq_dram_window =
1848  make_tile_window(dq_dram, make_tuple(number<kM0>{}, number<kQKHeaddim>{}), {i_m0, 0});
1849 
1850  if constexpr(kIsDeterministic)
1851  {
1852  const index_t nsplits = ck_tile::integer_divide_ceil(kargs.seqlen_k, kN0);
1853  FmhaBwdConvertQGrad{}(dq_acc_dram_window, dq_dram_window, nsplits);
1854  }
1855  else
1856  {
1857  FmhaBwdConvertQGrad{}(dq_acc_dram_window, dq_dram_window);
1858  }
1859  }
1860 };
1861 
1862 } // namespace ck_tile
#define CK_TILE_DEVICE
Definition: config.hpp:41
#define CK_TILE_HOST
Definition: config.hpp:40
#define CK_TILE_HOST_DEVICE
Definition: config.hpp:42
#define _TS_
#define _SS_
__host__ constexpr __device__ T max(T x)
Definition: math.hpp:84
__host__ T floor(T x)
Definition: math_v2.hpp:367
CK_TILE_DEVICE void insert_dummy_dep()
Definition: amd_buffer_addressing.hpp:1037
Definition: cluster_descriptor.hpp:13
constexpr CK_TILE_DEVICE auto make_null_tile_window(const WindowLengths &window_lengths)
Definition: null_tile_window.hpp:66
constexpr CK_TILE_HOST_DEVICE auto integer_divide_ceil(X x, Y y)
Definition: math.hpp:149
__device__ uint32_t amd_wave_read_first_lane(uint16_t v)
Definition: amd_buffer_addressing.hpp:35
_Float16 fp16_t
Definition: half.hpp:110
bfloat16_t bf16_t
Definition: bfloat16.hpp:113
int32_t index_t
Definition: integer.hpp:9
constexpr CK_TILE_HOST_DEVICE auto pad_tensor_view(const TensorView &tensor_view, const TileLengths &tile_lengths, DoPads)
Definition: tensor_view.hpp:530
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition: type_traits.hpp:21
int64_t long_index_t
Definition: integer.hpp:11
int32_t int32_t
Definition: integer.hpp:10
constexpr CK_TILE_DEVICE auto make_tile_window(null_tensor_view, const WindowLengths &window_lengths, const multi_index< WindowLengths::size()> &, Ts &&...)
Definition: null_tile_window.hpp:75
constexpr CK_TILE_HOST_DEVICE auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:360
GenericAttentionMaskEnum
Definition: block_masking.hpp:11
constexpr CK_TILE_HOST_DEVICE T max(T x)
Definition: math.hpp:161
typename conditional< predicate, X, Y >::type conditional_t
Definition: functional.hpp:115
__device__ X atomic_add(X *p_dst, const X &x)
unsigned char uint8_t
Definition: stdint.h:124
unsigned __int64 uint64_t
Definition: stdint.h:136
Definition: block_position_encoding.hpp:48
Definition: block_attention_bias_enum.hpp:19
Definition: block_position_encoding.hpp:137
ck_tile::index_t batch_stride_dq
Definition: fmha_bwd_kernel.hpp:1591
ck_tile::index_t batch_stride_dq_acc
Definition: fmha_bwd_kernel.hpp:1592
ck_tile::index_t seqlen_q
Definition: fmha_bwd_kernel.hpp:1570
ck_tile::index_t stride_dq_acc
Definition: fmha_bwd_kernel.hpp:1575
ck_tile::index_t nhead_stride_dq
Definition: fmha_bwd_kernel.hpp:1576
ck_tile::index_t hdim_q
Definition: fmha_bwd_kernel.hpp:1572
ck_tile::index_t stride_dq
Definition: fmha_bwd_kernel.hpp:1574
const void * dq_acc_ptr
Definition: fmha_bwd_kernel.hpp:1567
ck_tile::index_t seqlen_k
Definition: fmha_bwd_kernel.hpp:1571
ck_tile::index_t nhead_stride_dq_acc
Definition: fmha_bwd_kernel.hpp:1577
ck_tile::index_t split_stride_dq_acc
Definition: fmha_bwd_kernel.hpp:1582
const int32_t * seqstart_k_ptr
Definition: fmha_bwd_kernel.hpp:1602
const int32_t * cu_seqlen_q_ptr
Definition: fmha_bwd_kernel.hpp:1605
const int32_t * seqstart_q_ptr
Definition: fmha_bwd_kernel.hpp:1601
const int32_t * seqlen_q_ptr
Definition: fmha_bwd_kernel.hpp:1603
const int32_t * cu_seqlen_k_ptr
Definition: fmha_bwd_kernel.hpp:1606
const int32_t * seqlen_k_ptr
Definition: fmha_bwd_kernel.hpp:1604
Definition: fmha_bwd_kernel.hpp:1526
Definition: fmha_bwd_kernel.hpp:1509
static constexpr bool kIsGroupMode
Definition: fmha_bwd_kernel.hpp:1520
static constexpr bool kPadHeadDimQ
Definition: fmha_bwd_kernel.hpp:1522
static constexpr bool kIsDeterministic
Definition: fmha_bwd_kernel.hpp:1523
static constexpr ck_tile::index_t kBlockSize
Definition: fmha_bwd_kernel.hpp:1511
static constexpr bool kPadSeqLenQ
Definition: fmha_bwd_kernel.hpp:1521
static constexpr ck_tile::index_t kBlockPerCu
Definition: fmha_bwd_kernel.hpp:1512
static constexpr CK_TILE_DEVICE auto GetTileIndex()
Definition: fmha_bwd_kernel.hpp:1697
static constexpr CK_TILE_HOST std::enable_if_t< Cond, Kargs > MakeKargs(const void *dq_acc_ptr, void *dq_ptr, ck_tile::index_t seqlen_q, ck_tile::index_t seqlen_k, ck_tile::index_t hdim_q, ck_tile::index_t stride_dq, ck_tile::index_t stride_dq_acc, ck_tile::index_t nhead_stride_dq, ck_tile::index_t nhead_stride_dq_acc, ck_tile::index_t batch_stride_dq, ck_tile::index_t batch_stride_dq_acc, ck_tile::index_t split_stride_dq_acc)
Definition: fmha_bwd_kernel.hpp:1615
static constexpr CK_TILE_HOST_DEVICE ck_tile::index_t GetSmemSize()
Definition: fmha_bwd_kernel.hpp:1708
CK_TILE_DEVICE void operator()(Kargs kargs) const
Definition: fmha_bwd_kernel.hpp:1710
ck_tile::remove_cvref_t< typename FmhaBwdConvertQGrad::AccDataType > AccDataType
Definition: fmha_bwd_kernel.hpp:1517
static CK_TILE_HOST dim3 BlockSize()
Definition: fmha_bwd_kernel.hpp:1706
static constexpr ck_tile::index_t kM0
Definition: fmha_bwd_kernel.hpp:1513
static constexpr CK_TILE_HOST std::enable_if_t< Cond, Kargs > MakeKargs(const void *dq_acc_ptr, void *dq_ptr, const void *seqstart_q_ptr, const void *seqstart_k_ptr, const void *seqlen_q_ptr, const void *seqlen_k_ptr, const void *cu_seqlen_q_ptr, const void *cu_seqlen_k_ptr, ck_tile::index_t hdim_q, ck_tile::index_t stride_dq, ck_tile::index_t stride_dq_acc, ck_tile::index_t nhead_stride_dq, ck_tile::index_t nhead_stride_dq_acc, ck_tile::index_t split_stride_dq_acc)
Definition: fmha_bwd_kernel.hpp:1651
ck_tile::remove_cvref_t< typename FmhaBwdConvertQGrad::QGradDataType > QGradDataType
Definition: fmha_bwd_kernel.hpp:1518
static constexpr CK_TILE_HOST auto GridSize(ck_tile::index_t batch_size_, ck_tile::index_t nhead_, ck_tile::index_t seqlen_q_)
Definition: fmha_bwd_kernel.hpp:1692
static constexpr ck_tile::index_t kN0
Definition: fmha_bwd_kernel.hpp:1514
ck_tile::remove_cvref_t< FmhaBwdConvertQGrad_ > FmhaBwdConvertQGrad
Definition: fmha_bwd_kernel.hpp:1510
std::conditional_t< kIsGroupMode, FmhaBwdConvertQGradGroupModeKargs, FmhaBwdConvertQGradBatchModeKargs > Kargs
Definition: fmha_bwd_kernel.hpp:1611
static constexpr ck_tile::index_t kQKHeaddim
Definition: fmha_bwd_kernel.hpp:1515
static CK_TILE_HOST std::string GetName()
Definition: fmha_bwd_kernel.hpp:1532
Definition: fmha_bwd_kernel.hpp:192
const void * alibi_slope_ptr
Definition: fmha_bwd_kernel.hpp:194
ck_tile::index_t alibi_slope_stride
Definition: fmha_bwd_kernel.hpp:195
ck_tile::index_t batch_stride_dbias
Definition: fmha_bwd_kernel.hpp:207
ck_tile::index_t batch_stride_bias
Definition: fmha_bwd_kernel.hpp:188
ck_tile::index_t batch_stride_randval
Definition: fmha_bwd_kernel.hpp:272
Definition: fmha_bwd_kernel.hpp:291
ck_tile::index_t batch_stride_v
Definition: fmha_bwd_kernel.hpp:294
ck_tile::index_t batch_stride_k
Definition: fmha_bwd_kernel.hpp:293
ck_tile::index_t batch_stride_q
Definition: fmha_bwd_kernel.hpp:292
ck_tile::index_t batch_stride_do
Definition: fmha_bwd_kernel.hpp:295
ck_tile::index_t batch_stride_dq_acc
Definition: fmha_bwd_kernel.hpp:297
ck_tile::index_t batch_stride_dk
Definition: fmha_bwd_kernel.hpp:298
ck_tile::index_t batch_stride_dv
Definition: fmha_bwd_kernel.hpp:299
ck_tile::index_t batch_stride_lsed
Definition: fmha_bwd_kernel.hpp:296
ck_tile::index_t nhead_stride_dbias
Definition: fmha_bwd_kernel.hpp:202
void * dbias_ptr
Definition: fmha_bwd_kernel.hpp:200
ck_tile::index_t stride_dbias
Definition: fmha_bwd_kernel.hpp:201
Definition: fmha_bwd_kernel.hpp:180
ck_tile::index_t stride_bias
Definition: fmha_bwd_kernel.hpp:182
ck_tile::index_t nhead_stride_bias
Definition: fmha_bwd_kernel.hpp:183
const void * bias_ptr
Definition: fmha_bwd_kernel.hpp:181
uint8_t p_undrop_in_uint8_t
Definition: fmha_bwd_kernel.hpp:263
float rp_undrop
Definition: fmha_bwd_kernel.hpp:261
ck_tile::index_t nhead_stride_randval
Definition: fmha_bwd_kernel.hpp:267
void init_dropout(float p_drop, const uint64_t *seed_ptr, const uint64_t *offset_ptr, float raw_scale)
Definition: fmha_bwd_kernel.hpp:245
float scale_rp_undrop
Definition: fmha_bwd_kernel.hpp:262
void * rand_val_ptr
Definition: fmha_bwd_kernel.hpp:264
ck_tile::index_t stride_randval
Definition: fmha_bwd_kernel.hpp:266
void init_dropout(float p_drop, uint64_t seed, uint64_t offset, float raw_scale)
Definition: fmha_bwd_kernel.hpp:232
Definition: fmha_bwd_kernel.hpp:138
ck_tile::index_t nhead_stride_dk
Definition: fmha_bwd_kernel.hpp:175
ck_tile::index_t stride_do
Definition: fmha_bwd_kernel.hpp:164
ck_tile::index_t seqlen_k
Definition: fmha_bwd_kernel.hpp:150
const void * q_ptr
Definition: fmha_bwd_kernel.hpp:139
ck_tile::index_t hdim_q
Definition: fmha_bwd_kernel.hpp:151
ck_tile::index_t nhead_stride_do
Definition: fmha_bwd_kernel.hpp:172
ck_tile::index_t num_head_q
Definition: fmha_bwd_kernel.hpp:156
const void * lse_ptr
Definition: fmha_bwd_kernel.hpp:142
float raw_scale
Definition: fmha_bwd_kernel.hpp:158
ck_tile::index_t nhead_stride_k
Definition: fmha_bwd_kernel.hpp:170
ck_tile::index_t nhead_stride_q
Definition: fmha_bwd_kernel.hpp:169
ck_tile::index_t stride_dv
Definition: fmha_bwd_kernel.hpp:167
ck_tile::index_t nhead_stride_lsed
Definition: fmha_bwd_kernel.hpp:173
void * dq_acc_ptr
Definition: fmha_bwd_kernel.hpp:145
ck_tile::index_t stride_q
Definition: fmha_bwd_kernel.hpp:161
ck_tile::index_t seqlen_q
Definition: fmha_bwd_kernel.hpp:149
ck_tile::index_t stride_dk
Definition: fmha_bwd_kernel.hpp:166
const void * do_ptr
Definition: fmha_bwd_kernel.hpp:143
float scale
Definition: fmha_bwd_kernel.hpp:159
void * dk_ptr
Definition: fmha_bwd_kernel.hpp:146
ck_tile::index_t nhead_stride_v
Definition: fmha_bwd_kernel.hpp:171
ck_tile::index_t stride_v
Definition: fmha_bwd_kernel.hpp:163
const void * d_ptr
Definition: fmha_bwd_kernel.hpp:144
const void * k_ptr
Definition: fmha_bwd_kernel.hpp:140
ck_tile::index_t nhead_stride_dq_acc
Definition: fmha_bwd_kernel.hpp:174
ck_tile::index_t nhead_ratio_qk
Definition: fmha_bwd_kernel.hpp:157
void * dv_ptr
Definition: fmha_bwd_kernel.hpp:147
const void * v_ptr
Definition: fmha_bwd_kernel.hpp:141
ck_tile::index_t hdim_v
Definition: fmha_bwd_kernel.hpp:152
ck_tile::index_t stride_k
Definition: fmha_bwd_kernel.hpp:162
ck_tile::index_t nhead_stride_dv
Definition: fmha_bwd_kernel.hpp:176
ck_tile::index_t stride_dq_acc
Definition: fmha_bwd_kernel.hpp:165
ck_tile::index_t split_stride_dq_acc
Definition: fmha_bwd_kernel.hpp:277
bool is_drop_seed_offset_from_host
Definition: fmha_bwd_kernel.hpp:227
ValueOrPointer< uint64_t > drop_seed
Definition: fmha_bwd_kernel.hpp:225
ValueOrPointer< uint64_t > drop_offset
Definition: fmha_bwd_kernel.hpp:226
Definition: fmha_bwd_kernel.hpp:131
Definition: fmha_bwd_kernel.hpp:313
const int32_t * seqstart_k_ptr
Definition: fmha_bwd_kernel.hpp:315
const int32_t * seqstart_q_ptr
Definition: fmha_bwd_kernel.hpp:314
const int32_t * seqlen_k_ptr
Definition: fmha_bwd_kernel.hpp:317
const int32_t * cu_seqlen_k_ptr
Definition: fmha_bwd_kernel.hpp:319
const int32_t * seqlen_q_ptr
Definition: fmha_bwd_kernel.hpp:316
const int32_t * cu_seqlen_q_ptr
Definition: fmha_bwd_kernel.hpp:318
Definition: fmha_bwd_kernel.hpp:211
ck_tile::GenericAttentionMaskEnum mask_type
Definition: fmha_bwd_kernel.hpp:213
ck_tile::index_t window_size_right
Definition: fmha_bwd_kernel.hpp:212
ck_tile::index_t window_size_left
Definition: fmha_bwd_kernel.hpp:212
Definition: fmha_bwd_kernel.hpp:84
Definition: fmha_bwd_kernel.hpp:35
ck_tile::remove_cvref_t< KGradEpiloguePipeline_ > KGradEpiloguePipeline
Definition: fmha_bwd_kernel.hpp:37
static constexpr CK_TILE_HOST std::enable_if_t< Cond, Kargs > MakeKargsImpl(const void *q_ptr, const void *k_ptr, const void *v_ptr, const void *bias_ptr, const void *lse_ptr, const void *do_ptr, const void *d_ptr, void *rand_val_ptr, void *dk_ptr, void *dv_ptr, void *dbias_ptr, void *dq_acc_ptr, ck_tile::index_t seqlen_q, ck_tile::index_t seqlen_k, ck_tile::index_t hdim_q, ck_tile::index_t hdim_v, ck_tile::index_t num_head_q, ck_tile::index_t nhead_ratio_qk, float scale, ck_tile::index_t stride_q, ck_tile::index_t stride_k, ck_tile::index_t stride_v, ck_tile::index_t stride_bias, ck_tile::index_t stride_randval, ck_tile::index_t stride_do, ck_tile::index_t stride_dq_acc, ck_tile::index_t stride_dk, ck_tile::index_t stride_dv, ck_tile::index_t stride_dbias, ck_tile::index_t nhead_stride_q, ck_tile::index_t nhead_stride_k, ck_tile::index_t nhead_stride_v, ck_tile::index_t nhead_stride_bias, ck_tile::index_t nhead_stride_randval, ck_tile::index_t nhead_stride_do, ck_tile::index_t nhead_stride_lsed, ck_tile::index_t nhead_stride_dq_acc, ck_tile::index_t nhead_stride_dk, ck_tile::index_t nhead_stride_dv, ck_tile::index_t nhead_stride_dbias, ck_tile::index_t batch_stride_q, ck_tile::index_t batch_stride_k, ck_tile::index_t batch_stride_v, ck_tile::index_t batch_stride_bias, ck_tile::index_t batch_stride_randval, ck_tile::index_t batch_stride_do, ck_tile::index_t batch_stride_lsed, ck_tile::index_t batch_stride_dq_acc, ck_tile::index_t batch_stride_dk, ck_tile::index_t batch_stride_dv, ck_tile::index_t batch_stride_dbias, ck_tile::index_t split_stride_dq_acc, ck_tile::index_t window_size_left, ck_tile::index_t window_size_right, ck_tile::index_t mask_type, float p_drop, std::variant< std::pair< uint64_t, uint64_t >, std::pair< const void *, const void * >> drop_seed_offset)
Definition: fmha_bwd_kernel.hpp:344
static constexpr auto BiasEnum
Definition: fmha_bwd_kernel.hpp:66
static CK_TILE_HOST std::string GetName()
Definition: fmha_bwd_kernel.hpp:90
ck_tile::remove_cvref_t< typename FmhaPipeline::FmhaDropout > FmhaDropout
Definition: fmha_bwd_kernel.hpp:69
ck_tile::remove_cvref_t< typename FmhaPipeline::OGradDataType > OGradDataType
Definition: fmha_bwd_kernel.hpp:57
static constexpr bool kUseQrQtrDorPipeline
Definition: fmha_bwd_kernel.hpp:42
ck_tile::remove_cvref_t< VGradEpiloguePipeline_ > VGradEpiloguePipeline
Definition: fmha_bwd_kernel.hpp:38
ck_tile::remove_cvref_t< typename FmhaPipeline::BiasGradDataType > BiasGradDataType
Definition: fmha_bwd_kernel.hpp:61
static constexpr CK_TILE_HOST_DEVICE ck_tile::index_t GetSmemSize()
Definition: fmha_bwd_kernel.hpp:692
static constexpr ck_tile::index_t kBlockSize
Definition: fmha_bwd_kernel.hpp:40
static constexpr bool kHasMask
Definition: fmha_bwd_kernel.hpp:70
static CK_TILE_HOST dim3 BlockSize()
Definition: fmha_bwd_kernel.hpp:680
ck_tile::remove_cvref_t< typename FmhaPipeline::KDataType > KDataType
Definition: fmha_bwd_kernel.hpp:48
static constexpr CK_TILE_HOST Kargs MakeKargs(Ts... args, const std::tuple< uint64_t, uint64_t > &drop_seed_offset)
Definition: fmha_bwd_kernel.hpp:327
static constexpr bool kIsGroupMode
Definition: fmha_bwd_kernel.hpp:63
ck_tile::remove_cvref_t< typename FmhaPipeline::QGradDataType > QGradDataType
Definition: fmha_bwd_kernel.hpp:58
static constexpr CK_TILE_DEVICE auto GetTileIndex()
Definition: fmha_bwd_kernel.hpp:671
static constexpr CK_TILE_HOST std::enable_if_t< Cond, Kargs > MakeKargsImpl(const void *q_ptr, const void *k_ptr, const void *v_ptr, const void *bias_ptr, const void *lse_ptr, const void *do_ptr, const void *d_ptr, void *rand_val_ptr, void *dk_ptr, void *dv_ptr, void *dbias_ptr, void *dq_acc_ptr, const void *seqstart_q_ptr, const void *seqstart_k_ptr, const void *seqlen_q_ptr, const void *seqlen_k_ptr, const void *cu_seqlen_q_ptr, const void *cu_seqlen_k_ptr, ck_tile::index_t hdim_q, ck_tile::index_t hdim_v, ck_tile::index_t num_head_q, ck_tile::index_t nhead_ratio_qk, float scale, ck_tile::index_t stride_q, ck_tile::index_t stride_k, ck_tile::index_t stride_v, ck_tile::index_t stride_bias, ck_tile::index_t stride_randval, ck_tile::index_t stride_do, ck_tile::index_t stride_dq_acc, ck_tile::index_t stride_dk, ck_tile::index_t stride_dv, ck_tile::index_t stride_dbias, ck_tile::index_t nhead_stride_q, ck_tile::index_t nhead_stride_k, ck_tile::index_t nhead_stride_v, ck_tile::index_t nhead_stride_bias, ck_tile::index_t nhead_stride_randval, ck_tile::index_t nhead_stride_do, ck_tile::index_t nhead_stride_lsed, ck_tile::index_t nhead_stride_dq_acc, ck_tile::index_t nhead_stride_dk, ck_tile::index_t nhead_stride_dv, ck_tile::index_t nhead_stride_dbias, ck_tile::index_t split_stride_dq_acc, ck_tile::index_t window_size_left, ck_tile::index_t window_size_right, ck_tile::index_t mask_type, float p_drop, std::variant< std::pair< uint64_t, uint64_t >, std::pair< const void *, const void * >> drop_seed_offset)
Definition: fmha_bwd_kernel.hpp:512
ck_tile::remove_cvref_t< typename FmhaPipeline::DDataType > DDataType
Definition: fmha_bwd_kernel.hpp:54
ck_tile::remove_cvref_t< typename FmhaPipeline::QDataType > QDataType
Definition: fmha_bwd_kernel.hpp:47
ck_tile::remove_cvref_t< typename FmhaPipeline::LSEDataType > LSEDataType
Definition: fmha_bwd_kernel.hpp:52
static constexpr CK_TILE_HOST auto GridSize(ck_tile::index_t batch_size_, ck_tile::index_t nhead_, ck_tile::index_t seqlen_k_)
Definition: fmha_bwd_kernel.hpp:663
ck_tile::remove_cvref_t< typename FmhaPipeline::VGradDataType > VGradDataType
Definition: fmha_bwd_kernel.hpp:60
ck_tile::remove_cvref_t< typename FmhaPipeline::GemmDataType > GemmDataType
Definition: fmha_bwd_kernel.hpp:51
CK_TILE_DEVICE void operator()(Kargs kargs) const
Definition: fmha_bwd_kernel.hpp:699
static constexpr bool kIsAvailable
Definition: fmha_bwd_kernel.hpp:80
ck_tile::remove_cvref_t< QGradEpiloguePipeline_ > QGradEpiloguePipeline
Definition: fmha_bwd_kernel.hpp:39
static constexpr bool kHasDropout
Definition: fmha_bwd_kernel.hpp:71
ck_tile::remove_cvref_t< typename FmhaPipeline::AccDataType > AccDataType
Definition: fmha_bwd_kernel.hpp:53
CK_TILE_DEVICE void run_(Kargs kargs) const
Definition: fmha_bwd_kernel.hpp:705
ck_tile::remove_cvref_t< typename FmhaPipeline::VDataType > VDataType
Definition: fmha_bwd_kernel.hpp:49
static constexpr bool kHasBiasGrad
Definition: fmha_bwd_kernel.hpp:67
static constexpr bool kIsDeterministic
Definition: fmha_bwd_kernel.hpp:73
ck_tile::remove_cvref_t< typename FmhaPipeline::BiasDataType > BiasDataType
Definition: fmha_bwd_kernel.hpp:50
ck_tile::remove_cvref_t< typename FmhaPipeline::RandValOutputDataType > RandValOutputDataType
Definition: fmha_bwd_kernel.hpp:56
static constexpr index_t kPadHeadDimQ
Definition: fmha_bwd_kernel.hpp:64
ck_tile::remove_cvref_t< FmhaPipeline_ > FmhaPipeline
Definition: fmha_bwd_kernel.hpp:36
static constexpr CK_TILE_HOST Kargs MakeKargs(Ts... args, const std::tuple< const void *, const void * > &drop_seed_offset)
Definition: fmha_bwd_kernel.hpp:336
ck_tile::remove_cvref_t< typename FmhaPipeline::KGradDataType > KGradDataType
Definition: fmha_bwd_kernel.hpp:59
ck_tile::remove_cvref_t< typename FmhaPipeline::FmhaMask > FmhaMask
Definition: fmha_bwd_kernel.hpp:68
static constexpr bool kUseTrLoad
Definition: fmha_bwd_kernel.hpp:74
static constexpr index_t kPadHeadDimV
Definition: fmha_bwd_kernel.hpp:65
static constexpr ck_tile::index_t kBlockPerCu
Definition: fmha_bwd_kernel.hpp:41
std::conditional_t< kIsGroupMode, FmhaBwdGroupModeKargs, FmhaBwdBatchModeKargs > Kargs
Definition: fmha_bwd_kernel.hpp:322
static constexpr bool kIsStoreRandval
Definition: fmha_bwd_kernel.hpp:72
static constexpr index_t kMaxSeqLenQ
Definition: fmha_bwd_kernel.hpp:75
ck_tile::index_t batch_stride_o
Definition: fmha_bwd_kernel.hpp:1301
ck_tile::index_t batch_stride_do
Definition: fmha_bwd_kernel.hpp:1300
ck_tile::index_t batch_stride_d
Definition: fmha_bwd_kernel.hpp:1302
void * d_ptr
Definition: fmha_bwd_kernel.hpp:1283
const void * o_ptr
Definition: fmha_bwd_kernel.hpp:1281
ck_tile::index_t hdim_v
Definition: fmha_bwd_kernel.hpp:1288
ck_tile::index_t nhead_stride_do
Definition: fmha_bwd_kernel.hpp:1293
ck_tile::index_t stride_o
Definition: fmha_bwd_kernel.hpp:1291
ck_tile::index_t nhead_stride_o
Definition: fmha_bwd_kernel.hpp:1294
const void * do_ptr
Definition: fmha_bwd_kernel.hpp:1282
ck_tile::index_t stride_do
Definition: fmha_bwd_kernel.hpp:1290
ck_tile::index_t seqlen_q
Definition: fmha_bwd_kernel.hpp:1287
float p_undrop
Definition: fmha_bwd_kernel.hpp:1285
ck_tile::index_t nhead_stride_d
Definition: fmha_bwd_kernel.hpp:1295
const int32_t * cu_seqlen_q_ptr
Definition: fmha_bwd_kernel.hpp:1309
const int32_t * seqstart_q_ptr
Definition: fmha_bwd_kernel.hpp:1307
const int32_t * seqlen_q_ptr
Definition: fmha_bwd_kernel.hpp:1308
Definition: fmha_bwd_kernel.hpp:1249
Definition: fmha_bwd_kernel.hpp:1233
ck_tile::remove_cvref_t< typename FmhaBwdOGradDotO::ODataType > ODataType
Definition: fmha_bwd_kernel.hpp:1241
static constexpr CK_TILE_HOST_DEVICE ck_tile::index_t GetSmemSize()
Definition: fmha_bwd_kernel.hpp:1401
ck_tile::remove_cvref_t< FmhaBwdOGradDotO_ > FmhaBwdOGradDotO
Definition: fmha_bwd_kernel.hpp:1234
static CK_TILE_HOST dim3 BlockSize()
Definition: fmha_bwd_kernel.hpp:1399
static constexpr CK_TILE_HOST std::enable_if_t< Cond, Kargs > MakeKargs(const void *o_ptr, const void *do_ptr, void *d_ptr, float p_undrop, ck_tile::index_t seqlen_q, ck_tile::index_t hdim_v, ck_tile::index_t stride_do, ck_tile::index_t stride_o, ck_tile::index_t nhead_stride_do, ck_tile::index_t nhead_stride_o, ck_tile::index_t nhead_stride_d, ck_tile::index_t batch_stride_do, ck_tile::index_t batch_stride_o, ck_tile::index_t batch_stride_d)
Definition: fmha_bwd_kernel.hpp:1317
static constexpr ck_tile::index_t kVHeaddim
Definition: fmha_bwd_kernel.hpp:1238
CK_TILE_DEVICE void operator()(Kargs kargs) const
Definition: fmha_bwd_kernel.hpp:1403
static constexpr bool kIsGroupMode
Definition: fmha_bwd_kernel.hpp:1244
ck_tile::remove_cvref_t< typename FmhaBwdOGradDotO::OGradDataType > OGradDataType
Definition: fmha_bwd_kernel.hpp:1242
static constexpr ck_tile::index_t kM0
Definition: fmha_bwd_kernel.hpp:1237
static constexpr ck_tile::index_t kBlockSize
Definition: fmha_bwd_kernel.hpp:1235
static constexpr ck_tile::index_t kBlockPerCu
Definition: fmha_bwd_kernel.hpp:1236
static constexpr CK_TILE_HOST std::enable_if_t< Cond, Kargs > MakeKargs(const void *o_ptr, const void *do_ptr, void *d_ptr, float p_undrop, const void *seqstart_q_ptr, const void *seqlen_q_ptr, const void *cu_seqlen_q_ptr, ck_tile::index_t hdim_v, ck_tile::index_t stride_do, ck_tile::index_t stride_o, ck_tile::index_t nhead_stride_do, ck_tile::index_t nhead_stride_o, ck_tile::index_t nhead_stride_d)
Definition: fmha_bwd_kernel.hpp:1352
static constexpr CK_TILE_DEVICE auto GetTileIndex()
Definition: fmha_bwd_kernel.hpp:1390
static constexpr bool kPadSeqLenQ
Definition: fmha_bwd_kernel.hpp:1245
std::conditional_t< kIsGroupMode, FmhaBwdOGradDotOGroupModeKargs, FmhaBwdOGradDotOBatchModeKargs > Kargs
Definition: fmha_bwd_kernel.hpp:1313
static CK_TILE_HOST std::string GetName()
Definition: fmha_bwd_kernel.hpp:1255
ck_tile::remove_cvref_t< typename FmhaBwdOGradDotO::DDataType > DDataType
Definition: fmha_bwd_kernel.hpp:1240
static constexpr bool kPadHeadDimV
Definition: fmha_bwd_kernel.hpp:1246
static constexpr CK_TILE_HOST auto GridSize(ck_tile::index_t batch_size_, ck_tile::index_t nhead_, ck_tile::index_t seqlen_q_)
Definition: fmha_bwd_kernel.hpp:1385
Definition: integral_constant.hpp:13
Definition: block_fmha_bwd_dq_dk_dv_pipeline_trload_qr_qtr_dor.hpp:777
Definition: coordinate_transform.hpp:1392
Definition: sequence.hpp:49