/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp Source File
fmha_bwd_kernel.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include "ck_tile/core.hpp"
7 #include "ck_tile/ops/common.hpp"
10 
11 #include <string>
12 #include <type_traits>
13 #include <utility>
14 #include <variant>
15 
16 // S[seqlen_q, seqlen_k] = Q[seqlen_q, hdim_q] @ K[seqlen_k, hdim_q]
17 // S'[seqlen_q, seqlen_k] = S[seqlen_q, seqlen_k] * Scale[1]
18 // S''[seqlen_q, seqlen_k] = S'[seqlen_q, seqlen_k] + Bias[seqlen_q, seqlen_k]
19 // P[seqlen_q, seqlen_k] = Softmax(S''[seqlen_q, seqlen_k])
20 // dV[seqlen_k, hdim_v] = P^T[seqlen_k, seqlen_q] @ dO^T[hdim_v, seqlen_q]
21 // dP[seqlen_q, seqlen_k] = dO[seqlen_q, hdim_v] @ V[seqlen_k, hdim_v]
22 // D[seqlen_q] = rowsum(dO[seqlen_q, hdim_v] * O[seqlen_q, hdim_v])
23 // dS''[seqlen_q, seqlen_k] = P[seqlen_q, seqlen_k] * (dP[seqlen_q, seqlen_k] - D[seqlen_q])
24 // dBias[seqlen_q, seqlen_k] = dS'[seqlen_q, seqlen_k] = dS''[seqlen_q, seqlen_k]
25 // dK[seqlen_k, hdim_q] = dS'^T[seqlen_k, seqlen_q] @ Q^T[hdim_q, seqlen_q] * Scale[1]
26 // dQ[seqlen_q, hdim_q] = dS'[seqlen_q, seqlen_k] @ K^T[hdim_q, seqlen_k] * Scale[1]
27 
28 namespace ck_tile {
29 
30 template <typename FmhaPipeline_,
31  typename KGradEpiloguePipeline_,
32  typename VGradEpiloguePipeline_,
33  typename QGradEpiloguePipeline_ = void>
35 {
40  static constexpr ck_tile::index_t kBlockSize = FmhaPipeline::kBlockSize;
41  static constexpr ck_tile::index_t kBlockPerCu = FmhaPipeline::kBlockPerCu;
42  static constexpr bool kUseQrQtrDorPipeline =
44  static_assert(!kUseQrQtrDorPipeline || !std::is_same_v<QGradEpiloguePipeline_, void>,
45  "QrQtrDorPipeline needs QGradEpiloguePipeline");
46 
62 
63  static constexpr bool kIsGroupMode = FmhaPipeline::kIsGroupMode;
64  static constexpr bool kPadHeadDimQ = FmhaPipeline::kPadHeadDimQ;
65  static constexpr bool kPadHeadDimV = FmhaPipeline::kPadHeadDimV;
66  static constexpr auto BiasEnum = FmhaPipeline::BiasEnum;
67  static constexpr bool kHasBiasGrad = FmhaPipeline::kHasBiasGrad;
70  static constexpr bool kHasMask = FmhaMask::IsMasking;
71  static constexpr bool kHasDropout = FmhaDropout::IsDropout;
72  static constexpr bool kIsStoreRandval = FmhaDropout::IsStoreRandval;
73  static constexpr bool kIsDeterministic = FmhaPipeline::kIsDeterministic;
74  static constexpr bool kUseTrLoad = FmhaPipeline::kUseTrLoad;
75  static constexpr index_t kMaxSeqLenQ = FmhaPipeline::BlockFmhaShape::kMaxSeqLenQ;
76  static_assert(kUseQrQtrDorPipeline == (kMaxSeqLenQ != 0));
77 #if defined(__gfx950__)
78  static constexpr bool kIsAvialable = true;
79 #else
80  static constexpr bool kIsAvialable = !kUseTrLoad;
81 #endif
82 
83  // clang-format off
84  template <typename T> struct t2s;
85  template <> struct t2s<ck_tile::fp16_t> { static constexpr const char * name = "fp16"; };
86  template <> struct t2s<ck_tile::bf16_t> { static constexpr const char * name = "bf16"; };
87  // clang-format on
88 
89  CK_TILE_HOST static std::string GetName()
90  {
91  // sync with generate.py
92  // clang-format off
93  using bfs = typename FmhaPipeline::BlockFmhaShape;
94  using gbr0 = typename bfs::Gemm0BlockWarps;
95  using gbr1 = typename bfs::Gemm1BlockWarps;
96  using gbr4 = typename bfs::Gemm4BlockWarps;
97  using gwt0 = typename bfs::Gemm0WarpTile;
98  using gwt1 = typename bfs::Gemm1WarpTile;
99  #define _SS_ std::string
100  #define _TS_ std::to_string
101  auto pn = [&] () {
102  std::string n;
103  if (kPadHeadDimQ) n += "d";
104  if (kPadHeadDimV) n += "dv";
105  return n.empty() ? n : std::string("p") + n; }();
106  return
107  _SS_("fmha_bwd_d") + _TS_(bfs::kQKHeaddim) + "_" + _SS_(t2s<QDataType>::name) +
108  "_" + (kIsGroupMode ? "group" : "batch") + "_" +
109  "b" + _TS_(bfs::kM0) + "x" + _TS_(bfs::kN0) + "x" + _TS_(bfs::kK0) + "x" + _TS_(bfs::kK1) + "x" + _TS_(bfs::kK2) + "x" + _TS_(bfs::kK3) + "x" +
110  _TS_(bfs::kK4) + "x" + _TS_(bfs::kQKHeaddim) + "x" + _TS_(bfs::kVHeaddim) + "_" +
111  "r" + _TS_(gbr0::at(ck_tile::number<0>{})) + "x" + _TS_(gbr0::at(ck_tile::number<1>{})) + "x" + _TS_(gbr0::at(ck_tile::number<2>{})) + "_" +
112  "r" + _TS_(gbr1::at(ck_tile::number<0>{})) + "x" + _TS_(gbr1::at(ck_tile::number<1>{})) + "x" + _TS_(gbr1::at(ck_tile::number<2>{})) + "_" +
113  "r" + _TS_(gbr4::at(ck_tile::number<0>{})) + "x" + _TS_(gbr4::at(ck_tile::number<1>{})) + "x" + _TS_(gbr4::at(ck_tile::number<2>{})) + "_" +
114  "w" + _TS_(gwt0::at(ck_tile::number<0>{})) + "x" + _TS_(gwt0::at(ck_tile::number<1>{})) + "x" + _TS_(gwt0::at(ck_tile::number<2>{})) + "_" +
115  "w" + _TS_(gwt1::at(ck_tile::number<0>{})) + "x" + _TS_(gwt1::at(ck_tile::number<1>{})) + "x" + _TS_(gwt1::at(ck_tile::number<2>{})) + "_" +
116  ("o" + _TS_(kBlockPerCu)) + (pn.empty() ? "_npad" : "_" + pn) +
118  (kHasBiasGrad ? "_dbias" : "_ndbias") + (kHasMask ? "_" + _SS_(FmhaMask::name) : "_nmask") + (kHasDropout ? "_dropout" : "_ndropout" ) +
119  (kIsStoreRandval ? "_storerandval" : "" ) + (kIsDeterministic ? "_deterministic" : "_ndeterministic" ) + (kUseTrLoad ? "_trload" : "_ntrload");
120  #undef _SS_
121  #undef _TS_
122  // clang-format on
123  }
124 
125  template <ck_tile::index_t I> // to avoid duplicated base class prblem, introduce an template
126  // arg
128  {
129  };
130 
131  // kargs use aggregate initializer, so no constructor will provided
132  // use inheritance to minimize karg size
133  // user need to use MakeKargs() function to create kargs.
135  {
136  const void* q_ptr;
137  const void* k_ptr;
138  const void* v_ptr;
139  const void* lse_ptr;
140  const void* do_ptr;
141  const void* d_ptr;
142  void* dq_acc_ptr; // can be dq_ptr for qrqtrdor pipeline
143  void* dk_ptr;
144  void* dv_ptr;
145 
150 
151  // for MQA/GQA, nhead could be different. This parameter is nhead_q / nhead_k
152  // if this param is larger than 1, indicate MQA/GQA case
155  float raw_scale;
156  float scale;
157 
165 
174  };
175 
177  {
178  const void* bias_ptr = nullptr;
181  };
182 
184  {
186  };
187 
189  {
190  // alibi is batch*nhead*1, no matter in batch/group mode, they are the same
191  const void* alibi_slope_ptr;
192  ck_tile::index_t alibi_slope_stride; // stride in batch, or 0 for all batch share same slope
193  };
194 
196  {
197  void* dbias_ptr = nullptr;
200  };
201 
203  {
205  };
206 
208  {
211  };
212 
214  {
215  template <typename T>
217  {
218  T val;
219  const T* ptr;
220  };
221 
225  };
226 
228  {
229  void init_dropout(float p_drop, uint64_t seed, uint64_t offset, float raw_scale)
230  {
231  float p_undrop = 1.0 - p_drop;
234  rp_undrop = 1.0 / p_undrop;
235  scale_rp_undrop = rp_undrop * raw_scale;
236 
237  this->drop_seed.val = seed;
238  this->drop_offset.val = offset;
239  this->is_drop_seed_offset_from_host = true;
240  }
241 
242  void init_dropout(float p_drop,
243  const uint64_t* seed_ptr,
244  const uint64_t* offset_ptr,
245  float raw_scale)
246  {
247  float p_undrop = 1.0 - p_drop;
250  rp_undrop = 1.0 / p_undrop;
251  scale_rp_undrop = rp_undrop * raw_scale;
252 
253  this->drop_seed.ptr = seed_ptr;
254  this->drop_offset.ptr = offset_ptr;
255  this->is_drop_seed_offset_from_host = false;
256  }
257 
258  float rp_undrop = 1;
259  float scale_rp_undrop = 1;
261  void* rand_val_ptr = nullptr;
262 
265  };
266 
268  {
270  };
271 
273  {
275  };
276 
279  std::conditional_t<BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS,
280  FmhaBwdBatchModeBiasKargs,
281  std::conditional_t<BiasEnum == BlockAttentionBiasEnum::ALIBI,
282  FmhaBwdAlibiKargs,
283  FmhaBwdEmptyKargs<0>>>,
284  std::conditional_t<kHasBiasGrad, FmhaBwdBatchModeBiasGradKargs, FmhaBwdEmptyKargs<1>>,
285  std::conditional_t<kHasMask, FmhaBwdMaskKargs, FmhaBwdEmptyKargs<2>>,
286  std::conditional_t<kHasDropout, FmhaBwdBatchModeDropoutKargs, FmhaBwdEmptyKargs<3>>,
287  std::conditional_t<kIsDeterministic, FmhaBwdDeterministicKargs, FmhaBwdEmptyKargs<4>>
288  {
297  };
298 
301  std::conditional_t<BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS,
302  FmhaBwdCommonBiasKargs,
303  std::conditional_t<BiasEnum == BlockAttentionBiasEnum::ALIBI,
304  FmhaBwdAlibiKargs,
305  FmhaBwdEmptyKargs<0>>>,
306  std::conditional_t<kHasBiasGrad, FmhaBwdCommonBiasGradKargs, FmhaBwdEmptyKargs<1>>,
307  std::conditional_t<kHasMask, FmhaBwdMaskKargs, FmhaBwdEmptyKargs<2>>,
308  std::conditional_t<kHasDropout, FmhaBwdCommonDropoutKargs, FmhaBwdEmptyKargs<3>>,
309  std::conditional_t<kIsDeterministic, FmhaBwdDeterministicKargs, FmhaBwdEmptyKargs<4>>
310  {
314  };
315 
316  using Kargs = std::conditional_t<kIsGroupMode, FmhaBwdGroupModeKargs, FmhaBwdBatchModeKargs>;
317 
318  // std::variant<> can't take in a list initializer, overload for backward compatibility
319  template <typename... Ts>
320  CK_TILE_HOST static constexpr Kargs
321  MakeKargs(Ts... args, const std::tuple<uint64_t, uint64_t>& drop_seed_offset)
322  {
323  return MakeKargsImpl(
324  args..., std::make_pair(std::get<0>(drop_seed_offset), std::get<1>(drop_seed_offset)));
325  }
326 
327  // std::variant<> can't take in a list initializer, overload for backward compatibility
328  template <typename... Ts>
329  CK_TILE_HOST static constexpr Kargs
330  MakeKargs(Ts... args, const std::tuple<const void*, const void*>& drop_seed_offset)
331  {
332  return MakeKargsImpl(
333  args..., std::make_pair(std::get<0>(drop_seed_offset), std::get<1>(drop_seed_offset)));
334  }
335 
336  template <bool Cond = !kIsGroupMode>
337  CK_TILE_HOST static constexpr std::enable_if_t<Cond, Kargs>
338  MakeKargsImpl(const void* q_ptr,
339  const void* k_ptr,
340  const void* v_ptr,
341  const void* bias_ptr,
342  const void* lse_ptr,
343  const void* do_ptr,
344  const void* d_ptr,
345  void* rand_val_ptr,
346  void* dk_ptr,
347  void* dv_ptr,
348  void* dbias_ptr,
349  void* dq_acc_ptr, // can be dq_acc_ptr for qrqtrdor pipeline
350  ck_tile::index_t seqlen_q,
351  ck_tile::index_t seqlen_k,
352  ck_tile::index_t hdim_q,
353  ck_tile::index_t hdim_v,
354  ck_tile::index_t num_head_q,
355  ck_tile::index_t nhead_ratio_qk,
356  float scale,
357  ck_tile::index_t stride_q,
358  ck_tile::index_t stride_k,
359  ck_tile::index_t stride_v,
360  ck_tile::index_t stride_bias,
361  ck_tile::index_t stride_randval,
362  ck_tile::index_t stride_do,
363  ck_tile::index_t stride_dq_acc,
364  ck_tile::index_t stride_dk,
365  ck_tile::index_t stride_dv,
366  ck_tile::index_t stride_dbias,
367  ck_tile::index_t nhead_stride_q,
368  ck_tile::index_t nhead_stride_k,
369  ck_tile::index_t nhead_stride_v,
370  ck_tile::index_t nhead_stride_bias,
371  ck_tile::index_t nhead_stride_randval,
372  ck_tile::index_t nhead_stride_do,
373  ck_tile::index_t nhead_stride_lsed,
374  ck_tile::index_t nhead_stride_dq_acc,
375  ck_tile::index_t nhead_stride_dk,
376  ck_tile::index_t nhead_stride_dv,
377  ck_tile::index_t nhead_stride_dbias,
378  ck_tile::index_t batch_stride_q,
379  ck_tile::index_t batch_stride_k,
380  ck_tile::index_t batch_stride_v,
381  ck_tile::index_t batch_stride_bias,
382  ck_tile::index_t batch_stride_randval,
383  ck_tile::index_t batch_stride_do,
384  ck_tile::index_t batch_stride_lsed,
385  ck_tile::index_t batch_stride_dq_acc,
386  ck_tile::index_t batch_stride_dk,
387  ck_tile::index_t batch_stride_dv,
388  ck_tile::index_t batch_stride_dbias,
389  ck_tile::index_t split_stride_dq_acc,
390  ck_tile::index_t window_size_left,
391  ck_tile::index_t window_size_right,
392  ck_tile::index_t mask_type,
393  float p_drop,
394  std::variant<std::pair<uint64_t, uint64_t>, std::pair<const void*, const void*>>
395  drop_seed_offset)
396  {
397  Kargs kargs{{q_ptr,
398  k_ptr,
399  v_ptr,
400  lse_ptr,
401  do_ptr,
402  d_ptr,
403  dq_acc_ptr,
404  dk_ptr,
405  dv_ptr,
406  seqlen_q,
407  seqlen_k,
408  hdim_q,
409  hdim_v,
410  num_head_q,
411  nhead_ratio_qk,
412  scale,
413  static_cast<float>(scale * ck_tile::log2e_v<>),
414  stride_q,
415  stride_k,
416  stride_v,
417  stride_do,
418  stride_dq_acc,
419  stride_dk,
420  stride_dv,
421  nhead_stride_q,
422  nhead_stride_k,
423  nhead_stride_v,
424  nhead_stride_do,
425  nhead_stride_lsed,
426  nhead_stride_dq_acc,
427  nhead_stride_dk,
428  nhead_stride_dv}, // args for common karg
429  {}, // placeholder for bias
430  {}, // placeholder for dbias
431  {}, // placeholder for mask
432  {}, // placeholder for dropout
433  {}, // placeholder for deterministic
434  batch_stride_q,
435  batch_stride_k,
436  batch_stride_v,
437  batch_stride_do,
438  batch_stride_lsed,
439  batch_stride_dq_acc,
440  batch_stride_dk,
441  batch_stride_dv};
442 
444  {
445  kargs.bias_ptr = bias_ptr;
446  kargs.stride_bias = stride_bias;
447  kargs.nhead_stride_bias = nhead_stride_bias;
448  kargs.batch_stride_bias = batch_stride_bias;
449  }
450  else if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI)
451  {
452  kargs.alibi_slope_ptr = bias_ptr;
453  kargs.alibi_slope_stride = stride_bias;
454  }
455 
456  if constexpr(kHasBiasGrad)
457  {
458  kargs.dbias_ptr = dbias_ptr;
459  kargs.stride_dbias = stride_dbias;
460  kargs.nhead_stride_dbias = nhead_stride_dbias;
461  kargs.batch_stride_dbias = batch_stride_dbias;
462  }
463 
464  if constexpr(kHasMask)
465  {
466  kargs.window_size_left = window_size_left;
467  kargs.window_size_right = window_size_right;
468  kargs.mask_type = static_cast<ck_tile::GenericAttentionMaskEnum>(mask_type);
469  }
470 
471  if constexpr(kHasDropout)
472  {
473  if(drop_seed_offset.index() == 0) // seed & offset come from host
474  {
475  const auto& [seed, offset] = std::get<0>(drop_seed_offset);
476  kargs.init_dropout(p_drop, seed, offset, scale);
477  }
478  else // seed & offset come from device
479  {
480  const auto& [seed_ptr, offset_ptr] = std::get<1>(drop_seed_offset);
481  kargs.init_dropout(p_drop,
482  reinterpret_cast<const uint64_t*>(seed_ptr),
483  reinterpret_cast<const uint64_t*>(offset_ptr),
484  scale);
485  }
486 
487  if constexpr(kIsStoreRandval)
488  {
489  kargs.rand_val_ptr = rand_val_ptr;
490  kargs.stride_randval = stride_randval;
491  kargs.nhead_stride_randval = nhead_stride_randval;
492  kargs.batch_stride_randval = batch_stride_randval;
493  }
494  }
495 
496  if constexpr(kIsDeterministic && !kUseQrQtrDorPipeline)
497  {
498  kargs.split_stride_dq_acc = split_stride_dq_acc;
499  }
500 
501  return kargs;
502  }
503 
504  template <bool Cond = kIsGroupMode>
505  CK_TILE_HOST static constexpr std::enable_if_t<Cond, Kargs>
506  MakeKargsImpl(const void* q_ptr,
507  const void* k_ptr,
508  const void* v_ptr,
509  const void* bias_ptr,
510  const void* lse_ptr,
511  const void* do_ptr,
512  const void* d_ptr,
513  void* rand_val_ptr,
514  void* dk_ptr,
515  void* dv_ptr,
516  void* dbias_ptr,
517  void* dq_acc_ptr,
518  const void* seqstart_q_ptr,
519  const void* seqstart_k_ptr,
520  const void* seqlen_k_ptr,
521  ck_tile::index_t hdim_q,
522  ck_tile::index_t hdim_v,
523  ck_tile::index_t num_head_q,
524  ck_tile::index_t nhead_ratio_qk,
525  float scale,
526  ck_tile::index_t stride_q,
527  ck_tile::index_t stride_k,
528  ck_tile::index_t stride_v,
529  ck_tile::index_t stride_bias,
530  ck_tile::index_t stride_randval,
531  ck_tile::index_t stride_do,
532  ck_tile::index_t stride_dq_acc,
533  ck_tile::index_t stride_dk,
534  ck_tile::index_t stride_dv,
535  ck_tile::index_t stride_dbias,
536  ck_tile::index_t nhead_stride_q,
537  ck_tile::index_t nhead_stride_k,
538  ck_tile::index_t nhead_stride_v,
539  ck_tile::index_t nhead_stride_bias,
540  ck_tile::index_t nhead_stride_randval,
541  ck_tile::index_t nhead_stride_do,
542  ck_tile::index_t nhead_stride_lsed,
543  ck_tile::index_t nhead_stride_dq_acc,
544  ck_tile::index_t nhead_stride_dk,
545  ck_tile::index_t nhead_stride_dv,
546  ck_tile::index_t nhead_stride_dbias,
547  ck_tile::index_t split_stride_dq_acc,
548  ck_tile::index_t window_size_left,
549  ck_tile::index_t window_size_right,
550  ck_tile::index_t mask_type,
551  float p_drop,
552  std::variant<std::pair<uint64_t, uint64_t>, std::pair<const void*, const void*>>
553  drop_seed_offset)
554  {
555  Kargs kargs{{q_ptr,
556  k_ptr,
557  v_ptr,
558  lse_ptr,
559  do_ptr,
560  d_ptr,
561  dq_acc_ptr,
562  dk_ptr,
563  dv_ptr,
564  -1, // seqlen will be updated by another pointer
565  -1, //
566  hdim_q,
567  hdim_v,
568  num_head_q,
569  nhead_ratio_qk,
570  scale,
571  static_cast<float>(scale * ck_tile::log2e_v<>),
572  stride_q,
573  stride_k,
574  stride_v,
575  stride_do,
576  stride_dq_acc,
577  stride_dk,
578  stride_dv,
579  nhead_stride_q,
580  nhead_stride_k,
581  nhead_stride_v,
582  nhead_stride_do,
583  nhead_stride_lsed,
584  nhead_stride_dq_acc,
585  nhead_stride_dk,
586  nhead_stride_dv}, // args for common karg
587  {}, // placeholder for bias
588  {}, // placeholder for dbias
589  {}, // placeholder for mask
590  {}, // placeholder for dropout
591  {}, // placeholder for deterministic
592  reinterpret_cast<const int32_t*>(seqstart_q_ptr),
593  reinterpret_cast<const int32_t*>(seqstart_k_ptr),
594  reinterpret_cast<const int32_t*>(seqlen_k_ptr)};
595 
597  {
598  kargs.bias_ptr = bias_ptr;
599  kargs.stride_bias = stride_bias;
600  kargs.nhead_stride_bias = nhead_stride_bias;
601  }
602  else if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI)
603  {
604  kargs.alibi_slope_ptr = bias_ptr;
605  kargs.alibi_slope_stride = stride_bias;
606  }
607  if constexpr(kHasBiasGrad)
608  {
609  kargs.dbias_ptr = dbias_ptr;
610  kargs.stride_dbias = stride_dbias;
611  kargs.nhead_stride_dbias = nhead_stride_dbias;
612  }
613  if constexpr(kHasMask)
614  {
615  kargs.window_size_left = window_size_left;
616  kargs.window_size_right = window_size_right;
617  kargs.mask_type = static_cast<ck_tile::GenericAttentionMaskEnum>(mask_type);
618  }
619  if constexpr(kHasDropout)
620  {
621  if(drop_seed_offset.index() == 0) // seed & offset come from host
622  {
623  const auto& [seed, offset] = std::get<0>(drop_seed_offset);
624  kargs.init_dropout(p_drop, seed, offset, scale);
625  }
626  else // seed & offset come from device
627  {
628  const auto& [seed_ptr, offset_ptr] = std::get<1>(drop_seed_offset);
629  kargs.init_dropout(p_drop,
630  reinterpret_cast<const uint64_t*>(seed_ptr),
631  reinterpret_cast<const uint64_t*>(offset_ptr),
632  scale);
633  }
634 
635  if constexpr(kIsStoreRandval)
636  {
637  kargs.rand_val_ptr = rand_val_ptr;
638  kargs.stride_randval = stride_randval;
639  kargs.nhead_stride_randval = nhead_stride_randval;
640  }
641  }
642  if constexpr(kIsDeterministic)
643  {
644  kargs.split_stride_dq_acc = split_stride_dq_acc;
645  }
646 
647  return kargs;
648  }
649 
650  CK_TILE_HOST static constexpr auto
652  {
653  return dim3(
654  kUseQrQtrDorPipeline ? 1 : ck_tile::integer_divide_ceil(seqlen_k_, FmhaPipeline::kN0),
655  nhead_,
656  batch_size_);
657  }
658 
659  CK_TILE_DEVICE static constexpr auto GetTileIndex()
660  {
661  const index_t i_block = blockIdx.x;
662  const index_t i_nhead = blockIdx.y;
663  const index_t i_batch = blockIdx.z;
664 
665  return ck_tile::make_tuple(i_block, i_nhead, i_batch);
666  }
667 
668  CK_TILE_HOST static constexpr auto BlockSize() { return dim3(kBlockSize); }
669 
671  {
672  return ck_tile::max(FmhaPipeline::GetSmemSize(),
673  KGradEpiloguePipeline::GetSmemSize(),
674  VGradEpiloguePipeline::GetSmemSize());
675  }
676 
677  CK_TILE_DEVICE void operator()(Kargs kargs) const
678  {
679  if constexpr(kIsAvialable)
680  run_(std::move(kargs));
681  }
682 
683  CK_TILE_DEVICE void run_(Kargs kargs) const
684  {
685  // allocate LDS
686  __shared__ char smem_ptr[GetSmemSize()];
687 
688  // divide problem
689  const auto [i_tile_n, i_nhead, i_batch] = GetTileIndex();
690 
691  const index_t i_n0 = __builtin_amdgcn_readfirstlane(i_tile_n * FmhaPipeline::kN0);
692 
693  long_index_t batch_offset_q = 0;
694  long_index_t batch_offset_k = 0;
695  long_index_t batch_offset_v = 0;
696  long_index_t batch_offset_bias = 0;
697  long_index_t batch_offset_randval = 0;
698  long_index_t batch_offset_do = 0;
699  long_index_t batch_offset_lsed = 0;
700  long_index_t batch_offset_dq_acc = 0;
701  long_index_t batch_offset_dk = 0;
702  long_index_t batch_offset_dv = 0;
703  long_index_t batch_offset_dbias = 0;
704 
705  if constexpr(kIsGroupMode)
706  {
707  // get starting offset for each batch
708  const long_index_t query_start = kargs.seqstart_q_ptr[i_batch];
709  const long_index_t key_start = kargs.seqstart_k_ptr[i_batch];
710 
711  batch_offset_q = query_start * kargs.stride_q;
712  batch_offset_k = key_start * kargs.stride_k;
713  batch_offset_v = key_start * kargs.stride_v;
714  batch_offset_do = query_start * kargs.stride_do;
715  batch_offset_lsed = query_start;
716  batch_offset_dq_acc = query_start * kargs.stride_dq_acc;
717  batch_offset_dk = key_start * kargs.stride_dk;
718  batch_offset_dv = key_start * kargs.stride_dv;
720  {
721  batch_offset_bias = query_start * kargs.stride_bias;
722  }
723  if constexpr(kHasBiasGrad)
724  {
725  batch_offset_dbias = query_start * kargs.stride_dbias;
726  }
727  else
728  {
729  batch_offset_dbias = key_start;
730  }
731  if constexpr(kIsStoreRandval)
732  {
733  batch_offset_randval = query_start * kargs.stride_randval;
734  }
735 
736  // get real # queries & # keys under group mode
737  const auto adjusted_seqstart_q_ptr = kargs.seqstart_q_ptr + i_batch;
738  kargs.seqlen_q = adjusted_seqstart_q_ptr[1] - adjusted_seqstart_q_ptr[0];
739  if(kargs.seqlen_k_ptr != nullptr)
740  {
741  kargs.seqlen_k = kargs.seqlen_k_ptr[i_batch];
742  }
743  else
744  {
745  const auto adjusted_seqstart_k_ptr = kargs.seqstart_k_ptr + i_batch;
746  kargs.seqlen_k = adjusted_seqstart_k_ptr[1] - adjusted_seqstart_k_ptr[0];
747  }
748 
749  // # of required blocks is different in each groups, terminate unnecessary blocks
750  // earlier
751  if constexpr(!kUseQrQtrDorPipeline)
752  if(kargs.seqlen_k <= i_n0)
753  return;
754  }
755  else
756  {
757  batch_offset_q = static_cast<long_index_t>(i_batch) * kargs.batch_stride_q;
758  batch_offset_k = static_cast<long_index_t>(i_batch) * kargs.batch_stride_k;
759  batch_offset_v = static_cast<long_index_t>(i_batch) * kargs.batch_stride_v;
760  batch_offset_do = static_cast<long_index_t>(i_batch) * kargs.batch_stride_do;
761  batch_offset_lsed = static_cast<long_index_t>(i_batch) * kargs.batch_stride_lsed;
762  batch_offset_dq_acc = static_cast<long_index_t>(i_batch) * kargs.batch_stride_dq_acc;
763  batch_offset_dk = static_cast<long_index_t>(i_batch) * kargs.batch_stride_dk;
764  batch_offset_dv = static_cast<long_index_t>(i_batch) * kargs.batch_stride_dv;
766  {
767  batch_offset_bias = static_cast<long_index_t>(i_batch) * kargs.batch_stride_bias;
768  }
769  if constexpr(kHasBiasGrad)
770  {
771  batch_offset_dbias = static_cast<long_index_t>(i_batch) * kargs.batch_stride_dbias;
772  }
773  if constexpr(kIsStoreRandval)
774  {
775  batch_offset_randval =
776  static_cast<long_index_t>(i_batch) * kargs.batch_stride_randval;
777  }
778  }
779 
780  // for simplicity, batch stride we just modify the pointer
781  const QDataType* q_ptr = reinterpret_cast<const QDataType*>(kargs.q_ptr) +
782  static_cast<long_index_t>(i_nhead) * kargs.nhead_stride_q +
783  batch_offset_q;
784  const KDataType* k_ptr =
785  reinterpret_cast<const KDataType*>(kargs.k_ptr) +
786  static_cast<long_index_t>(i_nhead / kargs.nhead_ratio_qk) * kargs.nhead_stride_k +
787  batch_offset_k;
788  const VDataType* v_ptr =
789  reinterpret_cast<const VDataType*>(kargs.v_ptr) +
790  static_cast<long_index_t>(i_nhead / kargs.nhead_ratio_qk) * kargs.nhead_stride_v +
791  batch_offset_v;
792  const LSEDataType* lse_ptr = reinterpret_cast<const LSEDataType*>(kargs.lse_ptr) +
793  static_cast<long_index_t>(i_nhead) * kargs.nhead_stride_lsed +
794  batch_offset_lsed;
795  const DDataType* d_ptr = reinterpret_cast<const DDataType*>(kargs.d_ptr) +
796  static_cast<long_index_t>(i_nhead) * kargs.nhead_stride_lsed +
797  batch_offset_lsed;
798  const OGradDataType* do_ptr = reinterpret_cast<const OGradDataType*>(kargs.do_ptr) +
799  static_cast<long_index_t>(i_nhead) * kargs.nhead_stride_do +
800  batch_offset_do;
801  auto dk_ptr = reinterpret_cast<KGradDataType*>(kargs.dk_ptr) +
802  static_cast<long_index_t>(i_nhead) * kargs.nhead_stride_dk + batch_offset_dk;
803  auto dv_ptr = reinterpret_cast<VGradDataType*>(kargs.dv_ptr) +
804  static_cast<long_index_t>(i_nhead) * kargs.nhead_stride_dv + batch_offset_dv;
805 
806  // Q/K/V/LSE/D/dO/dQ/dK/dV DRAM and DRAM window
807  const auto q_dram_naive = make_naive_tensor_view<address_space_enum::global>(
808  q_ptr,
809  make_tuple(kargs.seqlen_q, kargs.hdim_q),
810  make_tuple(kargs.stride_q, 1),
812  number<1>{});
813  const auto q_dram = pad_tensor_view(
814  q_dram_naive,
817 
818  const auto k_dram_naive = make_naive_tensor_view<address_space_enum::global>(
819  k_ptr,
820  make_tuple(kargs.seqlen_k, kargs.hdim_q),
821  make_tuple(kargs.stride_k, 1),
823  number<1>{});
824  const auto k_dram = pad_tensor_view(
825  k_dram_naive,
828 
829  const auto v_dram = [&]() {
830  const auto v_dram_naive = make_naive_tensor_view<address_space_enum::global>(
831  v_ptr,
832  make_tuple(kargs.seqlen_k, kargs.hdim_v),
833  make_tuple(kargs.stride_v, 1),
835  number<1>{});
836  return pad_tensor_view(
837  v_dram_naive,
840  }();
841 
842  // lse and d should be fine to read unpaded data as they are not on the reduction dimension
843  const auto lse_dram = make_naive_tensor_view_packed<address_space_enum::global>(
844  lse_ptr, make_tuple(kargs.seqlen_q), number<FmhaPipeline::kM0>{});
845 
846  const auto d_dram = make_naive_tensor_view_packed<address_space_enum::global>(
847  d_ptr, make_tuple(kargs.seqlen_q), number<FmhaPipeline::kM0>{});
848 
849  const auto do_dram_naive = make_naive_tensor_view<address_space_enum::global>(
850  do_ptr,
851  make_tuple(kargs.seqlen_q, kargs.hdim_v),
852  make_tuple(kargs.stride_do, 1),
854  number<1>{});
855  const auto do_dram = pad_tensor_view(
856  do_dram_naive,
859 
860  auto q_dram_window = make_tile_window(
861  q_dram,
863  {0, 0});
864 
865  auto k_dram_window = make_tile_window(
866  k_dram,
868  {i_n0, 0});
869 
870  auto v_dram_window = make_tile_window(
871  v_dram,
873  {i_n0, 0});
874 
875  auto do_dram_window = make_tile_window(
876  do_dram,
878  {0, 0});
879 
880  auto dq_dram_window = [&, i_tile_n_ = i_tile_n, i_nhead_ = i_nhead]() {
881  constexpr bool kUseKSplit = !kUseQrQtrDorPipeline && kIsDeterministic;
882  using DType = std::conditional_t<kUseQrQtrDorPipeline, QGradDataType, AccDataType>;
883 
884  auto dq_acc_ptr = reinterpret_cast<DType*>(kargs.dq_acc_ptr) + [&]() {
885  if constexpr(kUseKSplit)
886  return static_cast<long_index_t>(i_nhead_) * kargs.nhead_stride_dq_acc +
887  static_cast<long_index_t>(i_tile_n_) * kargs.split_stride_dq_acc +
888  batch_offset_dq_acc;
889  else
890  return static_cast<long_index_t>(i_nhead_) * kargs.nhead_stride_dq_acc +
891  batch_offset_dq_acc;
892  }();
893 
894  constexpr auto DstInMemOp = conditional_expr<kUseKSplit>(
895  memory_operation_enum::set, memory_operation_enum::atomic_add);
896  const auto dq_acc_dram_naive =
897  make_naive_tensor_view<address_space_enum::global, DstInMemOp>(
898  dq_acc_ptr,
899  make_tuple(kargs.seqlen_q, kargs.hdim_q),
900  make_tuple(kargs.stride_dq_acc, 1),
902  number<1>{});
903  const auto dq_acc_dram = pad_tensor_view(
904  dq_acc_dram_naive,
907  return make_tile_window(
908  dq_acc_dram,
910  {0, 0});
911  }();
912 
913  auto lse_dram_window =
915 
916  auto d_dram_window = make_tile_window(d_dram, make_tuple(number<FmhaPipeline::kM0>{}), {0});
917 
920  constexpr auto bias_dram_window_lengths =
922  const auto bias_dram_window = [&, i_nhead_ = i_nhead]() {
924  {
925  const BiasDataType* bias_ptr =
926  reinterpret_cast<const BiasDataType*>(kargs.bias_ptr) +
927  static_cast<long_index_t>(i_nhead_) * kargs.nhead_stride_bias +
928  batch_offset_bias;
929 
930  const auto bias_dram = [&]() {
931  const auto bias_dram_naive = make_naive_tensor_view<address_space_enum::global>(
932  bias_ptr,
933  make_tuple(kargs.seqlen_q, kargs.seqlen_k),
934  make_tuple(kargs.stride_bias, 1),
936  number<1>{});
937 
938  return pad_tensor_view(
939  bias_dram_naive, bias_dram_window_lengths, sequence<false, true>{});
940  }();
941 
942  return make_tile_window(bias_dram, bias_dram_window_lengths, {0, i_n0});
943  }
944  else
945  {
946  return make_null_tile_window(bias_dram_window_lengths);
947  }
948  }();
949 
950  auto dbias_dram_window = [&, i_nhead_ = i_nhead]() {
951  if constexpr(kHasBiasGrad)
952  {
953  BiasGradDataType* dbias_ptr =
954  reinterpret_cast<BiasGradDataType*>(kargs.dbias_ptr) +
955  static_cast<long_index_t>(i_nhead_) * kargs.nhead_stride_dbias +
956  batch_offset_dbias;
957 
958  auto dbias_dram = [&]() {
959  const auto dbias_dram_naive =
960  make_naive_tensor_view<address_space_enum::global>(
961  dbias_ptr,
962  make_tuple(kargs.seqlen_q, kargs.seqlen_k),
963  make_tuple(kargs.stride_dbias, 1),
965  number<1>{});
966 
967  return pad_tensor_view(
968  dbias_dram_naive, bias_dram_window_lengths, sequence<false, true>{});
969  }();
970 
971  return make_tile_window(dbias_dram, bias_dram_window_lengths, {0, i_n0});
972  }
973  else
974  {
975  return make_null_tile_window(bias_dram_window_lengths);
976  }
977  }();
978 
979  // WA i_batch capture structure binding before c++20
980  auto position_encoding = [&, i_batch_ = i_batch, i_nhead_ = i_nhead]() {
981  if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI)
982  {
983  // data loading, shared by entire wg
984  // TODO: how to use s_read?
985  AccDataType slope = *(reinterpret_cast<const AccDataType*>(kargs.alibi_slope_ptr) +
986  i_batch_ * kargs.alibi_slope_stride + i_nhead_);
987  slope *= ck_tile::log2e_v<>;
988  if constexpr(kHasMask)
989  {
990  return make_alibi_from_lr_mask<AccDataType, false>(slope,
991  kargs.window_size_left,
992  kargs.window_size_right,
993  kargs.seqlen_q,
994  kargs.seqlen_k,
995  kargs.mask_type);
996  }
997  else
998  {
1000  slope, kargs.seqlen_q, kargs.seqlen_k, AlibiMode::FROM_BOTTOM_RIGHT};
1001  }
1002  }
1003  else
1004  {
1006  }
1007  }();
1008 
1009  // dropout
1010  float rp_undrop = 1;
1011  float scale_rp_undrop = 1;
1012  if constexpr(kHasDropout)
1013  {
1014  rp_undrop = kargs.rp_undrop;
1015  scale_rp_undrop = kargs.scale_rp_undrop;
1016  }
1017  auto dropout = [&, i_nhead_ = i_nhead, i_batch_ = i_batch]() {
1018  if constexpr(kHasDropout)
1019  {
1020  return FmhaDropout{i_batch_,
1021  i_nhead_,
1022  kargs.num_head_q,
1023  kargs.is_drop_seed_offset_from_host ? kargs.drop_seed.val
1024  : *kargs.drop_seed.ptr,
1025  kargs.is_drop_seed_offset_from_host ? kargs.drop_offset.val
1026  : *kargs.drop_offset.ptr,
1027  kargs.rp_undrop,
1028  kargs.p_undrop_in_uint8_t};
1029  }
1030  else
1031  {
1032  return FmhaDropout{};
1033  };
1034  }();
1035 
1036  auto randval_dram_window = [&, i_nhead_ = i_nhead]() {
1037  constexpr auto randval_dram_window_lengths =
1039  if constexpr(kIsStoreRandval)
1040  {
1041  RandValOutputDataType* rand_val_ptr =
1042  reinterpret_cast<RandValOutputDataType*>(kargs.rand_val_ptr) +
1043  static_cast<long_index_t>(i_nhead_) * kargs.nhead_stride_randval +
1044  batch_offset_randval;
1045 
1046  const auto randval_dram = [&]() {
1047  const auto randval_dram_naive =
1048  make_naive_tensor_view<address_space_enum::global>(
1049  rand_val_ptr,
1050  make_tuple(kargs.seqlen_q, kargs.seqlen_k),
1051  make_tuple(kargs.stride_randval, 1),
1052  number<1>{},
1053  number<1>{});
1054 
1055  return pad_tensor_view(
1056  randval_dram_naive, randval_dram_window_lengths, sequence<false, true>{});
1057  }();
1058 
1059  return make_tile_window(randval_dram, randval_dram_window_lengths, {0, i_n0});
1060  }
1061  else
1062  {
1063  return make_null_tile_window(randval_dram_window_lengths);
1064  }
1065  }();
1066 
1067  FmhaMask mask = [&]() {
1068  if constexpr(kHasMask)
1069  return ck_tile::make_generic_attention_mask_from_lr_window<FmhaMask>(
1070  kargs.window_size_left,
1071  kargs.window_size_right,
1072  kargs.seqlen_q,
1073  kargs.seqlen_k,
1075  else
1076  return FmhaMask{kargs.seqlen_q, kargs.seqlen_k};
1077  }();
1078 
1079  auto dk_dram = [&]() {
1080  const auto dk_dram_naive = make_naive_tensor_view<address_space_enum::global>(
1081  dk_ptr,
1082  make_tuple(kargs.seqlen_k, kargs.hdim_q),
1083  make_tuple(kargs.stride_dk, 1),
1085  number<1>{});
1086 
1087  return pad_tensor_view(
1088  dk_dram_naive,
1091  }();
1092 
1093  auto dv_dram = [&]() {
1094  const auto dv_dram_naive = make_naive_tensor_view<address_space_enum::global>(
1095  dv_ptr,
1096  make_tuple(kargs.seqlen_k, kargs.hdim_v),
1097  make_tuple(kargs.stride_dv, 1),
1099  number<1>{});
1100 
1101  return pad_tensor_view(
1102  dv_dram_naive,
1105  }();
1106 
1107  auto dk_dram_window = make_tile_window(
1108  dk_dram,
1110  {i_n0, 0});
1111 
1112  auto dv_dram_window = make_tile_window(
1113  dv_dram,
1115  {i_n0, 0});
1116  if constexpr(!kUseQrQtrDorPipeline)
1117  {
1118  auto [dk_acc_tile, dv_acc_tile] = FmhaPipeline{}(smem_ptr,
1119  q_dram_window,
1120  k_dram_window,
1121  v_dram_window,
1122  bias_dram_window,
1123  randval_dram_window,
1124  do_dram_window,
1125  lse_dram_window,
1126  d_dram_window,
1127  dq_dram_window,
1128  dbias_dram_window,
1129  mask,
1130  position_encoding,
1131  kargs.raw_scale,
1132  kargs.scale,
1133  rp_undrop,
1134  scale_rp_undrop,
1135  dropout);
1136 
1137  KGradEpiloguePipeline{}(dk_dram_window, dk_acc_tile, nullptr);
1138  VGradEpiloguePipeline{}(dv_dram_window, dv_acc_tile, nullptr);
1139  }
1140  else
1141  {
1142  FmhaPipeline{}(smem_ptr,
1143  q_dram_window,
1144  k_dram_window,
1145  v_dram_window,
1146  bias_dram_window,
1147  randval_dram_window,
1148  do_dram_window,
1149  lse_dram_window,
1150  d_dram_window,
1151  dq_dram_window,
1152  dk_dram_window,
1153  dv_dram_window,
1154  dbias_dram_window,
1158  mask,
1159  position_encoding,
1160  kargs.raw_scale,
1161  kargs.scale,
1162  rp_undrop,
1163  scale_rp_undrop,
1164  dropout);
1165  }
1166  }
1167 };
1168 
1169 template <typename FmhaBwdOGradDotO_>
1171 {
1173  static constexpr ck_tile::index_t kBlockSize = FmhaBwdOGradDotO::kBlockSize;
1174  static constexpr ck_tile::index_t kBlockPerCu = FmhaBwdOGradDotO::kBlockPerCu;
1175  static constexpr ck_tile::index_t kM0 = kBlockSize;
1176  static constexpr ck_tile::index_t kVHeaddim = FmhaBwdOGradDotO::kVHeaddim;
1177 
1181 
1182  static constexpr bool kIsGroupMode = FmhaBwdOGradDotO::kIsGroupMode;
1183  static constexpr bool kPadSeqLenQ = FmhaBwdOGradDotO::kPadSeqLenQ;
1184  static constexpr bool kPadHeadDimV = FmhaBwdOGradDotO::kPadHeadDimV;
1185 
1186  // clang-format off
1187  template <typename T> struct t2s;
1188  template <> struct t2s<ck_tile::fp16_t> { static constexpr const char * name = "fp16"; };
1189  template <> struct t2s<ck_tile::bf16_t> { static constexpr const char * name = "bf16"; };
1190  // clang-format on
1191 
1192  CK_TILE_HOST static std::string GetName()
1193  {
1194  // sync with generate.py
1195  // clang-format off
1196 
1197  #define _SS_ std::string
1198  #define _TS_ std::to_string
1199  auto pn = [&] () {
1200  std::string n;
1201  if (kPadSeqLenQ) n += "s";
1202  if (kPadHeadDimV) n += "dv";
1203  return n.empty() ? n : std::string("p") + n; }();
1204  return
1205  _SS_("fmha_bwd_dot_do_o_d") + _TS_(kVHeaddim) + "_" + _SS_(t2s<ODataType>::name) +
1206  "_" + (kIsGroupMode ? "group" : "batch") + "_" +
1207  ("o" + _TS_(kBlockPerCu)) + (pn.empty() ? "_npad" : "_" + pn);
1208  #undef _SS_
1209  #undef _TS_
1210  // clang-format on
1211  }
1212 
1213  // kargs use aggregate initializer, so no constructor will provided
1214  // use inheritance to minimize karg size
1215  // user need to use MakeKargs() function to create kargs.
1217  {
1218  const void* o_ptr;
1219  const void* do_ptr;
1220  void* d_ptr;
1221 
1222  float p_undrop;
1223 
1226 
1229 
1233  };
1234 
1236  {
1240  };
1241 
1243  {
1245  };
1246 
1247  using Kargs = std::
1248  conditional_t<kIsGroupMode, FmhaBwdOGradDotOGroupModeKargs, FmhaBwdOGradDotOBatchModeKargs>;
1249 
1250  template <bool Cond = !kIsGroupMode>
1251  CK_TILE_HOST static constexpr std::enable_if_t<Cond, Kargs>
1252  MakeKargs(const void* o_ptr,
1253  const void* do_ptr,
1254  void* d_ptr,
1255  float p_undrop,
1256  ck_tile::index_t seqlen_q,
1257  ck_tile::index_t hdim_v,
1258  ck_tile::index_t stride_do,
1259  ck_tile::index_t stride_o,
1260  ck_tile::index_t nhead_stride_do,
1261  ck_tile::index_t nhead_stride_o,
1262  ck_tile::index_t nhead_stride_d,
1263  ck_tile::index_t batch_stride_do,
1264  ck_tile::index_t batch_stride_o,
1265  ck_tile::index_t batch_stride_d)
1266  {
1267  Kargs kargs{{o_ptr,
1268  do_ptr,
1269  d_ptr,
1270  p_undrop,
1271  seqlen_q,
1272  hdim_v,
1273  stride_do,
1274  stride_o,
1275  nhead_stride_do,
1276  nhead_stride_o,
1277  nhead_stride_d},
1278  batch_stride_do,
1279  batch_stride_o,
1280  batch_stride_d};
1281 
1282  return kargs;
1283  }
1284 
1285  template <bool Cond = kIsGroupMode>
1286  CK_TILE_HOST static constexpr std::enable_if_t<Cond, Kargs>
1287  MakeKargs(const void* o_ptr,
1288  const void* do_ptr,
1289  void* d_ptr,
1290  float p_undrop,
1291  const void* seqstart_q_ptr,
1292  ck_tile::index_t hdim_v,
1293  ck_tile::index_t stride_do,
1294  ck_tile::index_t stride_o,
1295  ck_tile::index_t nhead_stride_do,
1296  ck_tile::index_t nhead_stride_o,
1297  ck_tile::index_t nhead_stride_d)
1298  {
1299  Kargs kargs{{o_ptr,
1300  do_ptr,
1301  d_ptr,
1302  p_undrop,
1303  -1, // seqlen will be updated by another pointer
1304  hdim_v,
1305  stride_do,
1306  stride_o,
1307  nhead_stride_do,
1308  nhead_stride_o,
1309  nhead_stride_d},
1310  reinterpret_cast<const int32_t*>(seqstart_q_ptr)};
1311 
1312  return kargs;
1313  }
1314 
1315  CK_TILE_HOST static constexpr auto
1317  {
1318  return dim3(ck_tile::integer_divide_ceil(seqlen_q_, kM0), nhead_, batch_size_);
1319  }
1320 
1321  CK_TILE_DEVICE static constexpr auto GetTileIndex()
1322  {
1323  const index_t i_block = blockIdx.x;
1324  const index_t i_nhead = blockIdx.y;
1325  const index_t i_batch = blockIdx.z;
1326 
1327  return ck_tile::make_tuple(i_block, i_nhead, i_batch);
1328  }
1329 
1330  CK_TILE_HOST static constexpr auto BlockSize() { return dim3(kBlockSize); }
1331 
1332  CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize() { return 0; }
1333 
1334  CK_TILE_DEVICE void operator()(Kargs kargs) const
1335  {
1336  // divide problem
1337  const auto [i_tile_m, i_nhead, i_batch] = GetTileIndex();
1338 
1339  const index_t i_m0 = __builtin_amdgcn_readfirstlane(i_tile_m * kM0);
1340 
1341  long_index_t batch_offset_o = 0;
1342  long_index_t batch_offset_do = 0;
1343  long_index_t batch_offset_d = 0;
1344 
1345  if constexpr(kIsGroupMode)
1346  {
1347  // get starting offset for each batch
1348  const long_index_t query_start = kargs.seqstart_q_ptr[i_batch];
1349 
1350  batch_offset_o = query_start * kargs.stride_o;
1351  batch_offset_do = query_start * kargs.stride_do;
1352  batch_offset_d = query_start;
1353 
1354  // get real # queries & # keys under group mode
1355  const auto adjusted_seqstart_q_ptr = kargs.seqstart_q_ptr + i_batch;
1356  kargs.seqlen_q = adjusted_seqstart_q_ptr[1] - adjusted_seqstart_q_ptr[0];
1357  // # of required blocks is different in each groups, terminate unnecessary blocks
1358  // earlier
1359  if(kargs.seqlen_q <= i_m0)
1360  {
1361  return;
1362  }
1363  }
1364  else
1365  {
1366  batch_offset_o = static_cast<long_index_t>(i_batch) * kargs.batch_stride_o;
1367  batch_offset_do = static_cast<long_index_t>(i_batch) * kargs.batch_stride_do;
1368  batch_offset_d = static_cast<long_index_t>(i_batch) * kargs.batch_stride_d;
1369  }
1370 
1371  // for simplicity, batch stride we just modify the pointer
1372  const ODataType* o_ptr = reinterpret_cast<const ODataType*>(kargs.o_ptr) +
1373  static_cast<long_index_t>(i_nhead) * kargs.nhead_stride_o +
1374  batch_offset_o;
1375  const OGradDataType* do_ptr = reinterpret_cast<const OGradDataType*>(kargs.do_ptr) +
1376  static_cast<long_index_t>(i_nhead) * kargs.nhead_stride_do +
1377  batch_offset_do;
1378  DDataType* d_ptr = reinterpret_cast<DDataType*>(kargs.d_ptr) +
1379  static_cast<long_index_t>(i_nhead) * kargs.nhead_stride_d +
1380  batch_offset_d;
1381 
1382  // O/dO/D DRAM and DRAM window
1383  const auto o_dram = [&]() {
1384  auto o_dram_naive = make_naive_tensor_view<address_space_enum::global>(
1385  o_ptr,
1386  make_tuple(kargs.seqlen_q, kargs.hdim_v),
1387  make_tuple(kargs.stride_o, 1),
1389  number<1>{});
1390  return pad_tensor_view(o_dram_naive,
1393  }();
1394  const auto do_dram = [&]() {
1395  auto do_dram_naive = make_naive_tensor_view<address_space_enum::global>(
1396  do_ptr,
1397  make_tuple(kargs.seqlen_q, kargs.hdim_v),
1398  make_tuple(kargs.stride_do, 1),
1400  number<1>{});
1401  return pad_tensor_view(do_dram_naive,
1404  }();
1405  auto d_dram = [&]() {
1406  const auto d_dram_naive = make_naive_tensor_view_packed<address_space_enum::global>(
1407  d_ptr, make_tuple(kargs.seqlen_q), number<1>{});
1408  return pad_tensor_view(
1409  d_dram_naive, make_tuple(number<kM0>{}), sequence<kPadSeqLenQ>{});
1410  }();
1411 
1412  auto o_dram_window =
1413  make_tile_window(o_dram, make_tuple(number<kM0>{}, number<kVHeaddim>{}), {i_m0, 0});
1414 
1415  auto do_dram_window =
1416  make_tile_window(do_dram, make_tuple(number<kM0>{}, number<kVHeaddim>{}), {i_m0, 0});
1417 
1418  auto d_dram_window = make_tile_window(d_dram, make_tuple(number<kM0>{}), {i_m0});
1419 
1420  FmhaBwdOGradDotO{}(o_dram_window, do_dram_window, d_dram_window, kargs.p_undrop);
1421  }
1422 };
1423 
1424 template <typename FmhaBwdConvertQGrad_>
1426 {
1428  static constexpr ck_tile::index_t kBlockSize = FmhaBwdConvertQGrad::kBlockSize;
1429  static constexpr ck_tile::index_t kBlockPerCu = FmhaBwdConvertQGrad::kBlockPerCu;
1430  static constexpr ck_tile::index_t kM0 = FmhaBwdConvertQGrad::kM0;
1431  static constexpr ck_tile::index_t kN0 = FmhaBwdConvertQGrad::kN0;
1432  static constexpr ck_tile::index_t kQKHeaddim = FmhaBwdConvertQGrad::kQKHeaddim;
1433 
1436 
1437  static constexpr bool kIsGroupMode = FmhaBwdConvertQGrad::kIsGroupMode;
1438  static constexpr bool kPadSeqLenQ = FmhaBwdConvertQGrad::kPadSeqLenQ;
1439  static constexpr bool kPadHeadDimQ = FmhaBwdConvertQGrad::kPadHeadDimQ;
1440  static constexpr bool kIsDeterministic = FmhaBwdConvertQGrad::kIsDeterministic;
1441 
1442  // clang-format off
1443  template <typename T> struct t2s;
1444  template <> struct t2s<ck_tile::fp16_t> { static constexpr const char * name = "fp16"; };
1445  template <> struct t2s<ck_tile::bf16_t> { static constexpr const char * name = "bf16"; };
1446  // clang-format on
1447 
1448  CK_TILE_HOST static std::string GetName()
1449  {
1450  // sync with generate.py
1451  // clang-format off
1452 
1453  #define _SS_ std::string
1454  #define _TS_ std::to_string
1455  auto pn = [&] () {
1456  std::string n;
1457  if (kPadSeqLenQ) n += "s";
1458  if (kPadHeadDimQ) n += "d";
1459  return n.empty() ? n : std::string("p") + n; }();
1460  return
1461  _SS_("fmha_bwd_convert_dq_d") + _TS_(kQKHeaddim) + "_"
1463  + "b" + _TS_(kM0) + "x" + _TS_(kN0) + "_"
1464  + (kIsGroupMode ? "group" : "batch") + "_"
1465  + ("o" + _TS_(kBlockPerCu)) + (pn.empty() ? "_npad" : "_" + pn)
1466  + (kIsDeterministic ? "_deterministic" : "_ndeterministic") ;
1467  #undef _SS_
1468  #undef _TS_
1469  // clang-format on
1470  }
1471 
1472  // to avoid duplicated base class prblem, introduce an template arg
1473  template <ck_tile::index_t I>
1475  {
1476  };
1477 
1478  // kargs use aggregate initializer, so no constructor will provided
1479  // use inheritance to minimize karg size
1480  // user need to use MakeKargs() function to create kargs.
1482  {
1483  const void* dq_acc_ptr;
1484  void* dq_ptr;
1485 
1489 
1494  };
1495 
1497  {
1499  };
1500 
1503  std::conditional_t<kIsDeterministic,
1504  FmhaBwdConvertQGradDeterministicKargs,
1505  FmhaBwdConvertQGradEmptyKargs<0>>
1506  {
1509  };
1510 
1513  std::conditional_t<kIsDeterministic,
1514  FmhaBwdConvertQGradDeterministicKargs,
1515  FmhaBwdConvertQGradEmptyKargs<0>>
1516  {
1519  };
1520 
1524 
1525  template <bool Cond = !kIsGroupMode>
1526  CK_TILE_HOST static constexpr std::enable_if_t<Cond, Kargs>
1527  MakeKargs(const void* dq_acc_ptr,
1528  void* dq_ptr,
1529  ck_tile::index_t seqlen_q,
1530  ck_tile::index_t seqlen_k,
1531  ck_tile::index_t hdim_q,
1532  ck_tile::index_t stride_dq,
1533  ck_tile::index_t stride_dq_acc,
1534  ck_tile::index_t nhead_stride_dq,
1535  ck_tile::index_t nhead_stride_dq_acc,
1536  ck_tile::index_t batch_stride_dq,
1537  ck_tile::index_t batch_stride_dq_acc,
1538  ck_tile::index_t split_stride_dq_acc)
1539  {
1540  Kargs kargs{{dq_acc_ptr,
1541  dq_ptr,
1542  seqlen_q,
1543  seqlen_k,
1544  hdim_q,
1545  stride_dq,
1546  stride_dq_acc,
1547  nhead_stride_dq,
1548  nhead_stride_dq_acc},
1549  {},
1550  batch_stride_dq,
1551  batch_stride_dq_acc};
1552 
1553  if constexpr(kIsDeterministic)
1554  {
1555  kargs.split_stride_dq_acc = split_stride_dq_acc;
1556  }
1557 
1558  return kargs;
1559  }
1560 
1561  template <bool Cond = kIsGroupMode>
1562  CK_TILE_HOST static constexpr std::enable_if_t<Cond, Kargs>
1563  MakeKargs(const void* dq_acc_ptr,
1564  void* dq_ptr,
1565  const void* seqstart_q_ptr,
1566  const void* seqstart_k_ptr,
1567  ck_tile::index_t hdim_q,
1568  ck_tile::index_t stride_dq,
1569  ck_tile::index_t stride_dq_acc,
1570  ck_tile::index_t nhead_stride_dq,
1571  ck_tile::index_t nhead_stride_dq_acc,
1572  ck_tile::index_t split_stride_dq_acc)
1573  {
1574  Kargs kargs{{dq_acc_ptr,
1575  dq_ptr,
1576  -1, // seqlen will be updated by another pointer
1577  -1, //
1578  hdim_q,
1579  stride_dq,
1580  stride_dq_acc,
1581  nhead_stride_dq,
1582  nhead_stride_dq_acc},
1583  {},
1584  reinterpret_cast<const int32_t*>(seqstart_q_ptr),
1585  reinterpret_cast<const int32_t*>(seqstart_k_ptr)};
1586 
1587  if constexpr(kIsDeterministic)
1588  {
1589  kargs.split_stride_dq_acc = split_stride_dq_acc;
1590  }
1591 
1592  return kargs;
1593  }
1594 
1595  CK_TILE_HOST static constexpr auto
1597  {
1598  return dim3(ck_tile::integer_divide_ceil(seqlen_q_, kM0), nhead_, batch_size_);
1599  }
1600 
1601  CK_TILE_DEVICE static constexpr auto GetTileIndex()
1602  {
1603  const index_t i_block = blockIdx.x;
1604  const index_t i_nhead = blockIdx.y;
1605  const index_t i_batch = blockIdx.z;
1606 
1607  return ck_tile::make_tuple(i_block, i_nhead, i_batch);
1608  }
1609 
1610  CK_TILE_HOST static constexpr auto BlockSize() { return dim3(kBlockSize); }
1611 
1612  CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize() { return 0; }
1613 
1614  CK_TILE_DEVICE void operator()(Kargs kargs) const
1615  {
1616  // divide problem
1617  const auto [i_tile_m, i_nhead, i_batch] = GetTileIndex();
1618 
1619  const index_t i_m0 = __builtin_amdgcn_readfirstlane(i_tile_m * kM0);
1620 
1621  long_index_t batch_offset_dq = 0;
1622  long_index_t batch_offset_dq_acc = 0;
1623  if constexpr(kIsGroupMode)
1624  {
1625  // get starting offset for each batch
1626  const long_index_t query_start = kargs.seqstart_q_ptr[i_batch];
1627  batch_offset_dq = query_start * kargs.stride_dq;
1628  batch_offset_dq_acc = query_start * kargs.stride_dq_acc;
1629 
1630  // get real # queries & # keys under group mode
1631  const auto adjusted_seqstart_q_ptr = kargs.seqstart_q_ptr + i_batch;
1632  kargs.seqlen_q = adjusted_seqstart_q_ptr[1] - adjusted_seqstart_q_ptr[0];
1633  if constexpr(kIsDeterministic)
1634  {
1635  const auto adjusted_seqstart_k_ptr = kargs.seqstart_k_ptr + i_batch;
1636  kargs.seqlen_k = adjusted_seqstart_k_ptr[1] - adjusted_seqstart_k_ptr[0];
1637  }
1638  // # of required blocks is different in each groups, terminate unnecessary blocks
1639  // earlier
1640  if(kargs.seqlen_q <= i_m0)
1641  {
1642  return;
1643  }
1644  }
1645  else
1646  {
1647  batch_offset_dq = static_cast<long_index_t>(i_batch) * kargs.batch_stride_dq;
1648  batch_offset_dq_acc = static_cast<long_index_t>(i_batch) * kargs.batch_stride_dq_acc;
1649  }
1650 
1651  // for simplicity, batch stride we just modify the pointer
1652  QGradDataType* dq_ptr = reinterpret_cast<QGradDataType*>(kargs.dq_ptr) +
1653  static_cast<long_index_t>(i_nhead) * kargs.nhead_stride_dq +
1654  batch_offset_dq;
1655 
1656  // dQAcc/dQ DRAM and DRAM window
1657  const auto dq_acc_dram = [&, i_nhead_ = i_nhead]() {
1658  if constexpr(kIsDeterministic)
1659  {
1660  const AccDataType* dq_acc_ptr =
1661  reinterpret_cast<const AccDataType*>(kargs.dq_acc_ptr) +
1662  static_cast<long_index_t>(i_nhead_) * (kargs.nhead_stride_dq_acc) +
1663  batch_offset_dq_acc;
1664 
1665  const index_t nsplits = ck_tile::integer_divide_ceil(kargs.seqlen_k, kN0);
1666 
1667  auto dq_acc_dram_naive = make_naive_tensor_view<address_space_enum::global>(
1668  dq_acc_ptr,
1669  make_tuple(nsplits, kargs.seqlen_q, kargs.hdim_q),
1670  make_tuple(kargs.split_stride_dq_acc, kargs.stride_dq_acc, 1),
1672  number<1>{});
1673  return pad_tensor_view(dq_acc_dram_naive,
1676  }
1677  else
1678  {
1679  const AccDataType* dq_acc_ptr =
1680  reinterpret_cast<const AccDataType*>(kargs.dq_acc_ptr) +
1681  static_cast<long_index_t>(i_nhead_) * (kargs.nhead_stride_dq_acc) +
1682  batch_offset_dq_acc;
1683 
1684  auto dq_acc_dram_naive = make_naive_tensor_view<address_space_enum::global>(
1685  dq_acc_ptr,
1686  make_tuple(kargs.seqlen_q, kargs.hdim_q),
1687  make_tuple(kargs.stride_dq_acc, 1),
1689  number<1>{});
1690  return pad_tensor_view(dq_acc_dram_naive,
1693  }
1694  }();
1695 
1696  auto dq_dram = [&]() {
1697  auto dq_dram_naive = make_naive_tensor_view<address_space_enum::global>(
1698  dq_ptr,
1699  make_tuple(kargs.seqlen_q, kargs.hdim_q),
1700  make_tuple(kargs.stride_dq, 1),
1702  number<1>{});
1703  return pad_tensor_view(dq_dram_naive,
1706  }();
1707 
1708  auto dq_acc_dram_window = [&]() {
1709  if constexpr(kIsDeterministic)
1710  {
1711  return make_tile_window(
1712  dq_acc_dram,
1714  {0, i_m0, 0});
1715  }
1716  else
1717  {
1718  return make_tile_window(
1719  dq_acc_dram, make_tuple(number<kM0>{}, number<kQKHeaddim>{}), {i_m0, 0});
1720  }
1721  }();
1722 
1723  auto dq_dram_window =
1724  make_tile_window(dq_dram, make_tuple(number<kM0>{}, number<kQKHeaddim>{}), {i_m0, 0});
1725 
1726  if constexpr(kIsDeterministic)
1727  {
1728  const index_t nsplits = ck_tile::integer_divide_ceil(kargs.seqlen_k, kN0);
1729  FmhaBwdConvertQGrad{}(dq_acc_dram_window, dq_dram_window, nsplits);
1730  }
1731  else
1732  {
1733  FmhaBwdConvertQGrad{}(dq_acc_dram_window, dq_dram_window);
1734  }
1735  }
1736 };
1737 
1738 } // namespace ck_tile
#define CK_TILE_DEVICE
Definition: config.hpp:41
#define CK_TILE_HOST
Definition: config.hpp:40
#define CK_TILE_HOST_DEVICE
Definition: config.hpp:42
#define _TS_
#define _SS_
__host__ constexpr __device__ T max(T x)
Definition: math.hpp:84
__host__ T floor(T x)
Definition: math_v2.hpp:367
Definition: cluster_descriptor.hpp:13
constexpr CK_TILE_DEVICE auto make_null_tile_window(const WindowLengths &window_lengths)
Definition: null_tile_window.hpp:66
constexpr CK_TILE_HOST_DEVICE auto integer_divide_ceil(X x, Y y)
Definition: math.hpp:149
_Float16 fp16_t
Definition: half.hpp:110
bfloat16_t bf16_t
Definition: bfloat16.hpp:113
int32_t index_t
Definition: integer.hpp:9
constexpr CK_TILE_HOST_DEVICE auto pad_tensor_view(const TensorView &tensor_view, const TileLengths &tile_lengths, DoPads)
Definition: tensor_view.hpp:530
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition: type_traits.hpp:21
int64_t long_index_t
Definition: integer.hpp:11
int32_t int32_t
Definition: integer.hpp:10
constexpr CK_TILE_DEVICE auto make_tile_window(null_tensor_view, const WindowLengths &window_lengths, const multi_index< WindowLengths::size()> &, Ts &&...)
Definition: null_tile_window.hpp:75
constexpr CK_TILE_HOST_DEVICE auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:360
GenericAttentionMaskEnum
Definition: block_masking.hpp:11
constexpr CK_TILE_HOST_DEVICE T max(T x)
Definition: math.hpp:161
typename conditional< predicate, X, Y >::type conditional_t
Definition: functional.hpp:115
__device__ X atomic_add(X *p_dst, const X &x)
unsigned char uint8_t
Definition: stdint.h:124
unsigned __int64 uint64_t
Definition: stdint.h:136
Definition: block_position_encoding.hpp:48
Definition: block_attention_bias_enum.hpp:19
Definition: block_position_encoding.hpp:137
ck_tile::index_t batch_stride_dq
Definition: fmha_bwd_kernel.hpp:1507
ck_tile::index_t batch_stride_dq_acc
Definition: fmha_bwd_kernel.hpp:1508
ck_tile::index_t seqlen_q
Definition: fmha_bwd_kernel.hpp:1486
ck_tile::index_t stride_dq_acc
Definition: fmha_bwd_kernel.hpp:1491
ck_tile::index_t nhead_stride_dq
Definition: fmha_bwd_kernel.hpp:1492
ck_tile::index_t hdim_q
Definition: fmha_bwd_kernel.hpp:1488
ck_tile::index_t stride_dq
Definition: fmha_bwd_kernel.hpp:1490
const void * dq_acc_ptr
Definition: fmha_bwd_kernel.hpp:1483
ck_tile::index_t seqlen_k
Definition: fmha_bwd_kernel.hpp:1487
ck_tile::index_t nhead_stride_dq_acc
Definition: fmha_bwd_kernel.hpp:1493
ck_tile::index_t split_stride_dq_acc
Definition: fmha_bwd_kernel.hpp:1498
const int32_t * seqstart_k_ptr
Definition: fmha_bwd_kernel.hpp:1518
const int32_t * seqstart_q_ptr
Definition: fmha_bwd_kernel.hpp:1517
Definition: fmha_bwd_kernel.hpp:1443
Definition: fmha_bwd_kernel.hpp:1426
static constexpr bool kIsGroupMode
Definition: fmha_bwd_kernel.hpp:1437
static constexpr bool kPadHeadDimQ
Definition: fmha_bwd_kernel.hpp:1439
static constexpr bool kIsDeterministic
Definition: fmha_bwd_kernel.hpp:1440
static constexpr ck_tile::index_t kBlockSize
Definition: fmha_bwd_kernel.hpp:1428
static constexpr bool kPadSeqLenQ
Definition: fmha_bwd_kernel.hpp:1438
static constexpr CK_TILE_HOST std::enable_if_t< Cond, Kargs > MakeKargs(const void *dq_acc_ptr, void *dq_ptr, const void *seqstart_q_ptr, const void *seqstart_k_ptr, ck_tile::index_t hdim_q, ck_tile::index_t stride_dq, ck_tile::index_t stride_dq_acc, ck_tile::index_t nhead_stride_dq, ck_tile::index_t nhead_stride_dq_acc, ck_tile::index_t split_stride_dq_acc)
Definition: fmha_bwd_kernel.hpp:1563
static constexpr ck_tile::index_t kBlockPerCu
Definition: fmha_bwd_kernel.hpp:1429
static constexpr CK_TILE_DEVICE auto GetTileIndex()
Definition: fmha_bwd_kernel.hpp:1601
static constexpr CK_TILE_HOST std::enable_if_t< Cond, Kargs > MakeKargs(const void *dq_acc_ptr, void *dq_ptr, ck_tile::index_t seqlen_q, ck_tile::index_t seqlen_k, ck_tile::index_t hdim_q, ck_tile::index_t stride_dq, ck_tile::index_t stride_dq_acc, ck_tile::index_t nhead_stride_dq, ck_tile::index_t nhead_stride_dq_acc, ck_tile::index_t batch_stride_dq, ck_tile::index_t batch_stride_dq_acc, ck_tile::index_t split_stride_dq_acc)
Definition: fmha_bwd_kernel.hpp:1527
static constexpr CK_TILE_HOST_DEVICE ck_tile::index_t GetSmemSize()
Definition: fmha_bwd_kernel.hpp:1612
CK_TILE_DEVICE void operator()(Kargs kargs) const
Definition: fmha_bwd_kernel.hpp:1614
ck_tile::remove_cvref_t< typename FmhaBwdConvertQGrad::AccDataType > AccDataType
Definition: fmha_bwd_kernel.hpp:1434
static constexpr ck_tile::index_t kM0
Definition: fmha_bwd_kernel.hpp:1430
ck_tile::remove_cvref_t< typename FmhaBwdConvertQGrad::QGradDataType > QGradDataType
Definition: fmha_bwd_kernel.hpp:1435
static constexpr CK_TILE_HOST auto GridSize(ck_tile::index_t batch_size_, ck_tile::index_t nhead_, ck_tile::index_t seqlen_q_)
Definition: fmha_bwd_kernel.hpp:1596
static constexpr ck_tile::index_t kN0
Definition: fmha_bwd_kernel.hpp:1431
ck_tile::remove_cvref_t< FmhaBwdConvertQGrad_ > FmhaBwdConvertQGrad
Definition: fmha_bwd_kernel.hpp:1427
std::conditional_t< kIsGroupMode, FmhaBwdConvertQGradGroupModeKargs, FmhaBwdConvertQGradBatchModeKargs > Kargs
Definition: fmha_bwd_kernel.hpp:1523
static constexpr ck_tile::index_t kQKHeaddim
Definition: fmha_bwd_kernel.hpp:1432
static constexpr CK_TILE_HOST auto BlockSize()
Definition: fmha_bwd_kernel.hpp:1610
static CK_TILE_HOST std::string GetName()
Definition: fmha_bwd_kernel.hpp:1448
Definition: fmha_bwd_kernel.hpp:189
const void * alibi_slope_ptr
Definition: fmha_bwd_kernel.hpp:191
ck_tile::index_t alibi_slope_stride
Definition: fmha_bwd_kernel.hpp:192
ck_tile::index_t batch_stride_dbias
Definition: fmha_bwd_kernel.hpp:204
ck_tile::index_t batch_stride_bias
Definition: fmha_bwd_kernel.hpp:185
ck_tile::index_t batch_stride_randval
Definition: fmha_bwd_kernel.hpp:269
Definition: fmha_bwd_kernel.hpp:288
ck_tile::index_t batch_stride_v
Definition: fmha_bwd_kernel.hpp:291
ck_tile::index_t batch_stride_k
Definition: fmha_bwd_kernel.hpp:290
ck_tile::index_t batch_stride_q
Definition: fmha_bwd_kernel.hpp:289
ck_tile::index_t batch_stride_do
Definition: fmha_bwd_kernel.hpp:292
ck_tile::index_t batch_stride_dq_acc
Definition: fmha_bwd_kernel.hpp:294
ck_tile::index_t batch_stride_dk
Definition: fmha_bwd_kernel.hpp:295
ck_tile::index_t batch_stride_dv
Definition: fmha_bwd_kernel.hpp:296
ck_tile::index_t batch_stride_lsed
Definition: fmha_bwd_kernel.hpp:293
ck_tile::index_t nhead_stride_dbias
Definition: fmha_bwd_kernel.hpp:199
void * dbias_ptr
Definition: fmha_bwd_kernel.hpp:197
ck_tile::index_t stride_dbias
Definition: fmha_bwd_kernel.hpp:198
Definition: fmha_bwd_kernel.hpp:177
ck_tile::index_t stride_bias
Definition: fmha_bwd_kernel.hpp:179
ck_tile::index_t nhead_stride_bias
Definition: fmha_bwd_kernel.hpp:180
const void * bias_ptr
Definition: fmha_bwd_kernel.hpp:178
uint8_t p_undrop_in_uint8_t
Definition: fmha_bwd_kernel.hpp:260
float rp_undrop
Definition: fmha_bwd_kernel.hpp:258
ck_tile::index_t nhead_stride_randval
Definition: fmha_bwd_kernel.hpp:264
void init_dropout(float p_drop, const uint64_t *seed_ptr, const uint64_t *offset_ptr, float raw_scale)
Definition: fmha_bwd_kernel.hpp:242
float scale_rp_undrop
Definition: fmha_bwd_kernel.hpp:259
void * rand_val_ptr
Definition: fmha_bwd_kernel.hpp:261
ck_tile::index_t stride_randval
Definition: fmha_bwd_kernel.hpp:263
void init_dropout(float p_drop, uint64_t seed, uint64_t offset, float raw_scale)
Definition: fmha_bwd_kernel.hpp:229
Definition: fmha_bwd_kernel.hpp:135
ck_tile::index_t nhead_stride_dk
Definition: fmha_bwd_kernel.hpp:172
ck_tile::index_t stride_do
Definition: fmha_bwd_kernel.hpp:161
ck_tile::index_t seqlen_k
Definition: fmha_bwd_kernel.hpp:147
const void * q_ptr
Definition: fmha_bwd_kernel.hpp:136
ck_tile::index_t hdim_q
Definition: fmha_bwd_kernel.hpp:148
ck_tile::index_t nhead_stride_do
Definition: fmha_bwd_kernel.hpp:169
ck_tile::index_t num_head_q
Definition: fmha_bwd_kernel.hpp:153
const void * lse_ptr
Definition: fmha_bwd_kernel.hpp:139
float raw_scale
Definition: fmha_bwd_kernel.hpp:155
ck_tile::index_t nhead_stride_k
Definition: fmha_bwd_kernel.hpp:167
ck_tile::index_t nhead_stride_q
Definition: fmha_bwd_kernel.hpp:166
ck_tile::index_t stride_dv
Definition: fmha_bwd_kernel.hpp:164
ck_tile::index_t nhead_stride_lsed
Definition: fmha_bwd_kernel.hpp:170
void * dq_acc_ptr
Definition: fmha_bwd_kernel.hpp:142
ck_tile::index_t stride_q
Definition: fmha_bwd_kernel.hpp:158
ck_tile::index_t seqlen_q
Definition: fmha_bwd_kernel.hpp:146
ck_tile::index_t stride_dk
Definition: fmha_bwd_kernel.hpp:163
const void * do_ptr
Definition: fmha_bwd_kernel.hpp:140
float scale
Definition: fmha_bwd_kernel.hpp:156
void * dk_ptr
Definition: fmha_bwd_kernel.hpp:143
ck_tile::index_t nhead_stride_v
Definition: fmha_bwd_kernel.hpp:168
ck_tile::index_t stride_v
Definition: fmha_bwd_kernel.hpp:160
const void * d_ptr
Definition: fmha_bwd_kernel.hpp:141
const void * k_ptr
Definition: fmha_bwd_kernel.hpp:137
ck_tile::index_t nhead_stride_dq_acc
Definition: fmha_bwd_kernel.hpp:171
ck_tile::index_t nhead_ratio_qk
Definition: fmha_bwd_kernel.hpp:154
void * dv_ptr
Definition: fmha_bwd_kernel.hpp:144
const void * v_ptr
Definition: fmha_bwd_kernel.hpp:138
ck_tile::index_t hdim_v
Definition: fmha_bwd_kernel.hpp:149
ck_tile::index_t stride_k
Definition: fmha_bwd_kernel.hpp:159
ck_tile::index_t nhead_stride_dv
Definition: fmha_bwd_kernel.hpp:173
ck_tile::index_t stride_dq_acc
Definition: fmha_bwd_kernel.hpp:162
ck_tile::index_t split_stride_dq_acc
Definition: fmha_bwd_kernel.hpp:274
bool is_drop_seed_offset_from_host
Definition: fmha_bwd_kernel.hpp:224
ValueOrPointer< uint64_t > drop_seed
Definition: fmha_bwd_kernel.hpp:222
ValueOrPointer< uint64_t > drop_offset
Definition: fmha_bwd_kernel.hpp:223
Definition: fmha_bwd_kernel.hpp:128
Definition: fmha_bwd_kernel.hpp:310
const int32_t * seqstart_k_ptr
Definition: fmha_bwd_kernel.hpp:312
const int32_t * seqstart_q_ptr
Definition: fmha_bwd_kernel.hpp:311
const int32_t * seqlen_k_ptr
Definition: fmha_bwd_kernel.hpp:313
Definition: fmha_bwd_kernel.hpp:208
ck_tile::GenericAttentionMaskEnum mask_type
Definition: fmha_bwd_kernel.hpp:210
ck_tile::index_t window_size_right
Definition: fmha_bwd_kernel.hpp:209
ck_tile::index_t window_size_left
Definition: fmha_bwd_kernel.hpp:209
Definition: fmha_bwd_kernel.hpp:84
Definition: fmha_bwd_kernel.hpp:35
static constexpr CK_TILE_HOST std::enable_if_t< Cond, Kargs > MakeKargsImpl(const void *q_ptr, const void *k_ptr, const void *v_ptr, const void *bias_ptr, const void *lse_ptr, const void *do_ptr, const void *d_ptr, void *rand_val_ptr, void *dk_ptr, void *dv_ptr, void *dbias_ptr, void *dq_acc_ptr, const void *seqstart_q_ptr, const void *seqstart_k_ptr, const void *seqlen_k_ptr, ck_tile::index_t hdim_q, ck_tile::index_t hdim_v, ck_tile::index_t num_head_q, ck_tile::index_t nhead_ratio_qk, float scale, ck_tile::index_t stride_q, ck_tile::index_t stride_k, ck_tile::index_t stride_v, ck_tile::index_t stride_bias, ck_tile::index_t stride_randval, ck_tile::index_t stride_do, ck_tile::index_t stride_dq_acc, ck_tile::index_t stride_dk, ck_tile::index_t stride_dv, ck_tile::index_t stride_dbias, ck_tile::index_t nhead_stride_q, ck_tile::index_t nhead_stride_k, ck_tile::index_t nhead_stride_v, ck_tile::index_t nhead_stride_bias, ck_tile::index_t nhead_stride_randval, ck_tile::index_t nhead_stride_do, ck_tile::index_t nhead_stride_lsed, ck_tile::index_t nhead_stride_dq_acc, ck_tile::index_t nhead_stride_dk, ck_tile::index_t nhead_stride_dv, ck_tile::index_t nhead_stride_dbias, ck_tile::index_t split_stride_dq_acc, ck_tile::index_t window_size_left, ck_tile::index_t window_size_right, ck_tile::index_t mask_type, float p_drop, std::variant< std::pair< uint64_t, uint64_t >, std::pair< const void *, const void * >> drop_seed_offset)
Definition: fmha_bwd_kernel.hpp:506
static constexpr CK_TILE_HOST auto BlockSize()
Definition: fmha_bwd_kernel.hpp:668
ck_tile::remove_cvref_t< KGradEpiloguePipeline_ > KGradEpiloguePipeline
Definition: fmha_bwd_kernel.hpp:37
static constexpr CK_TILE_HOST std::enable_if_t< Cond, Kargs > MakeKargsImpl(const void *q_ptr, const void *k_ptr, const void *v_ptr, const void *bias_ptr, const void *lse_ptr, const void *do_ptr, const void *d_ptr, void *rand_val_ptr, void *dk_ptr, void *dv_ptr, void *dbias_ptr, void *dq_acc_ptr, ck_tile::index_t seqlen_q, ck_tile::index_t seqlen_k, ck_tile::index_t hdim_q, ck_tile::index_t hdim_v, ck_tile::index_t num_head_q, ck_tile::index_t nhead_ratio_qk, float scale, ck_tile::index_t stride_q, ck_tile::index_t stride_k, ck_tile::index_t stride_v, ck_tile::index_t stride_bias, ck_tile::index_t stride_randval, ck_tile::index_t stride_do, ck_tile::index_t stride_dq_acc, ck_tile::index_t stride_dk, ck_tile::index_t stride_dv, ck_tile::index_t stride_dbias, ck_tile::index_t nhead_stride_q, ck_tile::index_t nhead_stride_k, ck_tile::index_t nhead_stride_v, ck_tile::index_t nhead_stride_bias, ck_tile::index_t nhead_stride_randval, ck_tile::index_t nhead_stride_do, ck_tile::index_t nhead_stride_lsed, ck_tile::index_t nhead_stride_dq_acc, ck_tile::index_t nhead_stride_dk, ck_tile::index_t nhead_stride_dv, ck_tile::index_t nhead_stride_dbias, ck_tile::index_t batch_stride_q, ck_tile::index_t batch_stride_k, ck_tile::index_t batch_stride_v, ck_tile::index_t batch_stride_bias, ck_tile::index_t batch_stride_randval, ck_tile::index_t batch_stride_do, ck_tile::index_t batch_stride_lsed, ck_tile::index_t batch_stride_dq_acc, ck_tile::index_t batch_stride_dk, ck_tile::index_t batch_stride_dv, ck_tile::index_t batch_stride_dbias, ck_tile::index_t split_stride_dq_acc, ck_tile::index_t window_size_left, ck_tile::index_t window_size_right, ck_tile::index_t mask_type, float p_drop, std::variant< std::pair< uint64_t, uint64_t >, std::pair< const void *, const void * >> drop_seed_offset)
Definition: fmha_bwd_kernel.hpp:338
static constexpr auto BiasEnum
Definition: fmha_bwd_kernel.hpp:66
static CK_TILE_HOST std::string GetName()
Definition: fmha_bwd_kernel.hpp:89
ck_tile::remove_cvref_t< typename FmhaPipeline::FmhaDropout > FmhaDropout
Definition: fmha_bwd_kernel.hpp:69
ck_tile::remove_cvref_t< typename FmhaPipeline::OGradDataType > OGradDataType
Definition: fmha_bwd_kernel.hpp:57
static constexpr bool kUseQrQtrDorPipeline
Definition: fmha_bwd_kernel.hpp:42
ck_tile::remove_cvref_t< VGradEpiloguePipeline_ > VGradEpiloguePipeline
Definition: fmha_bwd_kernel.hpp:38
ck_tile::remove_cvref_t< typename FmhaPipeline::BiasGradDataType > BiasGradDataType
Definition: fmha_bwd_kernel.hpp:61
static constexpr CK_TILE_HOST_DEVICE ck_tile::index_t GetSmemSize()
Definition: fmha_bwd_kernel.hpp:670
static constexpr ck_tile::index_t kBlockSize
Definition: fmha_bwd_kernel.hpp:40
static constexpr bool kHasMask
Definition: fmha_bwd_kernel.hpp:70
ck_tile::remove_cvref_t< typename FmhaPipeline::KDataType > KDataType
Definition: fmha_bwd_kernel.hpp:48
static constexpr CK_TILE_HOST Kargs MakeKargs(Ts... args, const std::tuple< uint64_t, uint64_t > &drop_seed_offset)
Definition: fmha_bwd_kernel.hpp:321
static constexpr bool kIsGroupMode
Definition: fmha_bwd_kernel.hpp:63
ck_tile::remove_cvref_t< typename FmhaPipeline::QGradDataType > QGradDataType
Definition: fmha_bwd_kernel.hpp:58
static constexpr CK_TILE_DEVICE auto GetTileIndex()
Definition: fmha_bwd_kernel.hpp:659
static constexpr bool kPadHeadDimV
Definition: fmha_bwd_kernel.hpp:65
ck_tile::remove_cvref_t< typename FmhaPipeline::DDataType > DDataType
Definition: fmha_bwd_kernel.hpp:54
static constexpr bool kIsAvialable
Definition: fmha_bwd_kernel.hpp:80
ck_tile::remove_cvref_t< typename FmhaPipeline::QDataType > QDataType
Definition: fmha_bwd_kernel.hpp:47
ck_tile::remove_cvref_t< typename FmhaPipeline::LSEDataType > LSEDataType
Definition: fmha_bwd_kernel.hpp:52
static constexpr CK_TILE_HOST auto GridSize(ck_tile::index_t batch_size_, ck_tile::index_t nhead_, ck_tile::index_t seqlen_k_)
Definition: fmha_bwd_kernel.hpp:651
ck_tile::remove_cvref_t< typename FmhaPipeline::VGradDataType > VGradDataType
Definition: fmha_bwd_kernel.hpp:60
ck_tile::remove_cvref_t< typename FmhaPipeline::GemmDataType > GemmDataType
Definition: fmha_bwd_kernel.hpp:51
CK_TILE_DEVICE void operator()(Kargs kargs) const
Definition: fmha_bwd_kernel.hpp:677
ck_tile::remove_cvref_t< QGradEpiloguePipeline_ > QGradEpiloguePipeline
Definition: fmha_bwd_kernel.hpp:39
static constexpr bool kPadHeadDimQ
Definition: fmha_bwd_kernel.hpp:64
static constexpr bool kHasDropout
Definition: fmha_bwd_kernel.hpp:71
ck_tile::remove_cvref_t< typename FmhaPipeline::AccDataType > AccDataType
Definition: fmha_bwd_kernel.hpp:53
CK_TILE_DEVICE void run_(Kargs kargs) const
Definition: fmha_bwd_kernel.hpp:683
ck_tile::remove_cvref_t< typename FmhaPipeline::VDataType > VDataType
Definition: fmha_bwd_kernel.hpp:49
static constexpr bool kHasBiasGrad
Definition: fmha_bwd_kernel.hpp:67
static constexpr bool kIsDeterministic
Definition: fmha_bwd_kernel.hpp:73
ck_tile::remove_cvref_t< typename FmhaPipeline::BiasDataType > BiasDataType
Definition: fmha_bwd_kernel.hpp:50
ck_tile::remove_cvref_t< typename FmhaPipeline::RandValOutputDataType > RandValOutputDataType
Definition: fmha_bwd_kernel.hpp:56
ck_tile::remove_cvref_t< FmhaPipeline_ > FmhaPipeline
Definition: fmha_bwd_kernel.hpp:36
static constexpr CK_TILE_HOST Kargs MakeKargs(Ts... args, const std::tuple< const void *, const void * > &drop_seed_offset)
Definition: fmha_bwd_kernel.hpp:330
ck_tile::remove_cvref_t< typename FmhaPipeline::KGradDataType > KGradDataType
Definition: fmha_bwd_kernel.hpp:59
ck_tile::remove_cvref_t< typename FmhaPipeline::FmhaMask > FmhaMask
Definition: fmha_bwd_kernel.hpp:68
static constexpr bool kUseTrLoad
Definition: fmha_bwd_kernel.hpp:74
static constexpr ck_tile::index_t kBlockPerCu
Definition: fmha_bwd_kernel.hpp:41
std::conditional_t< kIsGroupMode, FmhaBwdGroupModeKargs, FmhaBwdBatchModeKargs > Kargs
Definition: fmha_bwd_kernel.hpp:316
static constexpr bool kIsStoreRandval
Definition: fmha_bwd_kernel.hpp:72
static constexpr index_t kMaxSeqLenQ
Definition: fmha_bwd_kernel.hpp:75
ck_tile::index_t batch_stride_o
Definition: fmha_bwd_kernel.hpp:1238
ck_tile::index_t batch_stride_do
Definition: fmha_bwd_kernel.hpp:1237
ck_tile::index_t batch_stride_d
Definition: fmha_bwd_kernel.hpp:1239
void * d_ptr
Definition: fmha_bwd_kernel.hpp:1220
const void * o_ptr
Definition: fmha_bwd_kernel.hpp:1218
ck_tile::index_t hdim_v
Definition: fmha_bwd_kernel.hpp:1225
ck_tile::index_t nhead_stride_do
Definition: fmha_bwd_kernel.hpp:1230
ck_tile::index_t stride_o
Definition: fmha_bwd_kernel.hpp:1228
ck_tile::index_t nhead_stride_o
Definition: fmha_bwd_kernel.hpp:1231
const void * do_ptr
Definition: fmha_bwd_kernel.hpp:1219
ck_tile::index_t stride_do
Definition: fmha_bwd_kernel.hpp:1227
ck_tile::index_t seqlen_q
Definition: fmha_bwd_kernel.hpp:1224
float p_undrop
Definition: fmha_bwd_kernel.hpp:1222
ck_tile::index_t nhead_stride_d
Definition: fmha_bwd_kernel.hpp:1232
const int32_t * seqstart_q_ptr
Definition: fmha_bwd_kernel.hpp:1244
Definition: fmha_bwd_kernel.hpp:1187
Definition: fmha_bwd_kernel.hpp:1171
ck_tile::remove_cvref_t< typename FmhaBwdOGradDotO::ODataType > ODataType
Definition: fmha_bwd_kernel.hpp:1179
static constexpr CK_TILE_HOST_DEVICE ck_tile::index_t GetSmemSize()
Definition: fmha_bwd_kernel.hpp:1332
ck_tile::remove_cvref_t< FmhaBwdOGradDotO_ > FmhaBwdOGradDotO
Definition: fmha_bwd_kernel.hpp:1172
static constexpr CK_TILE_HOST std::enable_if_t< Cond, Kargs > MakeKargs(const void *o_ptr, const void *do_ptr, void *d_ptr, float p_undrop, ck_tile::index_t seqlen_q, ck_tile::index_t hdim_v, ck_tile::index_t stride_do, ck_tile::index_t stride_o, ck_tile::index_t nhead_stride_do, ck_tile::index_t nhead_stride_o, ck_tile::index_t nhead_stride_d, ck_tile::index_t batch_stride_do, ck_tile::index_t batch_stride_o, ck_tile::index_t batch_stride_d)
Definition: fmha_bwd_kernel.hpp:1252
static constexpr ck_tile::index_t kVHeaddim
Definition: fmha_bwd_kernel.hpp:1176
CK_TILE_DEVICE void operator()(Kargs kargs) const
Definition: fmha_bwd_kernel.hpp:1334
static constexpr bool kIsGroupMode
Definition: fmha_bwd_kernel.hpp:1182
ck_tile::remove_cvref_t< typename FmhaBwdOGradDotO::OGradDataType > OGradDataType
Definition: fmha_bwd_kernel.hpp:1180
static constexpr ck_tile::index_t kM0
Definition: fmha_bwd_kernel.hpp:1175
static constexpr CK_TILE_HOST std::enable_if_t< Cond, Kargs > MakeKargs(const void *o_ptr, const void *do_ptr, void *d_ptr, float p_undrop, const void *seqstart_q_ptr, ck_tile::index_t hdim_v, ck_tile::index_t stride_do, ck_tile::index_t stride_o, ck_tile::index_t nhead_stride_do, ck_tile::index_t nhead_stride_o, ck_tile::index_t nhead_stride_d)
Definition: fmha_bwd_kernel.hpp:1287
static constexpr ck_tile::index_t kBlockSize
Definition: fmha_bwd_kernel.hpp:1173
static constexpr ck_tile::index_t kBlockPerCu
Definition: fmha_bwd_kernel.hpp:1174
static constexpr CK_TILE_DEVICE auto GetTileIndex()
Definition: fmha_bwd_kernel.hpp:1321
static constexpr bool kPadSeqLenQ
Definition: fmha_bwd_kernel.hpp:1183
std::conditional_t< kIsGroupMode, FmhaBwdOGradDotOGroupModeKargs, FmhaBwdOGradDotOBatchModeKargs > Kargs
Definition: fmha_bwd_kernel.hpp:1248
static CK_TILE_HOST std::string GetName()
Definition: fmha_bwd_kernel.hpp:1192
ck_tile::remove_cvref_t< typename FmhaBwdOGradDotO::DDataType > DDataType
Definition: fmha_bwd_kernel.hpp:1178
static constexpr bool kPadHeadDimV
Definition: fmha_bwd_kernel.hpp:1184
static constexpr CK_TILE_HOST auto BlockSize()
Definition: fmha_bwd_kernel.hpp:1330
static constexpr CK_TILE_HOST auto GridSize(ck_tile::index_t batch_size_, ck_tile::index_t nhead_, ck_tile::index_t seqlen_q_)
Definition: fmha_bwd_kernel.hpp:1316
Definition: integral_constant.hpp:13
Definition: block_fmha_bwd_dq_dk_dv_pipeline_trload_qr_qtr_dor.hpp:772
Definition: coordinate_transform.hpp:1392
Definition: sequence.hpp:49