/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp Source File
fmha_bwd_kernel.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include "ck_tile/core.hpp"
7 #include "ck_tile/ops/common.hpp"
10 
11 #include <string>
12 #include <type_traits>
13 #include <utility>
14 #include <variant>
15 
16 // S[seqlen_q, seqlen_k] = Q[seqlen_q, hdim_q] @ K[seqlen_k, hdim_q]
17 // S'[seqlen_q, seqlen_k] = S[seqlen_q, seqlen_k] * Scale[1]
18 // S''[seqlen_q, seqlen_k] = S'[seqlen_q, seqlen_k] + Bias[seqlen_q, seqlen_k]
19 // P[seqlen_q, seqlen_k] = Softmax(S''[seqlen_q, seqlen_k])
20 // dV[seqlen_k, hdim_v] = P^T[seqlen_k, seqlen_q] @ dO^T[hdim_v, seqlen_q]
21 // dP[seqlen_q, seqlen_k] = dO[seqlen_q, hdim_v] @ V[seqlen_k, hdim_v]
22 // D[seqlen_q] = rowsum(dO[seqlen_q, hdim_v] * O[seqlen_q, hdim_v])
23 // dS''[seqlen_q, seqlen_k] = P[seqlen_q, seqlen_k] * (dP[seqlen_q, seqlen_k] - D[seqlen_q])
24 // dBias[seqlen_q, seqlen_k] = dS'[seqlen_q, seqlen_k] = dS''[seqlen_q, seqlen_k]
25 // dK[seqlen_k, hdim_q] = dS'^T[seqlen_k, seqlen_q] @ Q^T[hdim_q, seqlen_q] * Scale[1]
26 // dQ[seqlen_q, hdim_q] = dS'[seqlen_q, seqlen_k] @ K^T[hdim_q, seqlen_k] * Scale[1]
27 
28 namespace ck_tile {
29 
30 template <typename FmhaPipeline_,
31  typename KGradEpiloguePipeline_,
32  typename VGradEpiloguePipeline_,
33  typename QGradEpiloguePipeline_ = void>
35 {
40  static constexpr ck_tile::index_t kBlockSize = FmhaPipeline::kBlockSize;
41  static constexpr ck_tile::index_t kBlockPerCu = FmhaPipeline::kBlockPerCu;
42  static constexpr bool kUseQrQtrDorPipeline =
44  static_assert(!kUseQrQtrDorPipeline || !std::is_same_v<QGradEpiloguePipeline_, void>,
45  "QrQtrDorPipeline needs QGradEpiloguePipeline");
46 
62 
63  static constexpr bool kIsGroupMode = FmhaPipeline::kIsGroupMode;
64  static constexpr index_t kPadHeadDimQ = FmhaPipeline::kPadHeadDimQ;
65  static constexpr index_t kPadHeadDimV = FmhaPipeline::kPadHeadDimV;
66  static constexpr auto BiasEnum = FmhaPipeline::BiasEnum;
67  static constexpr bool kHasBiasGrad = FmhaPipeline::kHasBiasGrad;
70  static constexpr bool kHasMask = FmhaMask::IsMasking;
71  static constexpr bool kHasDropout = FmhaDropout::IsDropout;
72  static constexpr bool kIsStoreRandval = FmhaDropout::IsStoreRandval;
73  static constexpr bool kIsDeterministic = FmhaPipeline::kIsDeterministic;
74  static constexpr bool kUseTrLoad = FmhaPipeline::kUseTrLoad;
75  static constexpr index_t kMaxSeqLenQ = FmhaPipeline::BlockFmhaShape::kMaxSeqLenQ;
76  static_assert(kUseQrQtrDorPipeline == (kMaxSeqLenQ != 0));
77 #if defined(__gfx950__)
78  static constexpr bool kIsAvailable = true;
79 #else
80  static constexpr bool kIsAvailable = !kUseTrLoad;
81 #endif
82 
83  // clang-format off
84  template <typename T> struct t2s;
85  template <> struct t2s<float> { static constexpr const char * name = "fp32"; };
86  template <> struct t2s<ck_tile::fp16_t> { static constexpr const char * name = "fp16"; };
87  template <> struct t2s<ck_tile::bf16_t> { static constexpr const char * name = "bf16"; };
88  // clang-format on
89 
90  CK_TILE_HOST static std::string GetName()
91  {
92  // sync with generate.py
93  // clang-format off
94  using bfs = typename FmhaPipeline::BlockFmhaShape;
95  using gbr0 = typename bfs::Gemm0BlockWarps;
96  using gbr1 = typename bfs::Gemm1BlockWarps;
97  using gbr4 = typename bfs::Gemm4BlockWarps;
98  using gwt0 = typename bfs::Gemm0WarpTile;
99  using gwt1 = typename bfs::Gemm1WarpTile;
100  #define _SS_ std::string
101  #define _TS_ std::to_string
102  auto pn = [&] () {
103  std::string n;
104  if (kPadHeadDimQ) n += "d" + _TS_(kPadHeadDimQ);
105  if (kPadHeadDimV) n += "dv"+ _TS_(kPadHeadDimV);
106  return n.empty() ? n : std::string("p") + n; }();
107  return
108  _SS_("fmha_bwd_d") + _TS_(bfs::kQKHeaddim) + "_" + _SS_(t2s<QDataType>::name) +
109  "_" + (kIsGroupMode ? "group" : "batch") + "_" +
110  "b" + _TS_(bfs::kM0) + "x" + _TS_(bfs::kN0) + "x" + _TS_(bfs::kK0) + "x" + _TS_(bfs::kK1) + "x" + _TS_(bfs::kK2) + "x" + _TS_(bfs::kK3) + "x" +
111  _TS_(bfs::kK4) + "x" + _TS_(bfs::kQKHeaddim) + "x" + _TS_(bfs::kVHeaddim) + "_" +
112  "r" + _TS_(gbr0::at(ck_tile::number<0>{})) + "x" + _TS_(gbr0::at(ck_tile::number<1>{})) + "x" + _TS_(gbr0::at(ck_tile::number<2>{})) + "_" +
113  "r" + _TS_(gbr1::at(ck_tile::number<0>{})) + "x" + _TS_(gbr1::at(ck_tile::number<1>{})) + "x" + _TS_(gbr1::at(ck_tile::number<2>{})) + "_" +
114  "r" + _TS_(gbr4::at(ck_tile::number<0>{})) + "x" + _TS_(gbr4::at(ck_tile::number<1>{})) + "x" + _TS_(gbr4::at(ck_tile::number<2>{})) + "_" +
115  "w" + _TS_(gwt0::at(ck_tile::number<0>{})) + "x" + _TS_(gwt0::at(ck_tile::number<1>{})) + "x" + _TS_(gwt0::at(ck_tile::number<2>{})) + "_" +
116  "w" + _TS_(gwt1::at(ck_tile::number<0>{})) + "x" + _TS_(gwt1::at(ck_tile::number<1>{})) + "x" + _TS_(gwt1::at(ck_tile::number<2>{})) + "_" +
117  ("o" + _TS_(kBlockPerCu)) + "_" +
118  ("maxq" + _TS_(kMaxSeqLenQ)) +
119  (pn.empty() ? "_npad" : "_" + pn) +
121  (kHasBiasGrad ? "_dbias" : "_ndbias") + (kHasMask ? "_" + _SS_(FmhaMask::name) : "_nmask") + (kHasDropout ? gwt0::at(ck_tile::number<0>{}) == 16? "_dropout_wg16":"_dropout_wg32" : "_ndropout" ) +
122  (kIsStoreRandval ? "_storerandval" : "" ) + (kIsDeterministic ? "_deterministic" : "_ndeterministic" ) + (kUseTrLoad ? "_trload" : "_ntrload");
123  #undef _SS_
124  #undef _TS_
125  // clang-format on
126  }
127 
128  template <ck_tile::index_t I> // to avoid duplicated base class prblem, introduce an template
129  // arg
131  {
132  };
133 
134  // kargs use aggregate initializer, so no constructor will provided
135  // use inheritance to minimize karg size
136  // user need to use MakeKargs() function to create kargs.
138  {
139  const void* q_ptr;
140  const void* k_ptr;
141  const void* v_ptr;
142  const void* lse_ptr;
143  const void* do_ptr;
144  const void* d_ptr;
145  void* dq_acc_ptr; // can be dq_ptr for qrqtrdor pipeline
146  void* dk_ptr;
147  void* dv_ptr;
148 
153 
154  // for MQA/GQA, nhead could be different. This parameter is nhead_q / nhead_k
155  // if this param is larger than 1, indicate MQA/GQA case
158  float raw_scale;
159  float scale;
160 
168 
177  };
178 
180  {
181  const void* bias_ptr = nullptr;
184  };
185 
187  {
189  };
190 
192  {
193  // alibi is batch*nhead*1, no matter in batch/group mode, they are the same
194  const void* alibi_slope_ptr;
195  ck_tile::index_t alibi_slope_stride; // stride in batch, or 0 for all batch share same slope
196  };
197 
199  {
200  void* dbias_ptr = nullptr;
203  };
204 
206  {
208  };
209 
211  {
214  };
215 
217  {
218  template <typename T>
220  {
221  T val;
222  const T* ptr;
223  };
224 
228  };
229 
231  {
232  void init_dropout(float p_drop, uint64_t seed, uint64_t offset, float raw_scale)
233  {
234  float p_undrop = 1.0 - p_drop;
237  rp_undrop = 1.0 / p_undrop;
238  scale_rp_undrop = rp_undrop * raw_scale;
239 
240  this->drop_seed.val = seed;
241  this->drop_offset.val = offset;
242  this->is_drop_seed_offset_from_host = true;
243  }
244 
245  void init_dropout(float p_drop,
246  const uint64_t* seed_ptr,
247  const uint64_t* offset_ptr,
248  float raw_scale)
249  {
250  float p_undrop = 1.0 - p_drop;
253  rp_undrop = 1.0 / p_undrop;
254  scale_rp_undrop = rp_undrop * raw_scale;
255 
256  this->drop_seed.ptr = seed_ptr;
257  this->drop_offset.ptr = offset_ptr;
258  this->is_drop_seed_offset_from_host = false;
259  }
260 
261  float rp_undrop = 1;
262  float scale_rp_undrop = 1;
264  void* rand_val_ptr = nullptr;
265 
268  };
269 
271  {
273  };
274 
276  {
278  };
279 
282  std::conditional_t<BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS,
283  FmhaBwdBatchModeBiasKargs,
284  std::conditional_t<BiasEnum == BlockAttentionBiasEnum::ALIBI,
285  FmhaBwdAlibiKargs,
286  FmhaBwdEmptyKargs<0>>>,
287  std::conditional_t<kHasBiasGrad, FmhaBwdBatchModeBiasGradKargs, FmhaBwdEmptyKargs<1>>,
288  std::conditional_t<kHasMask, FmhaBwdMaskKargs, FmhaBwdEmptyKargs<2>>,
289  std::conditional_t<kHasDropout, FmhaBwdBatchModeDropoutKargs, FmhaBwdEmptyKargs<3>>,
290  std::conditional_t<kIsDeterministic, FmhaBwdDeterministicKargs, FmhaBwdEmptyKargs<4>>
291  {
300  };
301 
304  std::conditional_t<BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS,
305  FmhaBwdCommonBiasKargs,
306  std::conditional_t<BiasEnum == BlockAttentionBiasEnum::ALIBI,
307  FmhaBwdAlibiKargs,
308  FmhaBwdEmptyKargs<0>>>,
309  std::conditional_t<kHasBiasGrad, FmhaBwdCommonBiasGradKargs, FmhaBwdEmptyKargs<1>>,
310  std::conditional_t<kHasMask, FmhaBwdMaskKargs, FmhaBwdEmptyKargs<2>>,
311  std::conditional_t<kHasDropout, FmhaBwdCommonDropoutKargs, FmhaBwdEmptyKargs<3>>,
312  std::conditional_t<kIsDeterministic, FmhaBwdDeterministicKargs, FmhaBwdEmptyKargs<4>>
313  {
317  };
318 
319  using Kargs = std::conditional_t<kIsGroupMode, FmhaBwdGroupModeKargs, FmhaBwdBatchModeKargs>;
320 
321  // std::variant<> can't take in a list initializer, overload for backward compatibility
322  template <typename... Ts>
323  CK_TILE_HOST static constexpr Kargs
324  MakeKargs(Ts... args, const std::tuple<uint64_t, uint64_t>& drop_seed_offset)
325  {
326  return MakeKargsImpl(
327  args..., std::make_pair(std::get<0>(drop_seed_offset), std::get<1>(drop_seed_offset)));
328  }
329 
330  // std::variant<> can't take in a list initializer, overload for backward compatibility
331  template <typename... Ts>
332  CK_TILE_HOST static constexpr Kargs
333  MakeKargs(Ts... args, const std::tuple<const void*, const void*>& drop_seed_offset)
334  {
335  return MakeKargsImpl(
336  args..., std::make_pair(std::get<0>(drop_seed_offset), std::get<1>(drop_seed_offset)));
337  }
338 
339  template <bool Cond = !kIsGroupMode>
340  CK_TILE_HOST static constexpr std::enable_if_t<Cond, Kargs>
341  MakeKargsImpl(const void* q_ptr,
342  const void* k_ptr,
343  const void* v_ptr,
344  const void* bias_ptr,
345  const void* lse_ptr,
346  const void* do_ptr,
347  const void* d_ptr,
348  void* rand_val_ptr,
349  void* dk_ptr,
350  void* dv_ptr,
351  void* dbias_ptr,
352  void* dq_acc_ptr, // can be dq_acc_ptr for qrqtrdor pipeline
353  ck_tile::index_t seqlen_q,
354  ck_tile::index_t seqlen_k,
355  ck_tile::index_t hdim_q,
356  ck_tile::index_t hdim_v,
357  ck_tile::index_t num_head_q,
358  ck_tile::index_t nhead_ratio_qk,
359  float scale,
360  ck_tile::index_t stride_q,
361  ck_tile::index_t stride_k,
362  ck_tile::index_t stride_v,
363  ck_tile::index_t stride_bias,
364  ck_tile::index_t stride_randval,
365  ck_tile::index_t stride_do,
366  ck_tile::index_t stride_dq_acc,
367  ck_tile::index_t stride_dk,
368  ck_tile::index_t stride_dv,
369  ck_tile::index_t stride_dbias,
370  ck_tile::index_t nhead_stride_q,
371  ck_tile::index_t nhead_stride_k,
372  ck_tile::index_t nhead_stride_v,
373  ck_tile::index_t nhead_stride_bias,
374  ck_tile::index_t nhead_stride_randval,
375  ck_tile::index_t nhead_stride_do,
376  ck_tile::index_t nhead_stride_lsed,
377  ck_tile::index_t nhead_stride_dq_acc,
378  ck_tile::index_t nhead_stride_dk,
379  ck_tile::index_t nhead_stride_dv,
380  ck_tile::index_t nhead_stride_dbias,
381  ck_tile::index_t batch_stride_q,
382  ck_tile::index_t batch_stride_k,
383  ck_tile::index_t batch_stride_v,
384  ck_tile::index_t batch_stride_bias,
385  ck_tile::index_t batch_stride_randval,
386  ck_tile::index_t batch_stride_do,
387  ck_tile::index_t batch_stride_lsed,
388  ck_tile::index_t batch_stride_dq_acc,
389  ck_tile::index_t batch_stride_dk,
390  ck_tile::index_t batch_stride_dv,
391  ck_tile::index_t batch_stride_dbias,
392  ck_tile::index_t split_stride_dq_acc,
393  ck_tile::index_t window_size_left,
394  ck_tile::index_t window_size_right,
395  ck_tile::index_t mask_type,
396  float p_drop,
397  std::variant<std::pair<uint64_t, uint64_t>, std::pair<const void*, const void*>>
398  drop_seed_offset)
399  {
400  Kargs kargs{{q_ptr,
401  k_ptr,
402  v_ptr,
403  lse_ptr,
404  do_ptr,
405  d_ptr,
406  dq_acc_ptr,
407  dk_ptr,
408  dv_ptr,
409  seqlen_q,
410  seqlen_k,
411  hdim_q,
412  hdim_v,
413  num_head_q,
414  nhead_ratio_qk,
415  scale,
416  static_cast<float>(scale * ck_tile::log2e_v<>),
417  stride_q,
418  stride_k,
419  stride_v,
420  stride_do,
421  stride_dq_acc,
422  stride_dk,
423  stride_dv,
424  nhead_stride_q,
425  nhead_stride_k,
426  nhead_stride_v,
427  nhead_stride_do,
428  nhead_stride_lsed,
429  nhead_stride_dq_acc,
430  nhead_stride_dk,
431  nhead_stride_dv}, // args for common karg
432  {}, // placeholder for bias
433  {}, // placeholder for dbias
434  {}, // placeholder for mask
435  {}, // placeholder for dropout
436  {}, // placeholder for deterministic
437  batch_stride_q,
438  batch_stride_k,
439  batch_stride_v,
440  batch_stride_do,
441  batch_stride_lsed,
442  batch_stride_dq_acc,
443  batch_stride_dk,
444  batch_stride_dv};
445 
447  {
448  kargs.bias_ptr = bias_ptr;
449  kargs.stride_bias = stride_bias;
450  kargs.nhead_stride_bias = nhead_stride_bias;
451  kargs.batch_stride_bias = batch_stride_bias;
452  }
453  else if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI)
454  {
455  kargs.alibi_slope_ptr = bias_ptr;
456  kargs.alibi_slope_stride = stride_bias;
457  }
458 
459  if constexpr(kHasBiasGrad)
460  {
461  kargs.dbias_ptr = dbias_ptr;
462  kargs.stride_dbias = stride_dbias;
463  kargs.nhead_stride_dbias = nhead_stride_dbias;
464  kargs.batch_stride_dbias = batch_stride_dbias;
465  }
466 
467  if constexpr(kHasMask)
468  {
469  kargs.window_size_left = window_size_left;
470  kargs.window_size_right = window_size_right;
471  kargs.mask_type = static_cast<ck_tile::GenericAttentionMaskEnum>(mask_type);
472  }
473 
474  if constexpr(kHasDropout)
475  {
476  if(drop_seed_offset.index() == 0) // seed & offset come from host
477  {
478  const auto& [seed, offset] = std::get<0>(drop_seed_offset);
479  kargs.init_dropout(p_drop, seed, offset, scale);
480  }
481  else // seed & offset come from device
482  {
483  const auto& [seed_ptr, offset_ptr] = std::get<1>(drop_seed_offset);
484  kargs.init_dropout(p_drop,
485  reinterpret_cast<const uint64_t*>(seed_ptr),
486  reinterpret_cast<const uint64_t*>(offset_ptr),
487  scale);
488  }
489 
490  if constexpr(kIsStoreRandval)
491  {
492  kargs.rand_val_ptr = rand_val_ptr;
493  kargs.stride_randval = stride_randval;
494  kargs.nhead_stride_randval = nhead_stride_randval;
495  kargs.batch_stride_randval = batch_stride_randval;
496  }
497  }
498 
499  if constexpr(kIsDeterministic && !kUseQrQtrDorPipeline)
500  {
501  kargs.split_stride_dq_acc = split_stride_dq_acc;
502  }
503 
504  return kargs;
505  }
506 
507  template <bool Cond = kIsGroupMode>
508  CK_TILE_HOST static constexpr std::enable_if_t<Cond, Kargs>
509  MakeKargsImpl(const void* q_ptr,
510  const void* k_ptr,
511  const void* v_ptr,
512  const void* bias_ptr,
513  const void* lse_ptr,
514  const void* do_ptr,
515  const void* d_ptr,
516  void* rand_val_ptr,
517  void* dk_ptr,
518  void* dv_ptr,
519  void* dbias_ptr,
520  void* dq_acc_ptr,
521  const void* seqstart_q_ptr,
522  const void* seqstart_k_ptr,
523  const void* seqlen_k_ptr,
524  ck_tile::index_t hdim_q,
525  ck_tile::index_t hdim_v,
526  ck_tile::index_t num_head_q,
527  ck_tile::index_t nhead_ratio_qk,
528  float scale,
529  ck_tile::index_t stride_q,
530  ck_tile::index_t stride_k,
531  ck_tile::index_t stride_v,
532  ck_tile::index_t stride_bias,
533  ck_tile::index_t stride_randval,
534  ck_tile::index_t stride_do,
535  ck_tile::index_t stride_dq_acc,
536  ck_tile::index_t stride_dk,
537  ck_tile::index_t stride_dv,
538  ck_tile::index_t stride_dbias,
539  ck_tile::index_t nhead_stride_q,
540  ck_tile::index_t nhead_stride_k,
541  ck_tile::index_t nhead_stride_v,
542  ck_tile::index_t nhead_stride_bias,
543  ck_tile::index_t nhead_stride_randval,
544  ck_tile::index_t nhead_stride_do,
545  ck_tile::index_t nhead_stride_lsed,
546  ck_tile::index_t nhead_stride_dq_acc,
547  ck_tile::index_t nhead_stride_dk,
548  ck_tile::index_t nhead_stride_dv,
549  ck_tile::index_t nhead_stride_dbias,
550  ck_tile::index_t split_stride_dq_acc,
551  ck_tile::index_t window_size_left,
552  ck_tile::index_t window_size_right,
553  ck_tile::index_t mask_type,
554  float p_drop,
555  std::variant<std::pair<uint64_t, uint64_t>, std::pair<const void*, const void*>>
556  drop_seed_offset)
557  {
558  Kargs kargs{{q_ptr,
559  k_ptr,
560  v_ptr,
561  lse_ptr,
562  do_ptr,
563  d_ptr,
564  dq_acc_ptr,
565  dk_ptr,
566  dv_ptr,
567  -1, // seqlen will be updated by another pointer
568  -1, //
569  hdim_q,
570  hdim_v,
571  num_head_q,
572  nhead_ratio_qk,
573  scale,
574  static_cast<float>(scale * ck_tile::log2e_v<>),
575  stride_q,
576  stride_k,
577  stride_v,
578  stride_do,
579  stride_dq_acc,
580  stride_dk,
581  stride_dv,
582  nhead_stride_q,
583  nhead_stride_k,
584  nhead_stride_v,
585  nhead_stride_do,
586  nhead_stride_lsed,
587  nhead_stride_dq_acc,
588  nhead_stride_dk,
589  nhead_stride_dv}, // args for common karg
590  {}, // placeholder for bias
591  {}, // placeholder for dbias
592  {}, // placeholder for mask
593  {}, // placeholder for dropout
594  {}, // placeholder for deterministic
595  reinterpret_cast<const int32_t*>(seqstart_q_ptr),
596  reinterpret_cast<const int32_t*>(seqstart_k_ptr),
597  reinterpret_cast<const int32_t*>(seqlen_k_ptr)};
598 
600  {
601  kargs.bias_ptr = bias_ptr;
602  kargs.stride_bias = stride_bias;
603  kargs.nhead_stride_bias = nhead_stride_bias;
604  }
605  else if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI)
606  {
607  kargs.alibi_slope_ptr = bias_ptr;
608  kargs.alibi_slope_stride = stride_bias;
609  }
610  if constexpr(kHasBiasGrad)
611  {
612  kargs.dbias_ptr = dbias_ptr;
613  kargs.stride_dbias = stride_dbias;
614  kargs.nhead_stride_dbias = nhead_stride_dbias;
615  }
616  if constexpr(kHasMask)
617  {
618  kargs.window_size_left = window_size_left;
619  kargs.window_size_right = window_size_right;
620  kargs.mask_type = static_cast<ck_tile::GenericAttentionMaskEnum>(mask_type);
621  }
622  if constexpr(kHasDropout)
623  {
624  if(drop_seed_offset.index() == 0) // seed & offset come from host
625  {
626  const auto& [seed, offset] = std::get<0>(drop_seed_offset);
627  kargs.init_dropout(p_drop, seed, offset, scale);
628  }
629  else // seed & offset come from device
630  {
631  const auto& [seed_ptr, offset_ptr] = std::get<1>(drop_seed_offset);
632  kargs.init_dropout(p_drop,
633  reinterpret_cast<const uint64_t*>(seed_ptr),
634  reinterpret_cast<const uint64_t*>(offset_ptr),
635  scale);
636  }
637 
638  if constexpr(kIsStoreRandval)
639  {
640  kargs.rand_val_ptr = rand_val_ptr;
641  kargs.stride_randval = stride_randval;
642  kargs.nhead_stride_randval = nhead_stride_randval;
643  }
644  }
645  if constexpr(kIsDeterministic)
646  {
647  kargs.split_stride_dq_acc = split_stride_dq_acc;
648  }
649 
650  return kargs;
651  }
652 
653  CK_TILE_HOST static constexpr auto
655  {
656  return dim3(
657  kUseQrQtrDorPipeline ? 1 : ck_tile::integer_divide_ceil(seqlen_k_, FmhaPipeline::kN0),
658  nhead_,
659  batch_size_);
660  }
661 
662  CK_TILE_DEVICE static constexpr auto GetTileIndex()
663  {
664  const index_t i_block = blockIdx.x;
665  const index_t i_nhead = blockIdx.y;
666  const index_t i_batch = blockIdx.z;
667 
668  return ck_tile::make_tuple(i_block, i_nhead, i_batch);
669  }
670 
671  CK_TILE_HOST static constexpr auto BlockSize() { return dim3(kBlockSize); }
672 
674  {
675  return ck_tile::max(FmhaPipeline::GetSmemSize(),
676  KGradEpiloguePipeline::GetSmemSize(),
677  VGradEpiloguePipeline::GetSmemSize());
678  }
679 
680  CK_TILE_DEVICE void operator()(Kargs kargs) const
681  {
682  if constexpr(kIsAvailable)
683  run_(std::move(kargs));
684  }
685 
686  CK_TILE_DEVICE void run_(Kargs kargs) const
687  {
688  // allocate LDS
689  __shared__ char smem_ptr[GetSmemSize()];
690 
691  // divide problem
692  const auto [i_tile_n, i_nhead, i_batch] = GetTileIndex();
693 
694  const index_t i_n0 = amd_wave_read_first_lane(i_tile_n * FmhaPipeline::kN0);
695 
696  long_index_t batch_offset_q = 0;
697  long_index_t batch_offset_k = 0;
698  long_index_t batch_offset_v = 0;
699  long_index_t batch_offset_bias = 0;
700  long_index_t batch_offset_randval = 0;
701  long_index_t batch_offset_do = 0;
702  long_index_t batch_offset_lsed = 0;
703  long_index_t batch_offset_dq_acc = 0;
704  long_index_t batch_offset_dk = 0;
705  long_index_t batch_offset_dv = 0;
706  long_index_t batch_offset_dbias = 0;
707 
708  if constexpr(kIsGroupMode)
709  {
710  // get starting offset for each batch
711  const long_index_t query_start = kargs.seqstart_q_ptr[i_batch];
712  const long_index_t key_start = kargs.seqstart_k_ptr[i_batch];
713 
714  batch_offset_q = query_start * kargs.stride_q;
715  batch_offset_k = key_start * kargs.stride_k;
716  batch_offset_v = key_start * kargs.stride_v;
717  batch_offset_do = query_start * kargs.stride_do;
718  batch_offset_lsed = query_start;
719  batch_offset_dq_acc = query_start * kargs.stride_dq_acc;
720  batch_offset_dk = key_start * kargs.stride_dk;
721  batch_offset_dv = key_start * kargs.stride_dv;
723  {
724  batch_offset_bias = query_start * kargs.stride_bias;
725  }
726  if constexpr(kHasBiasGrad)
727  {
728  batch_offset_dbias = query_start * kargs.stride_dbias;
729  }
730  else
731  {
732  batch_offset_dbias = key_start;
733  }
734  if constexpr(kIsStoreRandval)
735  {
736  batch_offset_randval = query_start * kargs.stride_randval;
737  }
738 
739  // get real # queries & # keys under group mode
740  const auto adjusted_seqstart_q_ptr = kargs.seqstart_q_ptr + i_batch;
741  kargs.seqlen_q = adjusted_seqstart_q_ptr[1] - adjusted_seqstart_q_ptr[0];
742  if(kargs.seqlen_k_ptr != nullptr)
743  {
744  kargs.seqlen_k = kargs.seqlen_k_ptr[i_batch];
745  }
746  else
747  {
748  const auto adjusted_seqstart_k_ptr = kargs.seqstart_k_ptr + i_batch;
749  kargs.seqlen_k = adjusted_seqstart_k_ptr[1] - adjusted_seqstart_k_ptr[0];
750  }
751 
752  // # of required blocks is different in each groups, terminate unnecessary blocks
753  // earlier
754  if constexpr(!kUseQrQtrDorPipeline)
755  if(kargs.seqlen_k <= i_n0)
756  return;
757  }
758  else
759  {
760  batch_offset_q = static_cast<long_index_t>(i_batch) * kargs.batch_stride_q;
761  batch_offset_k = static_cast<long_index_t>(i_batch) * kargs.batch_stride_k;
762  batch_offset_v = static_cast<long_index_t>(i_batch) * kargs.batch_stride_v;
763  batch_offset_do = static_cast<long_index_t>(i_batch) * kargs.batch_stride_do;
764  batch_offset_lsed = static_cast<long_index_t>(i_batch) * kargs.batch_stride_lsed;
765  batch_offset_dq_acc = static_cast<long_index_t>(i_batch) * kargs.batch_stride_dq_acc;
766  batch_offset_dk = static_cast<long_index_t>(i_batch) * kargs.batch_stride_dk;
767  batch_offset_dv = static_cast<long_index_t>(i_batch) * kargs.batch_stride_dv;
769  {
770  batch_offset_bias = static_cast<long_index_t>(i_batch) * kargs.batch_stride_bias;
771  }
772  if constexpr(kHasBiasGrad)
773  {
774  batch_offset_dbias = static_cast<long_index_t>(i_batch) * kargs.batch_stride_dbias;
775  }
776  if constexpr(kIsStoreRandval)
777  {
778  batch_offset_randval =
779  static_cast<long_index_t>(i_batch) * kargs.batch_stride_randval;
780  }
781  }
782 
783  // for simplicity, batch stride we just modify the pointer
784  const QDataType* q_ptr = reinterpret_cast<const QDataType*>(kargs.q_ptr) +
785  static_cast<long_index_t>(i_nhead) * kargs.nhead_stride_q +
786  batch_offset_q;
787  const KDataType* k_ptr =
788  reinterpret_cast<const KDataType*>(kargs.k_ptr) +
789  static_cast<long_index_t>(i_nhead / kargs.nhead_ratio_qk) * kargs.nhead_stride_k +
790  batch_offset_k;
791  const VDataType* v_ptr =
792  reinterpret_cast<const VDataType*>(kargs.v_ptr) +
793  static_cast<long_index_t>(i_nhead / kargs.nhead_ratio_qk) * kargs.nhead_stride_v +
794  batch_offset_v;
795  const LSEDataType* lse_ptr = reinterpret_cast<const LSEDataType*>(kargs.lse_ptr) +
796  static_cast<long_index_t>(i_nhead) * kargs.nhead_stride_lsed +
797  batch_offset_lsed;
798  const DDataType* d_ptr = reinterpret_cast<const DDataType*>(kargs.d_ptr) +
799  static_cast<long_index_t>(i_nhead) * kargs.nhead_stride_lsed +
800  batch_offset_lsed;
801  const OGradDataType* do_ptr = reinterpret_cast<const OGradDataType*>(kargs.do_ptr) +
802  static_cast<long_index_t>(i_nhead) * kargs.nhead_stride_do +
803  batch_offset_do;
804  auto dk_ptr = reinterpret_cast<KGradDataType*>(kargs.dk_ptr) +
805  static_cast<long_index_t>(i_nhead) * kargs.nhead_stride_dk + batch_offset_dk;
806  auto dv_ptr = reinterpret_cast<VGradDataType*>(kargs.dv_ptr) +
807  static_cast<long_index_t>(i_nhead) * kargs.nhead_stride_dv + batch_offset_dv;
808 
809  // Q/K/V/LSE/D/dO/dQ/dK/dV DRAM and DRAM window
810  const auto q_dram_naive = make_naive_tensor_view<address_space_enum::global>(
811  q_ptr,
812  make_tuple(kargs.seqlen_q, kargs.hdim_q),
813  make_tuple(kargs.stride_q, 1),
815  number<1>{});
816  const auto q_dram = pad_tensor_view(
817  q_dram_naive,
819  sequence<false, (kPadHeadDimQ > 0)>{});
820 
821  const auto k_dram_naive = make_naive_tensor_view<address_space_enum::global>(
822  k_ptr,
823  make_tuple(kargs.seqlen_k, kargs.hdim_q),
824  make_tuple(kargs.stride_k, 1),
826  number<1>{});
827  const auto k_dram = pad_tensor_view(
828  k_dram_naive,
830  sequence<false, (kPadHeadDimQ > 0)>{});
831 
832  const auto v_dram = [&]() {
833  const auto v_dram_naive = make_naive_tensor_view<address_space_enum::global>(
834  v_ptr,
835  make_tuple(kargs.seqlen_k, kargs.hdim_v),
836  make_tuple(kargs.stride_v, 1),
838  number<1>{});
839  return pad_tensor_view(
840  v_dram_naive,
842  sequence<false, (kPadHeadDimV > 0)>{});
843  }();
844 
845  // lse and d should be fine to read unpaded data as they are not on the reduction dimension
846  const auto lse_dram = make_naive_tensor_view_packed<address_space_enum::global>(
847  lse_ptr, make_tuple(kargs.seqlen_q), number<FmhaPipeline::kM0>{});
848 
849  const auto d_dram = make_naive_tensor_view_packed<address_space_enum::global>(
850  d_ptr, make_tuple(kargs.seqlen_q), number<FmhaPipeline::kM0>{});
851 
852  const auto do_dram_naive = make_naive_tensor_view<address_space_enum::global>(
853  do_ptr,
854  make_tuple(kargs.seqlen_q, kargs.hdim_v),
855  make_tuple(kargs.stride_do, 1),
857  number<1>{});
858  const auto do_dram = pad_tensor_view(
859  do_dram_naive,
861  sequence<false, (kPadHeadDimV > 0)>{});
862 
863  auto q_dram_window = make_tile_window(
864  q_dram,
866  {0, 0});
867 
868  auto k_dram_window = make_tile_window(
869  k_dram,
871  {i_n0, 0});
872 
873  auto v_dram_window = make_tile_window(
874  v_dram,
876  {i_n0, 0});
877 
878  auto do_dram_window = make_tile_window(
879  do_dram,
881  {0, 0});
882 
883  auto dq_dram_window = [&, i_tile_n_ = i_tile_n, i_nhead_ = i_nhead]() {
884  constexpr bool kUseKSplit = !kUseQrQtrDorPipeline && kIsDeterministic;
885  using DType = std::conditional_t<kUseQrQtrDorPipeline, QGradDataType, AccDataType>;
886 
887  auto dq_acc_ptr = reinterpret_cast<DType*>(kargs.dq_acc_ptr) + [&]() {
888  if constexpr(kUseKSplit)
889  return static_cast<long_index_t>(i_nhead_) * kargs.nhead_stride_dq_acc +
890  static_cast<long_index_t>(i_tile_n_) * kargs.split_stride_dq_acc +
891  batch_offset_dq_acc;
892  else
893  return static_cast<long_index_t>(i_nhead_) * kargs.nhead_stride_dq_acc +
894  batch_offset_dq_acc;
895  }();
896 
897  constexpr auto DstInMemOp = conditional_expr<kUseKSplit>(
898  memory_operation_enum::set, memory_operation_enum::atomic_add);
899  const auto dq_acc_dram_naive =
900  make_naive_tensor_view<address_space_enum::global, DstInMemOp>(
901  dq_acc_ptr,
902  make_tuple(kargs.seqlen_q, kargs.hdim_q),
903  make_tuple(kargs.stride_dq_acc, 1),
905  number<1>{});
906  const auto dq_acc_dram = pad_tensor_view(
907  dq_acc_dram_naive,
909  sequence<false, (kPadHeadDimQ > 0)>{});
910  return make_tile_window(
911  dq_acc_dram,
913  {0, 0});
914  }();
915 
916  auto lse_dram_window =
918 
919  auto d_dram_window = make_tile_window(d_dram, make_tuple(number<FmhaPipeline::kM0>{}), {0});
920 
923  constexpr auto bias_dram_window_lengths =
925  const auto bias_dram_window = [&, i_nhead_ = i_nhead]() {
927  {
928  const BiasDataType* bias_ptr =
929  reinterpret_cast<const BiasDataType*>(kargs.bias_ptr) +
930  static_cast<long_index_t>(i_nhead_) * kargs.nhead_stride_bias +
931  batch_offset_bias;
932 
933  const auto bias_dram = [&]() {
934  const auto bias_dram_naive = make_naive_tensor_view<address_space_enum::global>(
935  bias_ptr,
936  make_tuple(kargs.seqlen_q, kargs.seqlen_k),
937  make_tuple(kargs.stride_bias, 1),
939  number<1>{});
940 
941  return pad_tensor_view(
942  bias_dram_naive, bias_dram_window_lengths, sequence<false, true>{});
943  }();
944 
945  return make_tile_window(bias_dram, bias_dram_window_lengths, {0, i_n0});
946  }
947  else
948  {
949  return make_null_tile_window(bias_dram_window_lengths);
950  }
951  }();
952 
953  auto dbias_dram_window = [&, i_nhead_ = i_nhead]() {
954  if constexpr(kHasBiasGrad)
955  {
956  BiasGradDataType* dbias_ptr =
957  reinterpret_cast<BiasGradDataType*>(kargs.dbias_ptr) +
958  static_cast<long_index_t>(i_nhead_) * kargs.nhead_stride_dbias +
959  batch_offset_dbias;
960 
961  auto dbias_dram = [&]() {
962  const auto dbias_dram_naive =
963  make_naive_tensor_view<address_space_enum::global>(
964  dbias_ptr,
965  make_tuple(kargs.seqlen_q, kargs.seqlen_k),
966  make_tuple(kargs.stride_dbias, 1),
968  number<1>{});
969 
970  return pad_tensor_view(
971  dbias_dram_naive, bias_dram_window_lengths, sequence<false, true>{});
972  }();
973 
974  return make_tile_window(dbias_dram, bias_dram_window_lengths, {0, i_n0});
975  }
976  else
977  {
978  return make_null_tile_window(bias_dram_window_lengths);
979  }
980  }();
981 
982  // WA i_batch capture structure binding before c++20
983  auto position_encoding = [&, i_batch_ = i_batch, i_nhead_ = i_nhead]() {
984  if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI)
985  {
986  // data loading, shared by entire wg
987  // TODO: how to use s_read?
988  AccDataType slope = *(reinterpret_cast<const AccDataType*>(kargs.alibi_slope_ptr) +
989  i_batch_ * kargs.alibi_slope_stride + i_nhead_);
990  slope *= ck_tile::log2e_v<>;
991  if constexpr(kHasMask)
992  {
993  return make_alibi_from_lr_mask<AccDataType, false>(slope,
994  kargs.window_size_left,
995  kargs.window_size_right,
996  kargs.seqlen_q,
997  kargs.seqlen_k,
998  kargs.mask_type);
999  }
1000  else
1001  {
1003  slope, kargs.seqlen_q, kargs.seqlen_k, AlibiMode::FROM_BOTTOM_RIGHT};
1004  }
1005  }
1006  else
1007  {
1009  }
1010  }();
1011 
1012  // dropout
1013  float rp_undrop = 1;
1014  float scale_rp_undrop = 1;
1015  if constexpr(kHasDropout)
1016  {
1017  rp_undrop = kargs.rp_undrop;
1018  scale_rp_undrop = kargs.scale_rp_undrop;
1019  }
1020  auto dropout = [&, i_nhead_ = i_nhead, i_batch_ = i_batch]() {
1021  if constexpr(kHasDropout)
1022  {
1023  return FmhaDropout{i_batch_,
1024  i_nhead_,
1025  kargs.num_head_q,
1026  kargs.is_drop_seed_offset_from_host ? kargs.drop_seed.val
1027  : *kargs.drop_seed.ptr,
1028  kargs.is_drop_seed_offset_from_host ? kargs.drop_offset.val
1029  : *kargs.drop_offset.ptr,
1030  kargs.rp_undrop,
1031  kargs.p_undrop_in_uint8_t};
1032  }
1033  else
1034  {
1035  return FmhaDropout{};
1036  };
1037  }();
1038 
1039  auto randval_dram_window = [&, i_nhead_ = i_nhead]() {
1040  constexpr auto randval_dram_window_lengths =
1042  if constexpr(kIsStoreRandval)
1043  {
1044  RandValOutputDataType* rand_val_ptr =
1045  reinterpret_cast<RandValOutputDataType*>(kargs.rand_val_ptr) +
1046  static_cast<long_index_t>(i_nhead_) * kargs.nhead_stride_randval +
1047  batch_offset_randval;
1048 
1049  const auto randval_dram = [&]() {
1050  const auto randval_dram_naive =
1051  make_naive_tensor_view<address_space_enum::global>(
1052  rand_val_ptr,
1053  make_tuple(kargs.seqlen_q, kargs.seqlen_k),
1054  make_tuple(kargs.stride_randval, 1),
1055  number<1>{},
1056  number<1>{});
1057 
1058  return pad_tensor_view(
1059  randval_dram_naive, randval_dram_window_lengths, sequence<false, true>{});
1060  }();
1061 
1062  return make_tile_window(randval_dram, randval_dram_window_lengths, {0, i_n0});
1063  }
1064  else
1065  {
1066  return make_null_tile_window(randval_dram_window_lengths);
1067  }
1068  }();
1069 
1070  FmhaMask mask = [&]() {
1071  if constexpr(kHasMask)
1072  return ck_tile::make_generic_attention_mask_from_lr_window<FmhaMask>(
1073  kargs.window_size_left,
1074  kargs.window_size_right,
1075  kargs.seqlen_q,
1076  kargs.seqlen_k,
1078  else
1079  return FmhaMask{kargs.seqlen_q, kargs.seqlen_k};
1080  }();
1081 
1082  auto dk_dram = [&]() {
1083  const auto dk_dram_naive = make_naive_tensor_view<address_space_enum::global>(
1084  dk_ptr,
1085  make_tuple(kargs.seqlen_k, kargs.hdim_q),
1086  make_tuple(kargs.stride_dk, 1),
1088  number<1>{});
1089 
1090  return pad_tensor_view(
1091  dk_dram_naive,
1093  sequence<false, (kPadHeadDimQ > 0)>{});
1094  }();
1095 
1096  auto dv_dram = [&]() {
1097  const auto dv_dram_naive = make_naive_tensor_view<address_space_enum::global>(
1098  dv_ptr,
1099  make_tuple(kargs.seqlen_k, kargs.hdim_v),
1100  make_tuple(kargs.stride_dv, 1),
1102  number<1>{});
1103 
1104  return pad_tensor_view(
1105  dv_dram_naive,
1107  sequence<false, (kPadHeadDimV > 0)>{});
1108  }();
1109 
1110  auto dk_dram_window = make_tile_window(
1111  dk_dram,
1113  {i_n0, 0});
1114 
1115  auto dv_dram_window = make_tile_window(
1116  dv_dram,
1118  {i_n0, 0});
1119  if constexpr(!kUseQrQtrDorPipeline)
1120  {
1121  auto [dk_acc_tile, dv_acc_tile] = FmhaPipeline{}(smem_ptr,
1122  q_dram_window,
1123  k_dram_window,
1124  v_dram_window,
1125  bias_dram_window,
1126  randval_dram_window,
1127  do_dram_window,
1128  lse_dram_window,
1129  d_dram_window,
1130  dq_dram_window,
1131  dbias_dram_window,
1132  mask,
1133  position_encoding,
1134  kargs.raw_scale,
1135  kargs.scale,
1136  rp_undrop,
1137  scale_rp_undrop,
1138  dropout);
1139 
1140  KGradEpiloguePipeline{}(dk_dram_window, dk_acc_tile, nullptr);
1141  VGradEpiloguePipeline{}(dv_dram_window, dv_acc_tile, nullptr);
1142  }
1143  else
1144  {
1145  FmhaPipeline{}(smem_ptr,
1146  q_dram_window,
1147  k_dram_window,
1148  v_dram_window,
1149  bias_dram_window,
1150  randval_dram_window,
1151  do_dram_window,
1152  lse_dram_window,
1153  d_dram_window,
1154  dq_dram_window,
1155  dk_dram_window,
1156  dv_dram_window,
1157  dbias_dram_window,
1161  mask,
1162  position_encoding,
1163  kargs.raw_scale,
1164  kargs.scale,
1165  rp_undrop,
1166  scale_rp_undrop,
1167  dropout);
1168  }
1169  }
1170 };
1171 
1172 template <typename FmhaBwdOGradDotO_>
1174 {
1176  static constexpr ck_tile::index_t kBlockSize = FmhaBwdOGradDotO::kBlockSize;
1177  static constexpr ck_tile::index_t kBlockPerCu = FmhaBwdOGradDotO::kBlockPerCu;
1178  static constexpr ck_tile::index_t kM0 = kBlockSize;
1179  static constexpr ck_tile::index_t kVHeaddim = FmhaBwdOGradDotO::kVHeaddim;
1180 
1184 
1185  static constexpr bool kIsGroupMode = FmhaBwdOGradDotO::kIsGroupMode;
1186  static constexpr bool kPadSeqLenQ = FmhaBwdOGradDotO::kPadSeqLenQ;
1187  static constexpr bool kPadHeadDimV = FmhaBwdOGradDotO::kPadHeadDimV;
1188 
1189  // clang-format off
1190  template <typename T> struct t2s;
1191  template <> struct t2s<float> { static constexpr const char * name = "fp32"; };
1192  template <> struct t2s<ck_tile::fp16_t> { static constexpr const char * name = "fp16"; };
1193  template <> struct t2s<ck_tile::bf16_t> { static constexpr const char * name = "bf16"; };
1194  // clang-format on
1195 
1196  CK_TILE_HOST static std::string GetName()
1197  {
1198  // sync with generate.py
1199  // clang-format off
1200 
1201  #define _SS_ std::string
1202  #define _TS_ std::to_string
1203  auto pn = [&] () {
1204  std::string n;
1205  if (kPadSeqLenQ) n += "s";
1206  if (kPadHeadDimV) n += "dv";
1207  return n.empty() ? n : std::string("p") + n; }();
1208  return
1209  _SS_("fmha_bwd_dot_do_o_d") + _TS_(kVHeaddim) + "_" + _SS_(t2s<ODataType>::name) +
1210  "_" + (kIsGroupMode ? "group" : "batch") + "_" +
1211  ("o" + _TS_(kBlockPerCu)) + (pn.empty() ? "_npad" : "_" + pn);
1212  #undef _SS_
1213  #undef _TS_
1214  // clang-format on
1215  }
1216 
1217  // kargs use aggregate initializer, so no constructor will provided
1218  // use inheritance to minimize karg size
1219  // user need to use MakeKargs() function to create kargs.
1221  {
1222  const void* o_ptr;
1223  const void* do_ptr;
1224  void* d_ptr;
1225 
1226  float p_undrop;
1227 
1230 
1233 
1237  };
1238 
1240  {
1244  };
1245 
1247  {
1249  };
1250 
1251  using Kargs = std::
1252  conditional_t<kIsGroupMode, FmhaBwdOGradDotOGroupModeKargs, FmhaBwdOGradDotOBatchModeKargs>;
1253 
1254  template <bool Cond = !kIsGroupMode>
1255  CK_TILE_HOST static constexpr std::enable_if_t<Cond, Kargs>
1256  MakeKargs(const void* o_ptr,
1257  const void* do_ptr,
1258  void* d_ptr,
1259  float p_undrop,
1260  ck_tile::index_t seqlen_q,
1261  ck_tile::index_t hdim_v,
1262  ck_tile::index_t stride_do,
1263  ck_tile::index_t stride_o,
1264  ck_tile::index_t nhead_stride_do,
1265  ck_tile::index_t nhead_stride_o,
1266  ck_tile::index_t nhead_stride_d,
1267  ck_tile::index_t batch_stride_do,
1268  ck_tile::index_t batch_stride_o,
1269  ck_tile::index_t batch_stride_d)
1270  {
1271  Kargs kargs{{o_ptr,
1272  do_ptr,
1273  d_ptr,
1274  p_undrop,
1275  seqlen_q,
1276  hdim_v,
1277  stride_do,
1278  stride_o,
1279  nhead_stride_do,
1280  nhead_stride_o,
1281  nhead_stride_d},
1282  batch_stride_do,
1283  batch_stride_o,
1284  batch_stride_d};
1285 
1286  return kargs;
1287  }
1288 
1289  template <bool Cond = kIsGroupMode>
1290  CK_TILE_HOST static constexpr std::enable_if_t<Cond, Kargs>
1291  MakeKargs(const void* o_ptr,
1292  const void* do_ptr,
1293  void* d_ptr,
1294  float p_undrop,
1295  const void* seqstart_q_ptr,
1296  ck_tile::index_t hdim_v,
1297  ck_tile::index_t stride_do,
1298  ck_tile::index_t stride_o,
1299  ck_tile::index_t nhead_stride_do,
1300  ck_tile::index_t nhead_stride_o,
1301  ck_tile::index_t nhead_stride_d)
1302  {
1303  Kargs kargs{{o_ptr,
1304  do_ptr,
1305  d_ptr,
1306  p_undrop,
1307  -1, // seqlen will be updated by another pointer
1308  hdim_v,
1309  stride_do,
1310  stride_o,
1311  nhead_stride_do,
1312  nhead_stride_o,
1313  nhead_stride_d},
1314  reinterpret_cast<const int32_t*>(seqstart_q_ptr)};
1315 
1316  return kargs;
1317  }
1318 
1319  CK_TILE_HOST static constexpr auto
1321  {
1322  return dim3(ck_tile::integer_divide_ceil(seqlen_q_, kM0), nhead_, batch_size_);
1323  }
1324 
1325  CK_TILE_DEVICE static constexpr auto GetTileIndex()
1326  {
1327  const index_t i_block = blockIdx.x;
1328  const index_t i_nhead = blockIdx.y;
1329  const index_t i_batch = blockIdx.z;
1330 
1331  return ck_tile::make_tuple(i_block, i_nhead, i_batch);
1332  }
1333 
1334  CK_TILE_HOST static constexpr auto BlockSize() { return dim3(kBlockSize); }
1335 
1336  CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize() { return 0; }
1337 
1338  CK_TILE_DEVICE void operator()(Kargs kargs) const
1339  {
1340  // divide problem
1341  const auto [i_tile_m, i_nhead, i_batch] = GetTileIndex();
1342 
1343  const index_t i_m0 = amd_wave_read_first_lane(i_tile_m * kM0);
1344 
1345  long_index_t batch_offset_o = 0;
1346  long_index_t batch_offset_do = 0;
1347  long_index_t batch_offset_d = 0;
1348 
1349  if constexpr(kIsGroupMode)
1350  {
1351  // get starting offset for each batch
1352  const long_index_t query_start = kargs.seqstart_q_ptr[i_batch];
1353 
1354  batch_offset_o = query_start * kargs.stride_o;
1355  batch_offset_do = query_start * kargs.stride_do;
1356  batch_offset_d = query_start;
1357 
1358  // get real # queries & # keys under group mode
1359  const auto adjusted_seqstart_q_ptr = kargs.seqstart_q_ptr + i_batch;
1360  kargs.seqlen_q = adjusted_seqstart_q_ptr[1] - adjusted_seqstart_q_ptr[0];
1361  // # of required blocks is different in each groups, terminate unnecessary blocks
1362  // earlier
1363  if(kargs.seqlen_q <= i_m0)
1364  {
1365  return;
1366  }
1367  }
1368  else
1369  {
1370  batch_offset_o = static_cast<long_index_t>(i_batch) * kargs.batch_stride_o;
1371  batch_offset_do = static_cast<long_index_t>(i_batch) * kargs.batch_stride_do;
1372  batch_offset_d = static_cast<long_index_t>(i_batch) * kargs.batch_stride_d;
1373  }
1374 
1375  // for simplicity, batch stride we just modify the pointer
1376  const ODataType* o_ptr = reinterpret_cast<const ODataType*>(kargs.o_ptr) +
1377  static_cast<long_index_t>(i_nhead) * kargs.nhead_stride_o +
1378  batch_offset_o;
1379  const OGradDataType* do_ptr = reinterpret_cast<const OGradDataType*>(kargs.do_ptr) +
1380  static_cast<long_index_t>(i_nhead) * kargs.nhead_stride_do +
1381  batch_offset_do;
1382  DDataType* d_ptr = reinterpret_cast<DDataType*>(kargs.d_ptr) +
1383  static_cast<long_index_t>(i_nhead) * kargs.nhead_stride_d +
1384  batch_offset_d;
1385 
1386  // O/dO/D DRAM and DRAM window
1387  const auto o_dram = [&]() {
1388  auto o_dram_naive = make_naive_tensor_view<address_space_enum::global>(
1389  o_ptr,
1390  make_tuple(kargs.seqlen_q, kargs.hdim_v),
1391  make_tuple(kargs.stride_o, 1),
1393  number<1>{});
1394  return pad_tensor_view(o_dram_naive,
1397  }();
1398  const auto do_dram = [&]() {
1399  auto do_dram_naive = make_naive_tensor_view<address_space_enum::global>(
1400  do_ptr,
1401  make_tuple(kargs.seqlen_q, kargs.hdim_v),
1402  make_tuple(kargs.stride_do, 1),
1404  number<1>{});
1405  return pad_tensor_view(do_dram_naive,
1408  }();
1409  auto d_dram = [&]() {
1410  const auto d_dram_naive = make_naive_tensor_view_packed<address_space_enum::global>(
1411  d_ptr, make_tuple(kargs.seqlen_q), number<1>{});
1412  return pad_tensor_view(
1413  d_dram_naive, make_tuple(number<kM0>{}), sequence<kPadSeqLenQ>{});
1414  }();
1415 
1416  auto o_dram_window =
1417  make_tile_window(o_dram, make_tuple(number<kM0>{}, number<kVHeaddim>{}), {i_m0, 0});
1418 
1419  auto do_dram_window =
1420  make_tile_window(do_dram, make_tuple(number<kM0>{}, number<kVHeaddim>{}), {i_m0, 0});
1421 
1422  auto d_dram_window = make_tile_window(d_dram, make_tuple(number<kM0>{}), {i_m0});
1423 
1424  FmhaBwdOGradDotO{}(o_dram_window, do_dram_window, d_dram_window, kargs.p_undrop);
1425  }
1426 };
1427 
1428 template <typename FmhaBwdConvertQGrad_>
1430 {
1432  static constexpr ck_tile::index_t kBlockSize = FmhaBwdConvertQGrad::kBlockSize;
1433  static constexpr ck_tile::index_t kBlockPerCu = FmhaBwdConvertQGrad::kBlockPerCu;
1434  static constexpr ck_tile::index_t kM0 = FmhaBwdConvertQGrad::kM0;
1435  static constexpr ck_tile::index_t kN0 = FmhaBwdConvertQGrad::kN0;
1436  static constexpr ck_tile::index_t kQKHeaddim = FmhaBwdConvertQGrad::kQKHeaddim;
1437 
1440 
1441  static constexpr bool kIsGroupMode = FmhaBwdConvertQGrad::kIsGroupMode;
1442  static constexpr bool kPadSeqLenQ = FmhaBwdConvertQGrad::kPadSeqLenQ;
1443  static constexpr bool kPadHeadDimQ = FmhaBwdConvertQGrad::kPadHeadDimQ;
1444  static constexpr bool kIsDeterministic = FmhaBwdConvertQGrad::kIsDeterministic;
1445 
1446  // clang-format off
1447  template <typename T> struct t2s;
1448  template <> struct t2s<float> { static constexpr const char * name = "fp32"; };
1449  template <> struct t2s<ck_tile::fp16_t> { static constexpr const char * name = "fp16"; };
1450  template <> struct t2s<ck_tile::bf16_t> { static constexpr const char * name = "bf16"; };
1451  // clang-format on
1452 
1453  CK_TILE_HOST static std::string GetName()
1454  {
1455  // sync with generate.py
1456  // clang-format off
1457 
1458  #define _SS_ std::string
1459  #define _TS_ std::to_string
1460  auto pn = [&] () {
1461  std::string n;
1462  if (kPadSeqLenQ) n += "s";
1463  if (kPadHeadDimQ) n += "d";
1464  return n.empty() ? n : std::string("p") + n; }();
1465  return
1466  _SS_("fmha_bwd_convert_dq_d") + _TS_(kQKHeaddim) + "_"
1468  + "b" + _TS_(kM0) + "x" + _TS_(kN0) + "_"
1469  + (kIsGroupMode ? "group" : "batch") + "_"
1470  + ("o" + _TS_(kBlockPerCu)) + (pn.empty() ? "_npad" : "_" + pn)
1471  + (kIsDeterministic ? "_deterministic" : "_ndeterministic") ;
1472  #undef _SS_
1473  #undef _TS_
1474  // clang-format on
1475  }
1476 
1477  // to avoid duplicated base class prblem, introduce an template arg
1478  template <ck_tile::index_t I>
1480  {
1481  };
1482 
1483  // kargs use aggregate initializer, so no constructor will provided
1484  // use inheritance to minimize karg size
1485  // user need to use MakeKargs() function to create kargs.
1487  {
1488  const void* dq_acc_ptr;
1489  void* dq_ptr;
1490 
1494 
1499  };
1500 
1502  {
1504  };
1505 
1508  std::conditional_t<kIsDeterministic,
1509  FmhaBwdConvertQGradDeterministicKargs,
1510  FmhaBwdConvertQGradEmptyKargs<0>>
1511  {
1514  };
1515 
1518  std::conditional_t<kIsDeterministic,
1519  FmhaBwdConvertQGradDeterministicKargs,
1520  FmhaBwdConvertQGradEmptyKargs<0>>
1521  {
1524  };
1525 
1529 
1530  template <bool Cond = !kIsGroupMode>
1531  CK_TILE_HOST static constexpr std::enable_if_t<Cond, Kargs>
1532  MakeKargs(const void* dq_acc_ptr,
1533  void* dq_ptr,
1534  ck_tile::index_t seqlen_q,
1535  ck_tile::index_t seqlen_k,
1536  ck_tile::index_t hdim_q,
1537  ck_tile::index_t stride_dq,
1538  ck_tile::index_t stride_dq_acc,
1539  ck_tile::index_t nhead_stride_dq,
1540  ck_tile::index_t nhead_stride_dq_acc,
1541  ck_tile::index_t batch_stride_dq,
1542  ck_tile::index_t batch_stride_dq_acc,
1543  ck_tile::index_t split_stride_dq_acc)
1544  {
1545  Kargs kargs{{dq_acc_ptr,
1546  dq_ptr,
1547  seqlen_q,
1548  seqlen_k,
1549  hdim_q,
1550  stride_dq,
1551  stride_dq_acc,
1552  nhead_stride_dq,
1553  nhead_stride_dq_acc},
1554  {},
1555  batch_stride_dq,
1556  batch_stride_dq_acc};
1557 
1558  if constexpr(kIsDeterministic)
1559  {
1560  kargs.split_stride_dq_acc = split_stride_dq_acc;
1561  }
1562 
1563  return kargs;
1564  }
1565 
1566  template <bool Cond = kIsGroupMode>
1567  CK_TILE_HOST static constexpr std::enable_if_t<Cond, Kargs>
1568  MakeKargs(const void* dq_acc_ptr,
1569  void* dq_ptr,
1570  const void* seqstart_q_ptr,
1571  const void* seqstart_k_ptr,
1572  ck_tile::index_t hdim_q,
1573  ck_tile::index_t stride_dq,
1574  ck_tile::index_t stride_dq_acc,
1575  ck_tile::index_t nhead_stride_dq,
1576  ck_tile::index_t nhead_stride_dq_acc,
1577  ck_tile::index_t split_stride_dq_acc)
1578  {
1579  Kargs kargs{{dq_acc_ptr,
1580  dq_ptr,
1581  -1, // seqlen will be updated by another pointer
1582  -1, //
1583  hdim_q,
1584  stride_dq,
1585  stride_dq_acc,
1586  nhead_stride_dq,
1587  nhead_stride_dq_acc},
1588  {},
1589  reinterpret_cast<const int32_t*>(seqstart_q_ptr),
1590  reinterpret_cast<const int32_t*>(seqstart_k_ptr)};
1591 
1592  if constexpr(kIsDeterministic)
1593  {
1594  kargs.split_stride_dq_acc = split_stride_dq_acc;
1595  }
1596 
1597  return kargs;
1598  }
1599 
1600  CK_TILE_HOST static constexpr auto
1602  {
1603  return dim3(ck_tile::integer_divide_ceil(seqlen_q_, kM0), nhead_, batch_size_);
1604  }
1605 
1606  CK_TILE_DEVICE static constexpr auto GetTileIndex()
1607  {
1608  const index_t i_block = blockIdx.x;
1609  const index_t i_nhead = blockIdx.y;
1610  const index_t i_batch = blockIdx.z;
1611 
1612  return ck_tile::make_tuple(i_block, i_nhead, i_batch);
1613  }
1614 
1615  CK_TILE_HOST static constexpr auto BlockSize() { return dim3(kBlockSize); }
1616 
1617  CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize() { return 0; }
1618 
1619  CK_TILE_DEVICE void operator()(Kargs kargs) const
1620  {
1621  // divide problem
1622  const auto [i_tile_m, i_nhead, i_batch] = GetTileIndex();
1623 
1624  const index_t i_m0 = amd_wave_read_first_lane(i_tile_m * kM0);
1625 
1626  long_index_t batch_offset_dq = 0;
1627  long_index_t batch_offset_dq_acc = 0;
1628  if constexpr(kIsGroupMode)
1629  {
1630  // get starting offset for each batch
1631  const long_index_t query_start = kargs.seqstart_q_ptr[i_batch];
1632  batch_offset_dq = query_start * kargs.stride_dq;
1633  batch_offset_dq_acc = query_start * kargs.stride_dq_acc;
1634 
1635  // get real # queries & # keys under group mode
1636  const auto adjusted_seqstart_q_ptr = kargs.seqstart_q_ptr + i_batch;
1637  kargs.seqlen_q = adjusted_seqstart_q_ptr[1] - adjusted_seqstart_q_ptr[0];
1638  if constexpr(kIsDeterministic)
1639  {
1640  const auto adjusted_seqstart_k_ptr = kargs.seqstart_k_ptr + i_batch;
1641  kargs.seqlen_k = adjusted_seqstart_k_ptr[1] - adjusted_seqstart_k_ptr[0];
1642  }
1643  // # of required blocks is different in each groups, terminate unnecessary blocks
1644  // earlier
1645  if(kargs.seqlen_q <= i_m0)
1646  {
1647  return;
1648  }
1649  }
1650  else
1651  {
1652  batch_offset_dq = static_cast<long_index_t>(i_batch) * kargs.batch_stride_dq;
1653  batch_offset_dq_acc = static_cast<long_index_t>(i_batch) * kargs.batch_stride_dq_acc;
1654  }
1655 
1656  // for simplicity, batch stride we just modify the pointer
1657  QGradDataType* dq_ptr = reinterpret_cast<QGradDataType*>(kargs.dq_ptr) +
1658  static_cast<long_index_t>(i_nhead) * kargs.nhead_stride_dq +
1659  batch_offset_dq;
1660 
1661  // dQAcc/dQ DRAM and DRAM window
1662  const auto dq_acc_dram = [&, i_nhead_ = i_nhead]() {
1663  if constexpr(kIsDeterministic)
1664  {
1665  const AccDataType* dq_acc_ptr =
1666  reinterpret_cast<const AccDataType*>(kargs.dq_acc_ptr) +
1667  static_cast<long_index_t>(i_nhead_) * (kargs.nhead_stride_dq_acc) +
1668  batch_offset_dq_acc;
1669 
1670  const index_t nsplits = ck_tile::integer_divide_ceil(kargs.seqlen_k, kN0);
1671 
1672  auto dq_acc_dram_naive = make_naive_tensor_view<address_space_enum::global>(
1673  dq_acc_ptr,
1674  make_tuple(nsplits, kargs.seqlen_q, kargs.hdim_q),
1675  make_tuple(kargs.split_stride_dq_acc, kargs.stride_dq_acc, 1),
1677  number<1>{});
1678  return pad_tensor_view(dq_acc_dram_naive,
1681  }
1682  else
1683  {
1684  const AccDataType* dq_acc_ptr =
1685  reinterpret_cast<const AccDataType*>(kargs.dq_acc_ptr) +
1686  static_cast<long_index_t>(i_nhead_) * (kargs.nhead_stride_dq_acc) +
1687  batch_offset_dq_acc;
1688 
1689  auto dq_acc_dram_naive = make_naive_tensor_view<address_space_enum::global>(
1690  dq_acc_ptr,
1691  make_tuple(kargs.seqlen_q, kargs.hdim_q),
1692  make_tuple(kargs.stride_dq_acc, 1),
1694  number<1>{});
1695  return pad_tensor_view(dq_acc_dram_naive,
1698  }
1699  }();
1700 
1701  auto dq_dram = [&]() {
1702  auto dq_dram_naive = make_naive_tensor_view<address_space_enum::global>(
1703  dq_ptr,
1704  make_tuple(kargs.seqlen_q, kargs.hdim_q),
1705  make_tuple(kargs.stride_dq, 1),
1707  number<1>{});
1708  return pad_tensor_view(dq_dram_naive,
1711  }();
1712 
1713  auto dq_acc_dram_window = [&]() {
1714  if constexpr(kIsDeterministic)
1715  {
1716  return make_tile_window(
1717  dq_acc_dram,
1719  {0, i_m0, 0});
1720  }
1721  else
1722  {
1723  return make_tile_window(
1724  dq_acc_dram, make_tuple(number<kM0>{}, number<kQKHeaddim>{}), {i_m0, 0});
1725  }
1726  }();
1727 
1728  auto dq_dram_window =
1729  make_tile_window(dq_dram, make_tuple(number<kM0>{}, number<kQKHeaddim>{}), {i_m0, 0});
1730 
1731  if constexpr(kIsDeterministic)
1732  {
1733  const index_t nsplits = ck_tile::integer_divide_ceil(kargs.seqlen_k, kN0);
1734  FmhaBwdConvertQGrad{}(dq_acc_dram_window, dq_dram_window, nsplits);
1735  }
1736  else
1737  {
1738  FmhaBwdConvertQGrad{}(dq_acc_dram_window, dq_dram_window);
1739  }
1740  }
1741 };
1742 
1743 } // namespace ck_tile
#define CK_TILE_DEVICE
Definition: config.hpp:41
#define CK_TILE_HOST
Definition: config.hpp:40
#define CK_TILE_HOST_DEVICE
Definition: config.hpp:42
#define _TS_
#define _SS_
__host__ constexpr __device__ T max(T x)
Definition: math.hpp:84
__host__ T floor(T x)
Definition: math_v2.hpp:367
Definition: cluster_descriptor.hpp:13
constexpr CK_TILE_DEVICE auto make_null_tile_window(const WindowLengths &window_lengths)
Definition: null_tile_window.hpp:66
constexpr CK_TILE_HOST_DEVICE auto integer_divide_ceil(X x, Y y)
Definition: math.hpp:149
__device__ uint32_t amd_wave_read_first_lane(uint16_t v)
Definition: amd_buffer_addressing.hpp:2834
_Float16 fp16_t
Definition: half.hpp:110
bfloat16_t bf16_t
Definition: bfloat16.hpp:113
int32_t index_t
Definition: integer.hpp:9
constexpr CK_TILE_HOST_DEVICE auto pad_tensor_view(const TensorView &tensor_view, const TileLengths &tile_lengths, DoPads)
Definition: tensor_view.hpp:530
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition: type_traits.hpp:21
int64_t long_index_t
Definition: integer.hpp:11
int32_t int32_t
Definition: integer.hpp:10
constexpr CK_TILE_DEVICE auto make_tile_window(null_tensor_view, const WindowLengths &window_lengths, const multi_index< WindowLengths::size()> &, Ts &&...)
Definition: null_tile_window.hpp:75
constexpr CK_TILE_HOST_DEVICE auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:360
GenericAttentionMaskEnum
Definition: block_masking.hpp:11
constexpr CK_TILE_HOST_DEVICE T max(T x)
Definition: math.hpp:161
typename conditional< predicate, X, Y >::type conditional_t
Definition: functional.hpp:115
__device__ X atomic_add(X *p_dst, const X &x)
unsigned char uint8_t
Definition: stdint.h:124
unsigned __int64 uint64_t
Definition: stdint.h:136
Definition: block_position_encoding.hpp:48
Definition: block_attention_bias_enum.hpp:19
Definition: block_position_encoding.hpp:137
ck_tile::index_t batch_stride_dq
Definition: fmha_bwd_kernel.hpp:1512
ck_tile::index_t batch_stride_dq_acc
Definition: fmha_bwd_kernel.hpp:1513
ck_tile::index_t seqlen_q
Definition: fmha_bwd_kernel.hpp:1491
ck_tile::index_t stride_dq_acc
Definition: fmha_bwd_kernel.hpp:1496
ck_tile::index_t nhead_stride_dq
Definition: fmha_bwd_kernel.hpp:1497
ck_tile::index_t hdim_q
Definition: fmha_bwd_kernel.hpp:1493
ck_tile::index_t stride_dq
Definition: fmha_bwd_kernel.hpp:1495
const void * dq_acc_ptr
Definition: fmha_bwd_kernel.hpp:1488
ck_tile::index_t seqlen_k
Definition: fmha_bwd_kernel.hpp:1492
ck_tile::index_t nhead_stride_dq_acc
Definition: fmha_bwd_kernel.hpp:1498
ck_tile::index_t split_stride_dq_acc
Definition: fmha_bwd_kernel.hpp:1503
const int32_t * seqstart_k_ptr
Definition: fmha_bwd_kernel.hpp:1523
const int32_t * seqstart_q_ptr
Definition: fmha_bwd_kernel.hpp:1522
Definition: fmha_bwd_kernel.hpp:1447
Definition: fmha_bwd_kernel.hpp:1430
static constexpr bool kIsGroupMode
Definition: fmha_bwd_kernel.hpp:1441
static constexpr bool kPadHeadDimQ
Definition: fmha_bwd_kernel.hpp:1443
static constexpr bool kIsDeterministic
Definition: fmha_bwd_kernel.hpp:1444
static constexpr ck_tile::index_t kBlockSize
Definition: fmha_bwd_kernel.hpp:1432
static constexpr bool kPadSeqLenQ
Definition: fmha_bwd_kernel.hpp:1442
static constexpr CK_TILE_HOST std::enable_if_t< Cond, Kargs > MakeKargs(const void *dq_acc_ptr, void *dq_ptr, const void *seqstart_q_ptr, const void *seqstart_k_ptr, ck_tile::index_t hdim_q, ck_tile::index_t stride_dq, ck_tile::index_t stride_dq_acc, ck_tile::index_t nhead_stride_dq, ck_tile::index_t nhead_stride_dq_acc, ck_tile::index_t split_stride_dq_acc)
Definition: fmha_bwd_kernel.hpp:1568
static constexpr ck_tile::index_t kBlockPerCu
Definition: fmha_bwd_kernel.hpp:1433
static constexpr CK_TILE_DEVICE auto GetTileIndex()
Definition: fmha_bwd_kernel.hpp:1606
static constexpr CK_TILE_HOST std::enable_if_t< Cond, Kargs > MakeKargs(const void *dq_acc_ptr, void *dq_ptr, ck_tile::index_t seqlen_q, ck_tile::index_t seqlen_k, ck_tile::index_t hdim_q, ck_tile::index_t stride_dq, ck_tile::index_t stride_dq_acc, ck_tile::index_t nhead_stride_dq, ck_tile::index_t nhead_stride_dq_acc, ck_tile::index_t batch_stride_dq, ck_tile::index_t batch_stride_dq_acc, ck_tile::index_t split_stride_dq_acc)
Definition: fmha_bwd_kernel.hpp:1532
static constexpr CK_TILE_HOST_DEVICE ck_tile::index_t GetSmemSize()
Definition: fmha_bwd_kernel.hpp:1617
CK_TILE_DEVICE void operator()(Kargs kargs) const
Definition: fmha_bwd_kernel.hpp:1619
ck_tile::remove_cvref_t< typename FmhaBwdConvertQGrad::AccDataType > AccDataType
Definition: fmha_bwd_kernel.hpp:1438
static constexpr ck_tile::index_t kM0
Definition: fmha_bwd_kernel.hpp:1434
ck_tile::remove_cvref_t< typename FmhaBwdConvertQGrad::QGradDataType > QGradDataType
Definition: fmha_bwd_kernel.hpp:1439
static constexpr CK_TILE_HOST auto GridSize(ck_tile::index_t batch_size_, ck_tile::index_t nhead_, ck_tile::index_t seqlen_q_)
Definition: fmha_bwd_kernel.hpp:1601
static constexpr ck_tile::index_t kN0
Definition: fmha_bwd_kernel.hpp:1435
ck_tile::remove_cvref_t< FmhaBwdConvertQGrad_ > FmhaBwdConvertQGrad
Definition: fmha_bwd_kernel.hpp:1431
std::conditional_t< kIsGroupMode, FmhaBwdConvertQGradGroupModeKargs, FmhaBwdConvertQGradBatchModeKargs > Kargs
Definition: fmha_bwd_kernel.hpp:1528
static constexpr ck_tile::index_t kQKHeaddim
Definition: fmha_bwd_kernel.hpp:1436
static constexpr CK_TILE_HOST auto BlockSize()
Definition: fmha_bwd_kernel.hpp:1615
static CK_TILE_HOST std::string GetName()
Definition: fmha_bwd_kernel.hpp:1453
Definition: fmha_bwd_kernel.hpp:192
const void * alibi_slope_ptr
Definition: fmha_bwd_kernel.hpp:194
ck_tile::index_t alibi_slope_stride
Definition: fmha_bwd_kernel.hpp:195
ck_tile::index_t batch_stride_dbias
Definition: fmha_bwd_kernel.hpp:207
ck_tile::index_t batch_stride_bias
Definition: fmha_bwd_kernel.hpp:188
ck_tile::index_t batch_stride_randval
Definition: fmha_bwd_kernel.hpp:272
Definition: fmha_bwd_kernel.hpp:291
ck_tile::index_t batch_stride_v
Definition: fmha_bwd_kernel.hpp:294
ck_tile::index_t batch_stride_k
Definition: fmha_bwd_kernel.hpp:293
ck_tile::index_t batch_stride_q
Definition: fmha_bwd_kernel.hpp:292
ck_tile::index_t batch_stride_do
Definition: fmha_bwd_kernel.hpp:295
ck_tile::index_t batch_stride_dq_acc
Definition: fmha_bwd_kernel.hpp:297
ck_tile::index_t batch_stride_dk
Definition: fmha_bwd_kernel.hpp:298
ck_tile::index_t batch_stride_dv
Definition: fmha_bwd_kernel.hpp:299
ck_tile::index_t batch_stride_lsed
Definition: fmha_bwd_kernel.hpp:296
ck_tile::index_t nhead_stride_dbias
Definition: fmha_bwd_kernel.hpp:202
void * dbias_ptr
Definition: fmha_bwd_kernel.hpp:200
ck_tile::index_t stride_dbias
Definition: fmha_bwd_kernel.hpp:201
Definition: fmha_bwd_kernel.hpp:180
ck_tile::index_t stride_bias
Definition: fmha_bwd_kernel.hpp:182
ck_tile::index_t nhead_stride_bias
Definition: fmha_bwd_kernel.hpp:183
const void * bias_ptr
Definition: fmha_bwd_kernel.hpp:181
uint8_t p_undrop_in_uint8_t
Definition: fmha_bwd_kernel.hpp:263
float rp_undrop
Definition: fmha_bwd_kernel.hpp:261
ck_tile::index_t nhead_stride_randval
Definition: fmha_bwd_kernel.hpp:267
void init_dropout(float p_drop, const uint64_t *seed_ptr, const uint64_t *offset_ptr, float raw_scale)
Definition: fmha_bwd_kernel.hpp:245
float scale_rp_undrop
Definition: fmha_bwd_kernel.hpp:262
void * rand_val_ptr
Definition: fmha_bwd_kernel.hpp:264
ck_tile::index_t stride_randval
Definition: fmha_bwd_kernel.hpp:266
void init_dropout(float p_drop, uint64_t seed, uint64_t offset, float raw_scale)
Definition: fmha_bwd_kernel.hpp:232
Definition: fmha_bwd_kernel.hpp:138
ck_tile::index_t nhead_stride_dk
Definition: fmha_bwd_kernel.hpp:175
ck_tile::index_t stride_do
Definition: fmha_bwd_kernel.hpp:164
ck_tile::index_t seqlen_k
Definition: fmha_bwd_kernel.hpp:150
const void * q_ptr
Definition: fmha_bwd_kernel.hpp:139
ck_tile::index_t hdim_q
Definition: fmha_bwd_kernel.hpp:151
ck_tile::index_t nhead_stride_do
Definition: fmha_bwd_kernel.hpp:172
ck_tile::index_t num_head_q
Definition: fmha_bwd_kernel.hpp:156
const void * lse_ptr
Definition: fmha_bwd_kernel.hpp:142
float raw_scale
Definition: fmha_bwd_kernel.hpp:158
ck_tile::index_t nhead_stride_k
Definition: fmha_bwd_kernel.hpp:170
ck_tile::index_t nhead_stride_q
Definition: fmha_bwd_kernel.hpp:169
ck_tile::index_t stride_dv
Definition: fmha_bwd_kernel.hpp:167
ck_tile::index_t nhead_stride_lsed
Definition: fmha_bwd_kernel.hpp:173
void * dq_acc_ptr
Definition: fmha_bwd_kernel.hpp:145
ck_tile::index_t stride_q
Definition: fmha_bwd_kernel.hpp:161
ck_tile::index_t seqlen_q
Definition: fmha_bwd_kernel.hpp:149
ck_tile::index_t stride_dk
Definition: fmha_bwd_kernel.hpp:166
const void * do_ptr
Definition: fmha_bwd_kernel.hpp:143
float scale
Definition: fmha_bwd_kernel.hpp:159
void * dk_ptr
Definition: fmha_bwd_kernel.hpp:146
ck_tile::index_t nhead_stride_v
Definition: fmha_bwd_kernel.hpp:171
ck_tile::index_t stride_v
Definition: fmha_bwd_kernel.hpp:163
const void * d_ptr
Definition: fmha_bwd_kernel.hpp:144
const void * k_ptr
Definition: fmha_bwd_kernel.hpp:140
ck_tile::index_t nhead_stride_dq_acc
Definition: fmha_bwd_kernel.hpp:174
ck_tile::index_t nhead_ratio_qk
Definition: fmha_bwd_kernel.hpp:157
void * dv_ptr
Definition: fmha_bwd_kernel.hpp:147
const void * v_ptr
Definition: fmha_bwd_kernel.hpp:141
ck_tile::index_t hdim_v
Definition: fmha_bwd_kernel.hpp:152
ck_tile::index_t stride_k
Definition: fmha_bwd_kernel.hpp:162
ck_tile::index_t nhead_stride_dv
Definition: fmha_bwd_kernel.hpp:176
ck_tile::index_t stride_dq_acc
Definition: fmha_bwd_kernel.hpp:165
ck_tile::index_t split_stride_dq_acc
Definition: fmha_bwd_kernel.hpp:277
bool is_drop_seed_offset_from_host
Definition: fmha_bwd_kernel.hpp:227
ValueOrPointer< uint64_t > drop_seed
Definition: fmha_bwd_kernel.hpp:225
ValueOrPointer< uint64_t > drop_offset
Definition: fmha_bwd_kernel.hpp:226
Definition: fmha_bwd_kernel.hpp:131
Definition: fmha_bwd_kernel.hpp:313
const int32_t * seqstart_k_ptr
Definition: fmha_bwd_kernel.hpp:315
const int32_t * seqstart_q_ptr
Definition: fmha_bwd_kernel.hpp:314
const int32_t * seqlen_k_ptr
Definition: fmha_bwd_kernel.hpp:316
Definition: fmha_bwd_kernel.hpp:211
ck_tile::GenericAttentionMaskEnum mask_type
Definition: fmha_bwd_kernel.hpp:213
ck_tile::index_t window_size_right
Definition: fmha_bwd_kernel.hpp:212
ck_tile::index_t window_size_left
Definition: fmha_bwd_kernel.hpp:212
Definition: fmha_bwd_kernel.hpp:84
Definition: fmha_bwd_kernel.hpp:35
static constexpr CK_TILE_HOST std::enable_if_t< Cond, Kargs > MakeKargsImpl(const void *q_ptr, const void *k_ptr, const void *v_ptr, const void *bias_ptr, const void *lse_ptr, const void *do_ptr, const void *d_ptr, void *rand_val_ptr, void *dk_ptr, void *dv_ptr, void *dbias_ptr, void *dq_acc_ptr, const void *seqstart_q_ptr, const void *seqstart_k_ptr, const void *seqlen_k_ptr, ck_tile::index_t hdim_q, ck_tile::index_t hdim_v, ck_tile::index_t num_head_q, ck_tile::index_t nhead_ratio_qk, float scale, ck_tile::index_t stride_q, ck_tile::index_t stride_k, ck_tile::index_t stride_v, ck_tile::index_t stride_bias, ck_tile::index_t stride_randval, ck_tile::index_t stride_do, ck_tile::index_t stride_dq_acc, ck_tile::index_t stride_dk, ck_tile::index_t stride_dv, ck_tile::index_t stride_dbias, ck_tile::index_t nhead_stride_q, ck_tile::index_t nhead_stride_k, ck_tile::index_t nhead_stride_v, ck_tile::index_t nhead_stride_bias, ck_tile::index_t nhead_stride_randval, ck_tile::index_t nhead_stride_do, ck_tile::index_t nhead_stride_lsed, ck_tile::index_t nhead_stride_dq_acc, ck_tile::index_t nhead_stride_dk, ck_tile::index_t nhead_stride_dv, ck_tile::index_t nhead_stride_dbias, ck_tile::index_t split_stride_dq_acc, ck_tile::index_t window_size_left, ck_tile::index_t window_size_right, ck_tile::index_t mask_type, float p_drop, std::variant< std::pair< uint64_t, uint64_t >, std::pair< const void *, const void * >> drop_seed_offset)
Definition: fmha_bwd_kernel.hpp:509
static constexpr CK_TILE_HOST auto BlockSize()
Definition: fmha_bwd_kernel.hpp:671
ck_tile::remove_cvref_t< KGradEpiloguePipeline_ > KGradEpiloguePipeline
Definition: fmha_bwd_kernel.hpp:37
static constexpr CK_TILE_HOST std::enable_if_t< Cond, Kargs > MakeKargsImpl(const void *q_ptr, const void *k_ptr, const void *v_ptr, const void *bias_ptr, const void *lse_ptr, const void *do_ptr, const void *d_ptr, void *rand_val_ptr, void *dk_ptr, void *dv_ptr, void *dbias_ptr, void *dq_acc_ptr, ck_tile::index_t seqlen_q, ck_tile::index_t seqlen_k, ck_tile::index_t hdim_q, ck_tile::index_t hdim_v, ck_tile::index_t num_head_q, ck_tile::index_t nhead_ratio_qk, float scale, ck_tile::index_t stride_q, ck_tile::index_t stride_k, ck_tile::index_t stride_v, ck_tile::index_t stride_bias, ck_tile::index_t stride_randval, ck_tile::index_t stride_do, ck_tile::index_t stride_dq_acc, ck_tile::index_t stride_dk, ck_tile::index_t stride_dv, ck_tile::index_t stride_dbias, ck_tile::index_t nhead_stride_q, ck_tile::index_t nhead_stride_k, ck_tile::index_t nhead_stride_v, ck_tile::index_t nhead_stride_bias, ck_tile::index_t nhead_stride_randval, ck_tile::index_t nhead_stride_do, ck_tile::index_t nhead_stride_lsed, ck_tile::index_t nhead_stride_dq_acc, ck_tile::index_t nhead_stride_dk, ck_tile::index_t nhead_stride_dv, ck_tile::index_t nhead_stride_dbias, ck_tile::index_t batch_stride_q, ck_tile::index_t batch_stride_k, ck_tile::index_t batch_stride_v, ck_tile::index_t batch_stride_bias, ck_tile::index_t batch_stride_randval, ck_tile::index_t batch_stride_do, ck_tile::index_t batch_stride_lsed, ck_tile::index_t batch_stride_dq_acc, ck_tile::index_t batch_stride_dk, ck_tile::index_t batch_stride_dv, ck_tile::index_t batch_stride_dbias, ck_tile::index_t split_stride_dq_acc, ck_tile::index_t window_size_left, ck_tile::index_t window_size_right, ck_tile::index_t mask_type, float p_drop, std::variant< std::pair< uint64_t, uint64_t >, std::pair< const void *, const void * >> drop_seed_offset)
Definition: fmha_bwd_kernel.hpp:341
static constexpr auto BiasEnum
Definition: fmha_bwd_kernel.hpp:66
static CK_TILE_HOST std::string GetName()
Definition: fmha_bwd_kernel.hpp:90
ck_tile::remove_cvref_t< typename FmhaPipeline::FmhaDropout > FmhaDropout
Definition: fmha_bwd_kernel.hpp:69
ck_tile::remove_cvref_t< typename FmhaPipeline::OGradDataType > OGradDataType
Definition: fmha_bwd_kernel.hpp:57
static constexpr bool kUseQrQtrDorPipeline
Definition: fmha_bwd_kernel.hpp:42
ck_tile::remove_cvref_t< VGradEpiloguePipeline_ > VGradEpiloguePipeline
Definition: fmha_bwd_kernel.hpp:38
ck_tile::remove_cvref_t< typename FmhaPipeline::BiasGradDataType > BiasGradDataType
Definition: fmha_bwd_kernel.hpp:61
static constexpr CK_TILE_HOST_DEVICE ck_tile::index_t GetSmemSize()
Definition: fmha_bwd_kernel.hpp:673
static constexpr ck_tile::index_t kBlockSize
Definition: fmha_bwd_kernel.hpp:40
static constexpr bool kHasMask
Definition: fmha_bwd_kernel.hpp:70
ck_tile::remove_cvref_t< typename FmhaPipeline::KDataType > KDataType
Definition: fmha_bwd_kernel.hpp:48
static constexpr CK_TILE_HOST Kargs MakeKargs(Ts... args, const std::tuple< uint64_t, uint64_t > &drop_seed_offset)
Definition: fmha_bwd_kernel.hpp:324
static constexpr bool kIsGroupMode
Definition: fmha_bwd_kernel.hpp:63
ck_tile::remove_cvref_t< typename FmhaPipeline::QGradDataType > QGradDataType
Definition: fmha_bwd_kernel.hpp:58
static constexpr CK_TILE_DEVICE auto GetTileIndex()
Definition: fmha_bwd_kernel.hpp:662
ck_tile::remove_cvref_t< typename FmhaPipeline::DDataType > DDataType
Definition: fmha_bwd_kernel.hpp:54
ck_tile::remove_cvref_t< typename FmhaPipeline::QDataType > QDataType
Definition: fmha_bwd_kernel.hpp:47
ck_tile::remove_cvref_t< typename FmhaPipeline::LSEDataType > LSEDataType
Definition: fmha_bwd_kernel.hpp:52
static constexpr CK_TILE_HOST auto GridSize(ck_tile::index_t batch_size_, ck_tile::index_t nhead_, ck_tile::index_t seqlen_k_)
Definition: fmha_bwd_kernel.hpp:654
ck_tile::remove_cvref_t< typename FmhaPipeline::VGradDataType > VGradDataType
Definition: fmha_bwd_kernel.hpp:60
ck_tile::remove_cvref_t< typename FmhaPipeline::GemmDataType > GemmDataType
Definition: fmha_bwd_kernel.hpp:51
CK_TILE_DEVICE void operator()(Kargs kargs) const
Definition: fmha_bwd_kernel.hpp:680
static constexpr bool kIsAvailable
Definition: fmha_bwd_kernel.hpp:80
ck_tile::remove_cvref_t< QGradEpiloguePipeline_ > QGradEpiloguePipeline
Definition: fmha_bwd_kernel.hpp:39
static constexpr bool kHasDropout
Definition: fmha_bwd_kernel.hpp:71
ck_tile::remove_cvref_t< typename FmhaPipeline::AccDataType > AccDataType
Definition: fmha_bwd_kernel.hpp:53
CK_TILE_DEVICE void run_(Kargs kargs) const
Definition: fmha_bwd_kernel.hpp:686
ck_tile::remove_cvref_t< typename FmhaPipeline::VDataType > VDataType
Definition: fmha_bwd_kernel.hpp:49
static constexpr bool kHasBiasGrad
Definition: fmha_bwd_kernel.hpp:67
static constexpr bool kIsDeterministic
Definition: fmha_bwd_kernel.hpp:73
ck_tile::remove_cvref_t< typename FmhaPipeline::BiasDataType > BiasDataType
Definition: fmha_bwd_kernel.hpp:50
ck_tile::remove_cvref_t< typename FmhaPipeline::RandValOutputDataType > RandValOutputDataType
Definition: fmha_bwd_kernel.hpp:56
static constexpr index_t kPadHeadDimQ
Definition: fmha_bwd_kernel.hpp:64
ck_tile::remove_cvref_t< FmhaPipeline_ > FmhaPipeline
Definition: fmha_bwd_kernel.hpp:36
static constexpr CK_TILE_HOST Kargs MakeKargs(Ts... args, const std::tuple< const void *, const void * > &drop_seed_offset)
Definition: fmha_bwd_kernel.hpp:333
ck_tile::remove_cvref_t< typename FmhaPipeline::KGradDataType > KGradDataType
Definition: fmha_bwd_kernel.hpp:59
ck_tile::remove_cvref_t< typename FmhaPipeline::FmhaMask > FmhaMask
Definition: fmha_bwd_kernel.hpp:68
static constexpr bool kUseTrLoad
Definition: fmha_bwd_kernel.hpp:74
static constexpr index_t kPadHeadDimV
Definition: fmha_bwd_kernel.hpp:65
static constexpr ck_tile::index_t kBlockPerCu
Definition: fmha_bwd_kernel.hpp:41
std::conditional_t< kIsGroupMode, FmhaBwdGroupModeKargs, FmhaBwdBatchModeKargs > Kargs
Definition: fmha_bwd_kernel.hpp:319
static constexpr bool kIsStoreRandval
Definition: fmha_bwd_kernel.hpp:72
static constexpr index_t kMaxSeqLenQ
Definition: fmha_bwd_kernel.hpp:75
ck_tile::index_t batch_stride_o
Definition: fmha_bwd_kernel.hpp:1242
ck_tile::index_t batch_stride_do
Definition: fmha_bwd_kernel.hpp:1241
ck_tile::index_t batch_stride_d
Definition: fmha_bwd_kernel.hpp:1243
void * d_ptr
Definition: fmha_bwd_kernel.hpp:1224
const void * o_ptr
Definition: fmha_bwd_kernel.hpp:1222
ck_tile::index_t hdim_v
Definition: fmha_bwd_kernel.hpp:1229
ck_tile::index_t nhead_stride_do
Definition: fmha_bwd_kernel.hpp:1234
ck_tile::index_t stride_o
Definition: fmha_bwd_kernel.hpp:1232
ck_tile::index_t nhead_stride_o
Definition: fmha_bwd_kernel.hpp:1235
const void * do_ptr
Definition: fmha_bwd_kernel.hpp:1223
ck_tile::index_t stride_do
Definition: fmha_bwd_kernel.hpp:1231
ck_tile::index_t seqlen_q
Definition: fmha_bwd_kernel.hpp:1228
float p_undrop
Definition: fmha_bwd_kernel.hpp:1226
ck_tile::index_t nhead_stride_d
Definition: fmha_bwd_kernel.hpp:1236
const int32_t * seqstart_q_ptr
Definition: fmha_bwd_kernel.hpp:1248
Definition: fmha_bwd_kernel.hpp:1190
Definition: fmha_bwd_kernel.hpp:1174
ck_tile::remove_cvref_t< typename FmhaBwdOGradDotO::ODataType > ODataType
Definition: fmha_bwd_kernel.hpp:1182
static constexpr CK_TILE_HOST_DEVICE ck_tile::index_t GetSmemSize()
Definition: fmha_bwd_kernel.hpp:1336
ck_tile::remove_cvref_t< FmhaBwdOGradDotO_ > FmhaBwdOGradDotO
Definition: fmha_bwd_kernel.hpp:1175
static constexpr CK_TILE_HOST std::enable_if_t< Cond, Kargs > MakeKargs(const void *o_ptr, const void *do_ptr, void *d_ptr, float p_undrop, ck_tile::index_t seqlen_q, ck_tile::index_t hdim_v, ck_tile::index_t stride_do, ck_tile::index_t stride_o, ck_tile::index_t nhead_stride_do, ck_tile::index_t nhead_stride_o, ck_tile::index_t nhead_stride_d, ck_tile::index_t batch_stride_do, ck_tile::index_t batch_stride_o, ck_tile::index_t batch_stride_d)
Definition: fmha_bwd_kernel.hpp:1256
static constexpr ck_tile::index_t kVHeaddim
Definition: fmha_bwd_kernel.hpp:1179
CK_TILE_DEVICE void operator()(Kargs kargs) const
Definition: fmha_bwd_kernel.hpp:1338
static constexpr bool kIsGroupMode
Definition: fmha_bwd_kernel.hpp:1185
ck_tile::remove_cvref_t< typename FmhaBwdOGradDotO::OGradDataType > OGradDataType
Definition: fmha_bwd_kernel.hpp:1183
static constexpr ck_tile::index_t kM0
Definition: fmha_bwd_kernel.hpp:1178
static constexpr CK_TILE_HOST std::enable_if_t< Cond, Kargs > MakeKargs(const void *o_ptr, const void *do_ptr, void *d_ptr, float p_undrop, const void *seqstart_q_ptr, ck_tile::index_t hdim_v, ck_tile::index_t stride_do, ck_tile::index_t stride_o, ck_tile::index_t nhead_stride_do, ck_tile::index_t nhead_stride_o, ck_tile::index_t nhead_stride_d)
Definition: fmha_bwd_kernel.hpp:1291
static constexpr ck_tile::index_t kBlockSize
Definition: fmha_bwd_kernel.hpp:1176
static constexpr ck_tile::index_t kBlockPerCu
Definition: fmha_bwd_kernel.hpp:1177
static constexpr CK_TILE_DEVICE auto GetTileIndex()
Definition: fmha_bwd_kernel.hpp:1325
static constexpr bool kPadSeqLenQ
Definition: fmha_bwd_kernel.hpp:1186
std::conditional_t< kIsGroupMode, FmhaBwdOGradDotOGroupModeKargs, FmhaBwdOGradDotOBatchModeKargs > Kargs
Definition: fmha_bwd_kernel.hpp:1252
static CK_TILE_HOST std::string GetName()
Definition: fmha_bwd_kernel.hpp:1196
ck_tile::remove_cvref_t< typename FmhaBwdOGradDotO::DDataType > DDataType
Definition: fmha_bwd_kernel.hpp:1181
static constexpr bool kPadHeadDimV
Definition: fmha_bwd_kernel.hpp:1187
static constexpr CK_TILE_HOST auto BlockSize()
Definition: fmha_bwd_kernel.hpp:1334
static constexpr CK_TILE_HOST auto GridSize(ck_tile::index_t batch_size_, ck_tile::index_t nhead_, ck_tile::index_t seqlen_q_)
Definition: fmha_bwd_kernel.hpp:1320
Definition: integral_constant.hpp:13
Definition: block_fmha_bwd_dq_dk_dv_pipeline_trload_qr_qtr_dor.hpp:777
Definition: coordinate_transform.hpp:1392
Definition: sequence.hpp:49