include/ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp Source File

include/ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp Source File#

Composable Kernel: include/ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp Source File
fmha_bwd_kernel.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include "ck_tile/core.hpp"
7 #include "ck_tile/ops/common.hpp"
9 
10 #include <string>
11 #include <type_traits>
12 #include <utility>
13 #include <variant>
14 
15 // S[seqlen_q, seqlen_k] = Q[seqlen_q, hdim_q] @ K[seqlen_k, hdim_q]
16 // S'[seqlen_q, seqlen_k] = S[seqlen_q, seqlen_k] * Scale[1]
17 // S''[seqlen_q, seqlen_k] = S'[seqlen_q, seqlen_k] + Bias[seqlen_q, seqlen_k]
18 // P[seqlen_q, seqlen_k] = Softmax(S''[seqlen_q, seqlen_k])
19 // dV[seqlen_k, hdim_v] = P^T[seqlen_k, seqlen_q] @ dO^T[hdim_v, seqlen_q]
20 // dP[seqlen_q, seqlen_k] = dO[seqlen_q, hdim_v] @ V[seqlen_k, hdim_v]
21 // D[seqlen_q] = rowsum(dO[seqlen_q, hdim_v] * O[seqlen_q, hdim_v])
22 // dS''[seqlen_q, seqlen_k] = P[seqlen_q, seqlen_k] * (dP[seqlen_q, seqlen_k] - D[seqlen_q])
23 // dBias[seqlen_q, seqlen_k] = dS'[seqlen_q, seqlen_k] = dS''[seqlen_q, seqlen_k]
24 // dK[seqlen_k, hdim_q] = dS'^T[seqlen_k, seqlen_q] @ Q^T[hdim_q, seqlen_q] * Scale[1]
25 // dQ[seqlen_q, hdim_q] = dS'[seqlen_q, seqlen_k] @ K^T[hdim_q, seqlen_k] * Scale[1]
26 
27 namespace ck_tile {
28 
29 template <typename FmhaPipeline_, typename KGradEpiloguePipeline_, typename VGradEpiloguePipeline_>
31 {
35  static constexpr ck_tile::index_t kBlockSize = FmhaPipeline::kBlockSize;
36  static constexpr ck_tile::index_t kBlockPerCu = FmhaPipeline::kBlockPerCu;
37 
53 
54  static constexpr bool kIsGroupMode = FmhaPipeline::kIsGroupMode;
55  static constexpr bool kPadSeqLenQ = FmhaPipeline::kPadSeqLenQ;
56  static constexpr bool kPadSeqLenK = FmhaPipeline::kPadSeqLenK;
57  static constexpr bool kPadHeadDimQ = FmhaPipeline::kPadHeadDimQ;
58  static constexpr bool kPadHeadDimV = FmhaPipeline::kPadHeadDimV;
59  static constexpr auto BiasEnum = FmhaPipeline::BiasEnum;
60  static constexpr bool kHasBiasGrad = FmhaPipeline::kHasBiasGrad;
63  static constexpr bool kHasMask = FmhaMask::IsMasking;
64  static constexpr bool kHasDropout = FmhaDropout::IsDropout;
65  static constexpr bool kIsStoreRandval = FmhaDropout::IsStoreRandval;
66  static constexpr bool kIsDeterministic = FmhaPipeline::kIsDeterministic;
67 
68  // clang-format off
69  template <typename T> struct t2s;
70  template <> struct t2s<ck_tile::fp16_t> { static constexpr const char * name = "fp16"; };
71  template <> struct t2s<ck_tile::bf16_t> { static constexpr const char * name = "bf16"; };
72  // clang-format on
73 
74  CK_TILE_HOST static std::string GetName()
75  {
76  // sync with generate.py
77  // clang-format off
78  using bfs = typename FmhaPipeline::BlockFmhaShape;
79  using gbr0 = typename bfs::Gemm0BlockWarps;
80  using gbr1 = typename bfs::Gemm1BlockWarps;
81  using gbr4 = typename bfs::Gemm4BlockWarps;
82  using gwt0 = typename bfs::Gemm0WarpTile;
83  using gwt1 = typename bfs::Gemm1WarpTile;
84  #define _SS_ std::string
85  #define _TS_ std::to_string
86  auto pn = [&] () {
87  std::string n;
88  if (kPadSeqLenQ) n += "s";
89  if (kPadSeqLenK) n += "sk";
90  if (kPadHeadDimQ) n += "d";
91  if (kPadHeadDimV) n += "dv";
92  return n.empty() ? n : std::string("p") + n; }();
93  return
94  _SS_("fmha_bwd_d") + _TS_(bfs::kQKHeaddim) + "_" + _SS_(t2s<QDataType>::name) +
95  "_" + (kIsGroupMode ? "group" : "batch") + "_" +
96  "b" + _TS_(bfs::kM0) + "x" + _TS_(bfs::kN0) + "x" + _TS_(bfs::kK0) + "x" + _TS_(bfs::kK1) + "x" + _TS_(bfs::kK2) + "x" + _TS_(bfs::kK3) + "x" +
97  _TS_(bfs::kK4) + "x" + _TS_(bfs::kQKHeaddim) + "x" + _TS_(bfs::kVHeaddim) + "_" +
98  "r" + _TS_(gbr0::at(ck_tile::number<0>{})) + "x" + _TS_(gbr0::at(ck_tile::number<1>{})) + "x" + _TS_(gbr0::at(ck_tile::number<2>{})) + "_" +
99  "r" + _TS_(gbr1::at(ck_tile::number<0>{})) + "x" + _TS_(gbr1::at(ck_tile::number<1>{})) + "x" + _TS_(gbr1::at(ck_tile::number<2>{})) + "_" +
100  "r" + _TS_(gbr4::at(ck_tile::number<0>{})) + "x" + _TS_(gbr4::at(ck_tile::number<1>{})) + "x" + _TS_(gbr4::at(ck_tile::number<2>{})) + "_" +
101  "w" + _TS_(gwt0::at(ck_tile::number<0>{})) + "x" + _TS_(gwt0::at(ck_tile::number<1>{})) + "x" + _TS_(gwt0::at(ck_tile::number<2>{})) + "_" +
102  "w" + _TS_(gwt1::at(ck_tile::number<0>{})) + "x" + _TS_(gwt1::at(ck_tile::number<1>{})) + "x" + _TS_(gwt1::at(ck_tile::number<2>{})) + "_" +
103  ("o" + _TS_(kBlockPerCu) + "_") + _SS_(FmhaPipeline::name) + (pn.empty() ? "" : "_" + pn) +
105  (kHasBiasGrad ? "_dbias" : "") + (kHasMask ? "_" + _SS_(FmhaMask::name) : "") + (kHasDropout ? "_dropout" : "" ) +
106  (kIsStoreRandval ? "_storerandval" : "" ) + (kIsDeterministic ? "_deterministic" : "" );
107  #undef _SS_
108  #undef _TS_
109  // clang-format on
110  }
111 
112  template <ck_tile::index_t I> // to avoid duplicated base class prblem, introduce an template
113  // arg
115  {
116  };
117 
118  // kargs use aggregate initializer, so no constructor will provided
119  // use inheritance to minimize karg size
120  // user need to use MakeKargs() function to create kargs.
122  {
123  const void* q_ptr;
124  const void* k_ptr;
125  const void* v_ptr;
126  const void* lse_ptr;
127  const void* do_ptr;
128  const void* d_ptr;
129  void* dq_acc_ptr;
130  void* dk_ptr;
131  void* dv_ptr;
132 
137 
138  // for MQA/GQA, nhead could be different. This parameter is nhead_q / nhead_k
139  // if this param is larger than 1, indicate MQA/GQA case
142  float raw_scale;
143  float scale;
144 
152 
161  };
162 
164  {
165  const void* bias_ptr = nullptr;
168  };
169 
171  {
173  };
174 
176  {
177  // alibi is batch*nhead*1, no matter in batch/group mode, they are the same
178  const void* alibi_slope_ptr;
179  ck_tile::index_t alibi_slope_stride; // stride in batch, or 0 for all batch share same slope
180  };
181 
183  {
184  void* dbias_ptr = nullptr;
187  };
188 
190  {
192  };
193 
195  {
198  };
199 
201  {
202  template <typename T>
204  {
205  T val;
206  const T* ptr;
207  };
208 
212  };
213 
215  {
216  void init_dropout(float p_drop, uint64_t seed, uint64_t offset, float raw_scale)
217  {
218  float p_undrop = 1.0 - p_drop;
220  uint8_t(std::floor(p_undrop * std::numeric_limits<uint8_t>::max()));
221  rp_undrop = 1.0 / p_undrop;
222  scale_rp_undrop = rp_undrop * raw_scale;
223 
224  this->drop_seed.val = seed;
225  this->drop_offset.val = offset;
226  this->is_drop_seed_offset_from_host = true;
227  }
228 
229  void init_dropout(float p_drop,
230  const uint64_t* seed_ptr,
231  const uint64_t* offset_ptr,
232  float raw_scale)
233  {
234  float p_undrop = 1.0 - p_drop;
236  uint8_t(std::floor(p_undrop * std::numeric_limits<uint8_t>::max()));
237  rp_undrop = 1.0 / p_undrop;
238  scale_rp_undrop = rp_undrop * raw_scale;
239 
240  this->drop_seed.ptr = seed_ptr;
241  this->drop_offset.ptr = offset_ptr;
242  this->is_drop_seed_offset_from_host = false;
243  }
244 
245  float rp_undrop = 1;
246  float scale_rp_undrop = 1;
248  void* rand_val_ptr = nullptr;
249 
252  };
253 
255  {
257  };
258 
260  {
262  };
263 
266  std::conditional_t<BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS,
267  FmhaBwdBatchModeBiasKargs,
268  std::conditional_t<BiasEnum == BlockAttentionBiasEnum::ALIBI,
269  FmhaBwdAlibiKargs,
270  FmhaBwdEmptyKargs<0>>>,
271  std::conditional_t<kHasBiasGrad, FmhaBwdBatchModeBiasGradKargs, FmhaBwdEmptyKargs<1>>,
272  std::conditional_t<kHasMask, FmhaBwdMaskKargs, FmhaBwdEmptyKargs<2>>,
273  std::conditional_t<kHasDropout, FmhaBwdBatchModeDropoutKargs, FmhaBwdEmptyKargs<3>>,
274  std::conditional_t<kIsDeterministic, FmhaBwdDeterministicKargs, FmhaBwdEmptyKargs<4>>
275  {
284  };
285 
288  std::conditional_t<BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS,
289  FmhaBwdCommonBiasKargs,
290  std::conditional_t<BiasEnum == BlockAttentionBiasEnum::ALIBI,
291  FmhaBwdAlibiKargs,
292  FmhaBwdEmptyKargs<0>>>,
293  std::conditional_t<kHasBiasGrad, FmhaBwdCommonBiasGradKargs, FmhaBwdEmptyKargs<1>>,
294  std::conditional_t<kHasMask, FmhaBwdMaskKargs, FmhaBwdEmptyKargs<2>>,
295  std::conditional_t<kHasDropout, FmhaBwdCommonDropoutKargs, FmhaBwdEmptyKargs<3>>,
296  std::conditional_t<kIsDeterministic, FmhaBwdDeterministicKargs, FmhaBwdEmptyKargs<4>>
297  {
298  const int32_t* seqstart_q_ptr;
299  const int32_t* seqstart_k_ptr;
300  const int32_t* seqlen_k_ptr;
301  };
302 
303  using Kargs = std::conditional_t<kIsGroupMode, FmhaBwdGroupModeKargs, FmhaBwdBatchModeKargs>;
304 
305  template <bool Cond = !kIsGroupMode>
306  CK_TILE_HOST static constexpr std::enable_if_t<Cond, Kargs>
307  MakeKargsImpl(const void* q_ptr,
308  const void* k_ptr,
309  const void* v_ptr,
310  const void* bias_ptr,
311  const void* lse_ptr,
312  const void* do_ptr,
313  const void* d_ptr,
314  void* rand_val_ptr,
315  void* dk_ptr,
316  void* dv_ptr,
317  void* dbias_ptr,
318  void* dq_acc_ptr,
319  ck_tile::index_t seqlen_q,
320  ck_tile::index_t seqlen_k,
321  ck_tile::index_t hdim_q,
322  ck_tile::index_t hdim_v,
323  ck_tile::index_t num_head_q,
324  ck_tile::index_t nhead_ratio_qk,
325  float scale,
326  ck_tile::index_t stride_q,
327  ck_tile::index_t stride_k,
328  ck_tile::index_t stride_v,
329  ck_tile::index_t stride_bias,
330  ck_tile::index_t stride_randval,
331  ck_tile::index_t stride_do,
332  ck_tile::index_t stride_dq_acc,
333  ck_tile::index_t stride_dk,
334  ck_tile::index_t stride_dv,
335  ck_tile::index_t stride_dbias,
336  ck_tile::index_t nhead_stride_q,
337  ck_tile::index_t nhead_stride_k,
338  ck_tile::index_t nhead_stride_v,
339  ck_tile::index_t nhead_stride_bias,
340  ck_tile::index_t nhead_stride_randval,
341  ck_tile::index_t nhead_stride_do,
342  ck_tile::index_t nhead_stride_lsed,
343  ck_tile::index_t nhead_stride_dq_acc,
344  ck_tile::index_t nhead_stride_dk,
345  ck_tile::index_t nhead_stride_dv,
346  ck_tile::index_t nhead_stride_dbias,
347  ck_tile::index_t batch_stride_q,
348  ck_tile::index_t batch_stride_k,
349  ck_tile::index_t batch_stride_v,
350  ck_tile::index_t batch_stride_bias,
351  ck_tile::index_t batch_stride_randval,
352  ck_tile::index_t batch_stride_do,
353  ck_tile::index_t batch_stride_lsed,
354  ck_tile::index_t batch_stride_dq_acc,
355  ck_tile::index_t batch_stride_dk,
356  ck_tile::index_t batch_stride_dv,
357  ck_tile::index_t batch_stride_dbias,
358  ck_tile::index_t split_stride_dq_acc,
359  ck_tile::index_t window_size_left,
360  ck_tile::index_t window_size_right,
361  ck_tile::index_t mask_type,
362  float p_drop,
363  std::variant<std::pair<uint64_t, uint64_t>, std::pair<const void*, const void*>>
364  drop_seed_offset)
365  {
366  Kargs kargs{{q_ptr,
367  k_ptr,
368  v_ptr,
369  lse_ptr,
370  do_ptr,
371  d_ptr,
372  dq_acc_ptr,
373  dk_ptr,
374  dv_ptr,
375  seqlen_q,
376  seqlen_k,
377  hdim_q,
378  hdim_v,
379  num_head_q,
380  nhead_ratio_qk,
381  scale,
382  static_cast<float>(scale * ck_tile::log2e_v<>),
383  stride_q,
384  stride_k,
385  stride_v,
386  stride_do,
387  stride_dq_acc,
388  stride_dk,
389  stride_dv,
390  nhead_stride_q,
391  nhead_stride_k,
392  nhead_stride_v,
393  nhead_stride_do,
394  nhead_stride_lsed,
395  nhead_stride_dq_acc,
396  nhead_stride_dk,
397  nhead_stride_dv}, // args for common karg
398  {}, // placeholder for bias
399  {}, // placeholder for dbias
400  {}, // placeholder for mask
401  {}, // placeholder for dropout
402  {}, // placeholder for deterministic
403  batch_stride_q,
404  batch_stride_k,
405  batch_stride_v,
406  batch_stride_do,
407  batch_stride_lsed,
408  batch_stride_dq_acc,
409  batch_stride_dk,
410  batch_stride_dv};
411 
413  {
414  kargs.bias_ptr = bias_ptr;
415  kargs.stride_bias = stride_bias;
416  kargs.nhead_stride_bias = nhead_stride_bias;
417  kargs.batch_stride_bias = batch_stride_bias;
418  }
419  else if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI)
420  {
421  kargs.alibi_slope_ptr = bias_ptr;
422  kargs.alibi_slope_stride = stride_bias;
423  }
424 
425  if constexpr(kHasBiasGrad)
426  {
427  kargs.dbias_ptr = dbias_ptr;
428  kargs.stride_dbias = stride_dbias;
429  kargs.nhead_stride_dbias = nhead_stride_dbias;
430  kargs.batch_stride_dbias = batch_stride_dbias;
431  }
432 
433  if constexpr(kHasMask)
434  {
435  kargs.window_size_left = window_size_left;
436  kargs.window_size_right = window_size_right;
437  kargs.mask_type = static_cast<ck_tile::GenericAttentionMaskEnum>(mask_type);
438  }
439 
440  if constexpr(kHasDropout)
441  {
442  if(drop_seed_offset.index() == 0) // seed & offset come from host
443  {
444  const auto& [seed, offset] = std::get<0>(drop_seed_offset);
445  kargs.init_dropout(p_drop, seed, offset, scale);
446  }
447  else // seed & offset come from device
448  {
449  const auto& [seed_ptr, offset_ptr] = std::get<1>(drop_seed_offset);
450  kargs.init_dropout(p_drop,
451  reinterpret_cast<const uint64_t*>(seed_ptr),
452  reinterpret_cast<const uint64_t*>(offset_ptr),
453  scale);
454  }
455 
456  if constexpr(kIsStoreRandval)
457  {
458  kargs.rand_val_ptr = rand_val_ptr;
459  kargs.stride_randval = stride_randval;
460  kargs.nhead_stride_randval = nhead_stride_randval;
461  kargs.batch_stride_randval = batch_stride_randval;
462  }
463  }
464 
465  if constexpr(kIsDeterministic)
466  {
467  kargs.split_stride_dq_acc = split_stride_dq_acc;
468  }
469 
470  return kargs;
471  }
472 
473  // std::variant<> can't take in a list initializer, overload for backward compatibility
474  template <bool Cond = !kIsGroupMode>
475  CK_TILE_HOST static constexpr std::enable_if_t<Cond, Kargs>
476  MakeKargs(const void* q_ptr,
477  const void* k_ptr,
478  const void* v_ptr,
479  const void* bias_ptr,
480  const void* lse_ptr,
481  const void* do_ptr,
482  const void* d_ptr,
483  void* rand_val_ptr,
484  void* dk_ptr,
485  void* dv_ptr,
486  void* dbias_ptr,
487  void* dq_acc_ptr,
488  ck_tile::index_t seqlen_q,
489  ck_tile::index_t seqlen_k,
490  ck_tile::index_t hdim_q,
491  ck_tile::index_t hdim_v,
492  ck_tile::index_t num_head_q,
493  ck_tile::index_t nhead_ratio_qk,
494  float scale,
495  ck_tile::index_t stride_q,
496  ck_tile::index_t stride_k,
497  ck_tile::index_t stride_v,
498  ck_tile::index_t stride_bias,
499  ck_tile::index_t stride_randval,
500  ck_tile::index_t stride_do,
501  ck_tile::index_t stride_dq_acc,
502  ck_tile::index_t stride_dk,
503  ck_tile::index_t stride_dv,
504  ck_tile::index_t stride_dbias,
505  ck_tile::index_t nhead_stride_q,
506  ck_tile::index_t nhead_stride_k,
507  ck_tile::index_t nhead_stride_v,
508  ck_tile::index_t nhead_stride_bias,
509  ck_tile::index_t nhead_stride_randval,
510  ck_tile::index_t nhead_stride_do,
511  ck_tile::index_t nhead_stride_lsed,
512  ck_tile::index_t nhead_stride_dq_acc,
513  ck_tile::index_t nhead_stride_dk,
514  ck_tile::index_t nhead_stride_dv,
515  ck_tile::index_t nhead_stride_dbias,
516  ck_tile::index_t batch_stride_q,
517  ck_tile::index_t batch_stride_k,
518  ck_tile::index_t batch_stride_v,
519  ck_tile::index_t batch_stride_bias,
520  ck_tile::index_t batch_stride_randval,
521  ck_tile::index_t batch_stride_do,
522  ck_tile::index_t batch_stride_lsed,
523  ck_tile::index_t batch_stride_dq_acc,
524  ck_tile::index_t batch_stride_dk,
525  ck_tile::index_t batch_stride_dv,
526  ck_tile::index_t batch_stride_dbias,
527  ck_tile::index_t split_stride_dq_acc,
528  ck_tile::index_t window_size_left,
529  ck_tile::index_t window_size_right,
530  ck_tile::index_t mask_type,
531  float p_drop,
532  const std::tuple<uint64_t, uint64_t>& drop_seed_offset)
533  {
534  return MakeKargsImpl(
535  q_ptr,
536  k_ptr,
537  v_ptr,
538  bias_ptr,
539  lse_ptr,
540  do_ptr,
541  d_ptr,
542  rand_val_ptr,
543  dk_ptr,
544  dv_ptr,
545  dbias_ptr,
546  dq_acc_ptr,
547  seqlen_q,
548  seqlen_k,
549  hdim_q,
550  hdim_v,
551  num_head_q,
552  nhead_ratio_qk,
553  scale,
554  stride_q,
555  stride_k,
556  stride_v,
557  stride_bias,
558  stride_randval,
559  stride_do,
560  stride_dq_acc,
561  stride_dk,
562  stride_dv,
563  stride_dbias,
564  nhead_stride_q,
565  nhead_stride_k,
566  nhead_stride_v,
567  nhead_stride_bias,
568  nhead_stride_randval,
569  nhead_stride_do,
570  nhead_stride_lsed,
571  nhead_stride_dq_acc,
572  nhead_stride_dk,
573  nhead_stride_dv,
574  nhead_stride_dbias,
575  batch_stride_q,
576  batch_stride_k,
577  batch_stride_v,
578  batch_stride_bias,
579  batch_stride_randval,
580  batch_stride_do,
581  batch_stride_lsed,
582  batch_stride_dq_acc,
583  batch_stride_dk,
584  batch_stride_dv,
585  batch_stride_dbias,
586  split_stride_dq_acc,
587  window_size_left,
588  window_size_right,
589  mask_type,
590  p_drop,
591  std::make_pair(std::get<0>(drop_seed_offset), std::get<1>(drop_seed_offset)));
592  }
593 
594  // std::variant<> can't take in a list initializer, overload for backward compatibility
595  template <bool Cond = !kIsGroupMode>
596  CK_TILE_HOST static constexpr std::enable_if_t<Cond, Kargs>
597  MakeKargs(const void* q_ptr,
598  const void* k_ptr,
599  const void* v_ptr,
600  const void* bias_ptr,
601  const void* lse_ptr,
602  const void* do_ptr,
603  const void* d_ptr,
604  void* rand_val_ptr,
605  void* dk_ptr,
606  void* dv_ptr,
607  void* dbias_ptr,
608  void* dq_acc_ptr,
609  ck_tile::index_t seqlen_q,
610  ck_tile::index_t seqlen_k,
611  ck_tile::index_t hdim_q,
612  ck_tile::index_t hdim_v,
613  ck_tile::index_t num_head_q,
614  ck_tile::index_t nhead_ratio_qk,
615  float scale,
616  ck_tile::index_t stride_q,
617  ck_tile::index_t stride_k,
618  ck_tile::index_t stride_v,
619  ck_tile::index_t stride_bias,
620  ck_tile::index_t stride_randval,
621  ck_tile::index_t stride_do,
622  ck_tile::index_t stride_dq_acc,
623  ck_tile::index_t stride_dk,
624  ck_tile::index_t stride_dv,
625  ck_tile::index_t stride_dbias,
626  ck_tile::index_t nhead_stride_q,
627  ck_tile::index_t nhead_stride_k,
628  ck_tile::index_t nhead_stride_v,
629  ck_tile::index_t nhead_stride_bias,
630  ck_tile::index_t nhead_stride_randval,
631  ck_tile::index_t nhead_stride_do,
632  ck_tile::index_t nhead_stride_lsed,
633  ck_tile::index_t nhead_stride_dq_acc,
634  ck_tile::index_t nhead_stride_dk,
635  ck_tile::index_t nhead_stride_dv,
636  ck_tile::index_t nhead_stride_dbias,
637  ck_tile::index_t batch_stride_q,
638  ck_tile::index_t batch_stride_k,
639  ck_tile::index_t batch_stride_v,
640  ck_tile::index_t batch_stride_bias,
641  ck_tile::index_t batch_stride_randval,
642  ck_tile::index_t batch_stride_do,
643  ck_tile::index_t batch_stride_lsed,
644  ck_tile::index_t batch_stride_dq_acc,
645  ck_tile::index_t batch_stride_dk,
646  ck_tile::index_t batch_stride_dv,
647  ck_tile::index_t batch_stride_dbias,
648  ck_tile::index_t split_stride_dq_acc,
649  ck_tile::index_t window_size_left,
650  ck_tile::index_t window_size_right,
651  ck_tile::index_t mask_type,
652  float p_drop,
653  const std::tuple<const void*, const void*>& drop_seed_offset)
654  {
655  return MakeKargsImpl(
656  q_ptr,
657  k_ptr,
658  v_ptr,
659  bias_ptr,
660  lse_ptr,
661  do_ptr,
662  d_ptr,
663  rand_val_ptr,
664  dk_ptr,
665  dv_ptr,
666  dbias_ptr,
667  dq_acc_ptr,
668  seqlen_q,
669  seqlen_k,
670  hdim_q,
671  hdim_v,
672  num_head_q,
673  nhead_ratio_qk,
674  scale,
675  stride_q,
676  stride_k,
677  stride_v,
678  stride_bias,
679  stride_randval,
680  stride_do,
681  stride_dq_acc,
682  stride_dk,
683  stride_dv,
684  stride_dbias,
685  nhead_stride_q,
686  nhead_stride_k,
687  nhead_stride_v,
688  nhead_stride_bias,
689  nhead_stride_randval,
690  nhead_stride_do,
691  nhead_stride_lsed,
692  nhead_stride_dq_acc,
693  nhead_stride_dk,
694  nhead_stride_dv,
695  nhead_stride_dbias,
696  batch_stride_q,
697  batch_stride_k,
698  batch_stride_v,
699  batch_stride_bias,
700  batch_stride_randval,
701  batch_stride_do,
702  batch_stride_lsed,
703  batch_stride_dq_acc,
704  batch_stride_dk,
705  batch_stride_dv,
706  batch_stride_dbias,
707  split_stride_dq_acc,
708  window_size_left,
709  window_size_right,
710  mask_type,
711  p_drop,
712  std::make_pair(std::get<0>(drop_seed_offset), std::get<1>(drop_seed_offset)));
713  }
714 
715  template <bool Cond = kIsGroupMode>
716  CK_TILE_HOST static constexpr std::enable_if_t<Cond, Kargs>
717  MakeKargsImpl(const void* q_ptr,
718  const void* k_ptr,
719  const void* v_ptr,
720  const void* bias_ptr,
721  const void* lse_ptr,
722  const void* do_ptr,
723  const void* d_ptr,
724  void* rand_val_ptr,
725  void* dk_ptr,
726  void* dv_ptr,
727  void* dbias_ptr,
728  void* dq_acc_ptr,
729  const void* seqstart_q_ptr,
730  const void* seqstart_k_ptr,
731  const void* seqlen_k_ptr,
732  ck_tile::index_t hdim_q,
733  ck_tile::index_t hdim_v,
734  ck_tile::index_t num_head_q,
735  ck_tile::index_t nhead_ratio_qk,
736  float scale,
737  ck_tile::index_t stride_q,
738  ck_tile::index_t stride_k,
739  ck_tile::index_t stride_v,
740  ck_tile::index_t stride_bias,
741  ck_tile::index_t stride_randval,
742  ck_tile::index_t stride_do,
743  ck_tile::index_t stride_dq_acc,
744  ck_tile::index_t stride_dk,
745  ck_tile::index_t stride_dv,
746  ck_tile::index_t stride_dbias,
747  ck_tile::index_t nhead_stride_q,
748  ck_tile::index_t nhead_stride_k,
749  ck_tile::index_t nhead_stride_v,
750  ck_tile::index_t nhead_stride_bias,
751  ck_tile::index_t nhead_stride_randval,
752  ck_tile::index_t nhead_stride_do,
753  ck_tile::index_t nhead_stride_lsed,
754  ck_tile::index_t nhead_stride_dq_acc,
755  ck_tile::index_t nhead_stride_dk,
756  ck_tile::index_t nhead_stride_dv,
757  ck_tile::index_t nhead_stride_dbias,
758  ck_tile::index_t split_stride_dq_acc,
759  ck_tile::index_t window_size_left,
760  ck_tile::index_t window_size_right,
761  ck_tile::index_t mask_type,
762  float p_drop,
763  std::variant<std::pair<uint64_t, uint64_t>, std::pair<const void*, const void*>>
764  drop_seed_offset)
765  {
766  Kargs kargs{{q_ptr,
767  k_ptr,
768  v_ptr,
769  lse_ptr,
770  do_ptr,
771  d_ptr,
772  dq_acc_ptr,
773  dk_ptr,
774  dv_ptr,
775  -1, // seqlen will be updated by another pointer
776  -1, //
777  hdim_q,
778  hdim_v,
779  num_head_q,
780  nhead_ratio_qk,
781  scale,
782  static_cast<float>(scale * ck_tile::log2e_v<>),
783  stride_q,
784  stride_k,
785  stride_v,
786  stride_do,
787  stride_dq_acc,
788  stride_dk,
789  stride_dv,
790  nhead_stride_q,
791  nhead_stride_k,
792  nhead_stride_v,
793  nhead_stride_do,
794  nhead_stride_lsed,
795  nhead_stride_dq_acc,
796  nhead_stride_dk,
797  nhead_stride_dv}, // args for common karg
798  {}, // placeholder for bias
799  {}, // placeholder for dbias
800  {}, // placeholder for mask
801  {}, // placeholder for dropout
802  {}, // placeholder for deterministic
803  reinterpret_cast<const int32_t*>(seqstart_q_ptr),
804  reinterpret_cast<const int32_t*>(seqstart_k_ptr),
805  reinterpret_cast<const int32_t*>(seqlen_k_ptr)};
806 
808  {
809  kargs.bias_ptr = bias_ptr;
810  kargs.stride_bias = stride_bias;
811  kargs.nhead_stride_bias = nhead_stride_bias;
812  }
813  else if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI)
814  {
815  kargs.alibi_slope_ptr = bias_ptr;
816  kargs.alibi_slope_stride = stride_bias;
817  }
818  if constexpr(kHasBiasGrad)
819  {
820  kargs.dbias_ptr = dbias_ptr;
821  kargs.stride_dbias = stride_dbias;
822  kargs.nhead_stride_dbias = nhead_stride_dbias;
823  }
824  if constexpr(kHasMask)
825  {
826  kargs.window_size_left = window_size_left;
827  kargs.window_size_right = window_size_right;
828  kargs.mask_type = static_cast<ck_tile::GenericAttentionMaskEnum>(mask_type);
829  }
830  if constexpr(kHasDropout)
831  {
832  if(drop_seed_offset.index() == 0) // seed & offset come from host
833  {
834  const auto& [seed, offset] = std::get<0>(drop_seed_offset);
835  kargs.init_dropout(p_drop, seed, offset, scale);
836  }
837  else // seed & offset come from device
838  {
839  const auto& [seed_ptr, offset_ptr] = std::get<1>(drop_seed_offset);
840  kargs.init_dropout(p_drop,
841  reinterpret_cast<const uint64_t*>(seed_ptr),
842  reinterpret_cast<const uint64_t*>(offset_ptr),
843  scale);
844  }
845 
846  if constexpr(kIsStoreRandval)
847  {
848  kargs.rand_val_ptr = rand_val_ptr;
849  kargs.stride_randval = stride_randval;
850  kargs.nhead_stride_randval = nhead_stride_randval;
851  }
852  }
853  if constexpr(kIsDeterministic)
854  {
855  kargs.split_stride_dq_acc = split_stride_dq_acc;
856  }
857 
858  return kargs;
859  }
860 
861  // std::variant<> can't take in a list initializer, overload for backward compatibility
862  template <bool Cond = kIsGroupMode>
863  CK_TILE_HOST static constexpr std::enable_if_t<Cond, Kargs>
864  MakeKargs(const void* q_ptr,
865  const void* k_ptr,
866  const void* v_ptr,
867  const void* bias_ptr,
868  const void* lse_ptr,
869  const void* do_ptr,
870  const void* d_ptr,
871  void* rand_val_ptr,
872  void* dk_ptr,
873  void* dv_ptr,
874  void* dbias_ptr,
875  void* dq_acc_ptr,
876  const void* seqstart_q_ptr,
877  const void* seqstart_k_ptr,
878  const void* seqlen_k_ptr,
879  ck_tile::index_t hdim_q,
880  ck_tile::index_t hdim_v,
881  ck_tile::index_t num_head_q,
882  ck_tile::index_t nhead_ratio_qk,
883  float scale,
884  ck_tile::index_t stride_q,
885  ck_tile::index_t stride_k,
886  ck_tile::index_t stride_v,
887  ck_tile::index_t stride_bias,
888  ck_tile::index_t stride_randval,
889  ck_tile::index_t stride_do,
890  ck_tile::index_t stride_dq_acc,
891  ck_tile::index_t stride_dk,
892  ck_tile::index_t stride_dv,
893  ck_tile::index_t stride_dbias,
894  ck_tile::index_t nhead_stride_q,
895  ck_tile::index_t nhead_stride_k,
896  ck_tile::index_t nhead_stride_v,
897  ck_tile::index_t nhead_stride_bias,
898  ck_tile::index_t nhead_stride_randval,
899  ck_tile::index_t nhead_stride_do,
900  ck_tile::index_t nhead_stride_lsed,
901  ck_tile::index_t nhead_stride_dq_acc,
902  ck_tile::index_t nhead_stride_dk,
903  ck_tile::index_t nhead_stride_dv,
904  ck_tile::index_t nhead_stride_dbias,
905  ck_tile::index_t split_stride_dq_acc,
906  ck_tile::index_t window_size_left,
907  ck_tile::index_t window_size_right,
908  ck_tile::index_t mask_type,
909  float p_drop,
910  const std::tuple<uint64_t, uint64_t>& drop_seed_offset)
911  {
912  return MakeKargsImpl(
913  q_ptr,
914  k_ptr,
915  v_ptr,
916  bias_ptr,
917  lse_ptr,
918  do_ptr,
919  d_ptr,
920  rand_val_ptr,
921  dk_ptr,
922  dv_ptr,
923  dbias_ptr,
924  dq_acc_ptr,
925  seqstart_q_ptr,
926  seqstart_k_ptr,
927  seqlen_k_ptr,
928  hdim_q,
929  hdim_v,
930  num_head_q,
931  nhead_ratio_qk,
932  scale,
933  stride_q,
934  stride_k,
935  stride_v,
936  stride_bias,
937  stride_randval,
938  stride_do,
939  stride_dq_acc,
940  stride_dk,
941  stride_dv,
942  stride_dbias,
943  nhead_stride_q,
944  nhead_stride_k,
945  nhead_stride_v,
946  nhead_stride_bias,
947  nhead_stride_randval,
948  nhead_stride_do,
949  nhead_stride_lsed,
950  nhead_stride_dq_acc,
951  nhead_stride_dk,
952  nhead_stride_dv,
953  nhead_stride_dbias,
954  split_stride_dq_acc,
955  window_size_left,
956  window_size_right,
957  mask_type,
958  p_drop,
959  std::make_pair(std::get<0>(drop_seed_offset), std::get<1>(drop_seed_offset)));
960  }
961 
962  // std::variant<> can't take in a list initializer, overload for backward compatibility
963  template <bool Cond = kIsGroupMode>
964  CK_TILE_HOST static constexpr std::enable_if_t<Cond, Kargs>
965  MakeKargs(const void* q_ptr,
966  const void* k_ptr,
967  const void* v_ptr,
968  const void* bias_ptr,
969  const void* lse_ptr,
970  const void* do_ptr,
971  const void* d_ptr,
972  void* rand_val_ptr,
973  void* dk_ptr,
974  void* dv_ptr,
975  void* dbias_ptr,
976  void* dq_acc_ptr,
977  const void* seqstart_q_ptr,
978  const void* seqstart_k_ptr,
979  const void* seqlen_k_ptr,
980  ck_tile::index_t hdim_q,
981  ck_tile::index_t hdim_v,
982  ck_tile::index_t num_head_q,
983  ck_tile::index_t nhead_ratio_qk,
984  float scale,
985  ck_tile::index_t stride_q,
986  ck_tile::index_t stride_k,
987  ck_tile::index_t stride_v,
988  ck_tile::index_t stride_bias,
989  ck_tile::index_t stride_randval,
990  ck_tile::index_t stride_do,
991  ck_tile::index_t stride_dq_acc,
992  ck_tile::index_t stride_dk,
993  ck_tile::index_t stride_dv,
994  ck_tile::index_t stride_dbias,
995  ck_tile::index_t nhead_stride_q,
996  ck_tile::index_t nhead_stride_k,
997  ck_tile::index_t nhead_stride_v,
998  ck_tile::index_t nhead_stride_bias,
999  ck_tile::index_t nhead_stride_randval,
1000  ck_tile::index_t nhead_stride_do,
1001  ck_tile::index_t nhead_stride_lsed,
1002  ck_tile::index_t nhead_stride_dq_acc,
1003  ck_tile::index_t nhead_stride_dk,
1004  ck_tile::index_t nhead_stride_dv,
1005  ck_tile::index_t nhead_stride_dbias,
1006  ck_tile::index_t split_stride_dq_acc,
1007  ck_tile::index_t window_size_left,
1008  ck_tile::index_t window_size_right,
1009  ck_tile::index_t mask_type,
1010  float p_drop,
1011  const std::tuple<const void*, const void*>& drop_seed_offset)
1012  {
1013  return MakeKargsImpl(
1014  q_ptr,
1015  k_ptr,
1016  v_ptr,
1017  bias_ptr,
1018  lse_ptr,
1019  do_ptr,
1020  d_ptr,
1021  rand_val_ptr,
1022  dk_ptr,
1023  dv_ptr,
1024  dbias_ptr,
1025  dq_acc_ptr,
1026  seqstart_q_ptr,
1027  seqstart_k_ptr,
1028  seqlen_k_ptr,
1029  hdim_q,
1030  hdim_v,
1031  num_head_q,
1032  nhead_ratio_qk,
1033  scale,
1034  stride_q,
1035  stride_k,
1036  stride_v,
1037  stride_bias,
1038  stride_randval,
1039  stride_do,
1040  stride_dq_acc,
1041  stride_dk,
1042  stride_dv,
1043  stride_dbias,
1044  nhead_stride_q,
1045  nhead_stride_k,
1046  nhead_stride_v,
1047  nhead_stride_bias,
1048  nhead_stride_randval,
1049  nhead_stride_do,
1050  nhead_stride_lsed,
1051  nhead_stride_dq_acc,
1052  nhead_stride_dk,
1053  nhead_stride_dv,
1054  nhead_stride_dbias,
1055  split_stride_dq_acc,
1056  window_size_left,
1057  window_size_right,
1058  mask_type,
1059  p_drop,
1060  std::make_pair(std::get<0>(drop_seed_offset), std::get<1>(drop_seed_offset)));
1061  }
1062 
1063  CK_TILE_HOST static constexpr auto
1065  {
1066  return dim3(
1067  ck_tile::integer_divide_ceil(seqlen_k_, FmhaPipeline::kN0), nhead_, batch_size_);
1068  }
1069 
1070  CK_TILE_DEVICE static constexpr auto GetTileIndex()
1071  {
1072  const index_t i_block = blockIdx.x;
1073  const index_t i_nhead = blockIdx.y;
1074  const index_t i_batch = blockIdx.z;
1075 
1076  return ck_tile::make_tuple(i_block, i_nhead, i_batch);
1077  }
1078 
1079  CK_TILE_HOST static constexpr auto BlockSize() { return dim3(kBlockSize); }
1080 
1082  {
1083  return ck_tile::max(FmhaPipeline::GetSmemSize(),
1084  KGradEpiloguePipeline::GetSmemSize(),
1085  VGradEpiloguePipeline::GetSmemSize());
1086  }
1087 
1088  CK_TILE_DEVICE void operator()(Kargs kargs) const
1089  {
1090  // allocate LDS
1091  __shared__ char smem_ptr[GetSmemSize()];
1092 
1093  // divide problem
1094  const auto [i_tile_n, i_nhead, i_batch] = GetTileIndex();
1095 
1096  const index_t i_n0 = __builtin_amdgcn_readfirstlane(i_tile_n * FmhaPipeline::kN0);
1097 
1098  long_index_t batch_offset_q = 0;
1099  long_index_t batch_offset_k = 0;
1100  long_index_t batch_offset_v = 0;
1101  long_index_t batch_offset_bias = 0;
1102  long_index_t batch_offset_randval = 0;
1103  long_index_t batch_offset_do = 0;
1104  long_index_t batch_offset_lsed = 0;
1105  long_index_t batch_offset_dq_acc = 0;
1106  long_index_t batch_offset_dk = 0;
1107  long_index_t batch_offset_dv = 0;
1108  long_index_t batch_offset_dbias = 0;
1109 
1110  if constexpr(kIsGroupMode)
1111  {
1112  // get starting offset for each batch
1113  const long_index_t query_start = kargs.seqstart_q_ptr[i_batch];
1114  const long_index_t key_start = kargs.seqstart_k_ptr[i_batch];
1115 
1116  batch_offset_q = query_start * kargs.stride_q;
1117  batch_offset_k = key_start * kargs.stride_k;
1118  batch_offset_v = key_start * kargs.stride_v;
1119  batch_offset_do = query_start * kargs.stride_do;
1120  batch_offset_lsed = query_start;
1121  batch_offset_dq_acc = query_start * kargs.stride_dq_acc;
1122  batch_offset_dk = key_start * kargs.stride_dk;
1123  batch_offset_dv = key_start * kargs.stride_dv;
1125  {
1126  batch_offset_bias = query_start * kargs.stride_bias;
1127  }
1128  if constexpr(kHasBiasGrad)
1129  {
1130  batch_offset_dbias = query_start * kargs.stride_dbias;
1131  }
1132  else
1133  {
1134  batch_offset_dbias = key_start;
1135  }
1136  if constexpr(kIsStoreRandval)
1137  {
1138  batch_offset_randval = query_start * kargs.stride_randval;
1139  }
1140 
1141  // get real # queries & # keys under group mode
1142  const auto adjusted_seqstart_q_ptr = kargs.seqstart_q_ptr + i_batch;
1143  kargs.seqlen_q = adjusted_seqstart_q_ptr[1] - adjusted_seqstart_q_ptr[0];
1144  if(kargs.seqlen_k_ptr != nullptr)
1145  {
1146  kargs.seqlen_k = kargs.seqlen_k_ptr[i_batch];
1147  }
1148  else
1149  {
1150  const auto adjusted_seqstart_k_ptr = kargs.seqstart_k_ptr + i_batch;
1151  kargs.seqlen_k = adjusted_seqstart_k_ptr[1] - adjusted_seqstart_k_ptr[0];
1152  }
1153 
1154  // # of required blocks is different in each groups, terminate unnecessary blocks
1155  // earlier
1156  if(kargs.seqlen_k <= i_n0)
1157  {
1158  return;
1159  }
1160  }
1161  else
1162  {
1163  batch_offset_q = static_cast<long_index_t>(i_batch) * kargs.batch_stride_q;
1164  batch_offset_k = static_cast<long_index_t>(i_batch) * kargs.batch_stride_k;
1165  batch_offset_v = static_cast<long_index_t>(i_batch) * kargs.batch_stride_v;
1166  batch_offset_do = static_cast<long_index_t>(i_batch) * kargs.batch_stride_do;
1167  batch_offset_lsed = static_cast<long_index_t>(i_batch) * kargs.batch_stride_lsed;
1168  batch_offset_dq_acc = static_cast<long_index_t>(i_batch) * kargs.batch_stride_dq_acc;
1169  batch_offset_dk = static_cast<long_index_t>(i_batch) * kargs.batch_stride_dk;
1170  batch_offset_dv = static_cast<long_index_t>(i_batch) * kargs.batch_stride_dv;
1172  {
1173  batch_offset_bias = static_cast<long_index_t>(i_batch) * kargs.batch_stride_bias;
1174  }
1175  if constexpr(kHasBiasGrad)
1176  {
1177  batch_offset_dbias = static_cast<long_index_t>(i_batch) * kargs.batch_stride_dbias;
1178  }
1179  if constexpr(kIsStoreRandval)
1180  {
1181  batch_offset_randval =
1182  static_cast<long_index_t>(i_batch) * kargs.batch_stride_randval;
1183  }
1184  }
1185 
1186  // for simplicity, batch stride we just modify the pointer
1187  const QDataType* q_ptr = reinterpret_cast<const QDataType*>(kargs.q_ptr) +
1188  static_cast<long_index_t>(i_nhead) * kargs.nhead_stride_q +
1189  batch_offset_q;
1190  const KDataType* k_ptr =
1191  reinterpret_cast<const KDataType*>(kargs.k_ptr) +
1192  static_cast<long_index_t>(i_nhead / kargs.nhead_ratio_qk) * kargs.nhead_stride_k +
1193  batch_offset_k;
1194  const VDataType* v_ptr =
1195  reinterpret_cast<const VDataType*>(kargs.v_ptr) +
1196  static_cast<long_index_t>(i_nhead / kargs.nhead_ratio_qk) * kargs.nhead_stride_v +
1197  batch_offset_v;
1198  const LSEDataType* lse_ptr = reinterpret_cast<const LSEDataType*>(kargs.lse_ptr) +
1199  static_cast<long_index_t>(i_nhead) * kargs.nhead_stride_lsed +
1200  batch_offset_lsed;
1201  const DDataType* d_ptr = reinterpret_cast<const DDataType*>(kargs.d_ptr) +
1202  static_cast<long_index_t>(i_nhead) * kargs.nhead_stride_lsed +
1203  batch_offset_lsed;
1204  const OGradDataType* do_ptr = reinterpret_cast<const OGradDataType*>(kargs.do_ptr) +
1205  static_cast<long_index_t>(i_nhead) * kargs.nhead_stride_do +
1206  batch_offset_do;
1207  KGradDataType* dk_ptr = reinterpret_cast<KGradDataType*>(kargs.dk_ptr) +
1208  static_cast<long_index_t>(i_nhead) * kargs.nhead_stride_dk +
1209  batch_offset_dk;
1210  VGradDataType* dv_ptr = reinterpret_cast<VGradDataType*>(kargs.dv_ptr) +
1211  static_cast<long_index_t>(i_nhead) * kargs.nhead_stride_dv +
1212  batch_offset_dv;
1213 
1214  // Q/K/V/LSE/D/dO/dQ/dK/dV DRAM and DRAM window
1215  const auto q_dram_naive = make_naive_tensor_view<address_space_enum::global>(
1216  q_ptr,
1217  make_tuple(kargs.seqlen_q, kargs.hdim_q),
1218  make_tuple(kargs.stride_q, 1),
1220  number<1>{});
1221  const auto q_dram = pad_tensor_view(
1222  q_dram_naive,
1225 
1226  const auto k_dram_naive = make_naive_tensor_view<address_space_enum::global>(
1227  k_ptr,
1228  make_tuple(kargs.seqlen_k, kargs.hdim_q),
1229  make_tuple(kargs.stride_k, 1),
1231  number<1>{});
1232  const auto k_dram = pad_tensor_view(
1233  k_dram_naive,
1236 
1237  const auto v_dram = [&]() {
1238  const auto v_dram_naive = make_naive_tensor_view<address_space_enum::global>(
1239  v_ptr,
1240  make_tuple(kargs.seqlen_k, kargs.hdim_v),
1241  make_tuple(kargs.stride_v, 1),
1243  number<1>{});
1244  return pad_tensor_view(
1245  v_dram_naive,
1248  }();
1249 
1250  const auto lse_dram = [&]() {
1251  const auto lse_dram_naive = make_naive_tensor_view_packed<address_space_enum::global>(
1252  lse_ptr, make_tuple(kargs.seqlen_q), number<1>{});
1253  return pad_tensor_view(
1255  }();
1256 
1257  const auto d_dram = [&]() {
1258  const auto d_dram_naive = make_naive_tensor_view_packed<address_space_enum::global>(
1259  d_ptr, make_tuple(kargs.seqlen_q), number<1>{});
1260  return pad_tensor_view(
1262  }();
1263 
1264  const auto do_dram_naive = make_naive_tensor_view<address_space_enum::global>(
1265  do_ptr,
1266  make_tuple(kargs.seqlen_q, kargs.hdim_v),
1267  make_tuple(kargs.stride_do, 1),
1269  number<1>{});
1270  const auto do_dram = pad_tensor_view(
1271  do_dram_naive,
1274 
1275  auto q_dram_window = make_tile_window(
1276  q_dram,
1278  {0, 0});
1279 
1280  auto k_dram_window = make_tile_window(
1281  k_dram,
1283  {i_n0, 0});
1284 
1285  auto v_dram_window = make_tile_window(
1286  v_dram,
1288  {i_n0, 0});
1289 
1290  auto do_dram_window = make_tile_window(
1291  do_dram,
1293  {0, 0});
1294 
1295  auto dq_dram_window = [&, i_tile_n_ = i_tile_n, i_nhead_ = i_nhead]() {
1296  if constexpr(kIsDeterministic)
1297  {
1298  AccDataType* dq_acc_ptr =
1299  reinterpret_cast<AccDataType*>(kargs.dq_acc_ptr) +
1300  static_cast<long_index_t>(i_nhead_) * kargs.nhead_stride_dq_acc +
1301  static_cast<long_index_t>(i_tile_n_) * kargs.split_stride_dq_acc +
1302  batch_offset_dq_acc;
1303 
1304  auto dq_acc_dram = [&]() {
1305  const auto dq_acc_dram_naive =
1306  make_naive_tensor_view<address_space_enum::global>(
1307  dq_acc_ptr,
1308  make_tuple(kargs.seqlen_q, kargs.hdim_q),
1309  make_tuple(kargs.stride_dq_acc, 1),
1311  number<1>{});
1312 
1313  return pad_tensor_view(
1314  dq_acc_dram_naive,
1317  }();
1318 
1319  return make_tile_window(
1320  dq_acc_dram,
1322  {0, 0});
1323  }
1324  else
1325  {
1326  AccDataType* dq_acc_ptr =
1327  reinterpret_cast<AccDataType*>(kargs.dq_acc_ptr) +
1328  static_cast<long_index_t>(i_nhead_) * kargs.nhead_stride_dq_acc +
1329  batch_offset_dq_acc;
1330 
1331  auto dq_acc_dram = [&]() {
1332  const auto dq_acc_dram_naive =
1335  dq_acc_ptr,
1336  make_tuple(kargs.seqlen_q, kargs.hdim_q),
1337  make_tuple(kargs.stride_dq_acc, 1),
1339  number<1>{});
1340 
1341  return pad_tensor_view(
1342  dq_acc_dram_naive,
1345  }();
1346 
1347  return make_tile_window(
1348  dq_acc_dram,
1350  {0, 0});
1351  }
1352  }();
1353 
1354  auto lse_dram_window =
1356 
1357  auto d_dram_window = make_tile_window(d_dram, make_tuple(number<FmhaPipeline::kM0>{}), {0});
1358 
1361  constexpr auto bias_dram_window_lengths =
1363  const auto bias_dram_window = [&, i_nhead_ = i_nhead]() {
1365  {
1366  const BiasDataType* bias_ptr =
1367  reinterpret_cast<const BiasDataType*>(kargs.bias_ptr) +
1368  static_cast<long_index_t>(i_nhead_) * kargs.nhead_stride_bias +
1369  batch_offset_bias;
1370 
1371  const auto bias_dram = [&]() {
1372  const auto bias_dram_naive = make_naive_tensor_view<address_space_enum::global>(
1373  bias_ptr,
1374  make_tuple(kargs.seqlen_q, kargs.seqlen_k),
1375  make_tuple(kargs.stride_bias, 1),
1377  number<1>{});
1378 
1379  return pad_tensor_view(bias_dram_naive,
1380  bias_dram_window_lengths,
1382  }();
1383 
1384  return make_tile_window(bias_dram, bias_dram_window_lengths, {0, i_n0});
1385  }
1386  else
1387  {
1388  return make_null_tile_window(bias_dram_window_lengths);
1389  }
1390  }();
1391 
1392  auto dbias_dram_window = [&, i_nhead_ = i_nhead]() {
1393  if constexpr(kHasBiasGrad)
1394  {
1395  BiasGradDataType* dbias_ptr =
1396  reinterpret_cast<BiasGradDataType*>(kargs.dbias_ptr) +
1397  static_cast<long_index_t>(i_nhead_) * kargs.nhead_stride_dbias +
1398  batch_offset_dbias;
1399 
1400  auto dbias_dram = [&]() {
1401  const auto dbias_dram_naive =
1402  make_naive_tensor_view<address_space_enum::global>(
1403  dbias_ptr,
1404  make_tuple(kargs.seqlen_q, kargs.seqlen_k),
1405  make_tuple(kargs.stride_dbias, 1),
1407  number<1>{});
1408 
1409  return pad_tensor_view(dbias_dram_naive,
1410  bias_dram_window_lengths,
1412  }();
1413 
1414  return make_tile_window(dbias_dram, bias_dram_window_lengths, {0, i_n0});
1415  }
1416  else
1417  {
1418  return make_null_tile_window(bias_dram_window_lengths);
1419  }
1420  }();
1421 
1422  // WA i_batch capture structure binding before c++20
1423  auto position_encoding = [&, i_batch_ = i_batch, i_nhead_ = i_nhead]() {
1424  if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI)
1425  {
1426  // data loading, shared by entire wg
1427  // TODO: how to use s_read?
1428  AccDataType slope = *(reinterpret_cast<const AccDataType*>(kargs.alibi_slope_ptr) +
1429  i_batch_ * kargs.alibi_slope_stride + i_nhead_);
1430  slope *= ck_tile::log2e_v<>;
1431  if constexpr(kHasMask)
1432  {
1433  return make_alibi_from_lr_mask<AccDataType, false>(slope,
1434  kargs.window_size_left,
1435  kargs.window_size_right,
1436  kargs.seqlen_q,
1437  kargs.seqlen_k,
1438  kargs.mask_type);
1439  }
1440  else
1441  {
1443  slope, kargs.seqlen_q, kargs.seqlen_k, AlibiMode::FROM_BOTTOM_RIGHT};
1444  }
1445  }
1446  else
1447  {
1449  }
1450  }();
1451 
1452  // dropout
1453  float rp_undrop = 1;
1454  float scale_rp_undrop = 1;
1455  if constexpr(kHasDropout)
1456  {
1457  rp_undrop = kargs.rp_undrop;
1458  scale_rp_undrop = kargs.scale_rp_undrop;
1459  }
1460  auto dropout = [&, i_nhead_ = i_nhead, i_batch_ = i_batch]() {
1461  if constexpr(kHasDropout)
1462  {
1463  return FmhaDropout{i_batch_,
1464  i_nhead_,
1465  kargs.num_head_q,
1466  kargs.is_drop_seed_offset_from_host ? kargs.drop_seed.val
1467  : *kargs.drop_seed.ptr,
1468  kargs.is_drop_seed_offset_from_host ? kargs.drop_offset.val
1469  : *kargs.drop_offset.ptr,
1470  kargs.rp_undrop,
1471  kargs.p_undrop_in_uint8_t};
1472  }
1473  else
1474  {
1475  return FmhaDropout{};
1476  };
1477  }();
1478 
1479  auto randval_dram_window = [&, i_nhead_ = i_nhead]() {
1480  constexpr auto randval_dram_window_lengths =
1482  if constexpr(kIsStoreRandval)
1483  {
1484  RandValOutputDataType* rand_val_ptr =
1485  reinterpret_cast<RandValOutputDataType*>(kargs.rand_val_ptr) +
1486  static_cast<long_index_t>(i_nhead_) * kargs.nhead_stride_randval +
1487  batch_offset_randval;
1488 
1489  const auto randval_dram = [&]() {
1490  const auto randval_dram_naive =
1491  make_naive_tensor_view<address_space_enum::global>(
1492  rand_val_ptr,
1493  make_tuple(kargs.seqlen_q, kargs.seqlen_k),
1494  make_tuple(kargs.stride_randval, 1),
1495  number<1>{},
1496  number<1>{});
1497 
1498  return pad_tensor_view(randval_dram_naive,
1499  randval_dram_window_lengths,
1501  }();
1502 
1503  return make_tile_window(randval_dram, randval_dram_window_lengths, {0, i_n0});
1504  }
1505  else
1506  {
1507  return make_null_tile_window(randval_dram_window_lengths);
1508  }
1509  }();
1510 
1511  FmhaMask mask = [&]() {
1512  if constexpr(kHasMask)
1513  return ck_tile::make_generic_attention_mask_from_lr_window<FmhaMask>(
1514  kargs.window_size_left,
1515  kargs.window_size_right,
1516  kargs.seqlen_q,
1517  kargs.seqlen_k,
1519  else
1520  return FmhaMask{kargs.seqlen_q, kargs.seqlen_k};
1521  }();
1522 
1523  auto [dk_acc_tile, dv_acc_tile] = FmhaPipeline{}(q_dram_window,
1524  k_dram_window,
1525  v_dram_window,
1526  bias_dram_window,
1527  randval_dram_window,
1528  do_dram_window,
1529  lse_dram_window,
1530  d_dram_window,
1531  dq_dram_window,
1532  dbias_dram_window,
1533  mask,
1534  position_encoding,
1535  kargs.raw_scale,
1536  kargs.scale,
1537  rp_undrop,
1538  scale_rp_undrop,
1539  smem_ptr,
1540  dropout);
1541 
1542  auto dk_dram = [&]() {
1543  const auto dk_dram_naive = make_naive_tensor_view<address_space_enum::global>(
1544  dk_ptr,
1545  make_tuple(kargs.seqlen_k, kargs.hdim_q),
1546  make_tuple(kargs.stride_dk, 1),
1548  number<1>{});
1549 
1550  return pad_tensor_view(
1551  dk_dram_naive,
1554  }();
1555 
1556  auto dv_dram = [&]() {
1557  const auto dv_dram_naive = make_naive_tensor_view<address_space_enum::global>(
1558  dv_ptr,
1559  make_tuple(kargs.seqlen_k, kargs.hdim_v),
1560  make_tuple(kargs.stride_dv, 1),
1562  number<1>{});
1563 
1564  return pad_tensor_view(
1565  dv_dram_naive,
1568  }();
1569 
1570  auto dk_dram_window = make_tile_window(
1571  dk_dram,
1573  {i_n0, 0});
1574 
1575  auto dv_dram_window = make_tile_window(
1576  dv_dram,
1578  {i_n0, 0});
1579 
1580  KGradEpiloguePipeline{}(dk_dram_window, dk_acc_tile);
1581  VGradEpiloguePipeline{}(dv_dram_window, dv_acc_tile);
1582  }
1583 };
1584 
1585 template <typename FmhaBwdOGradDotO_>
1587 {
1589  static constexpr ck_tile::index_t kBlockSize = FmhaBwdOGradDotO::kBlockSize;
1590  static constexpr ck_tile::index_t kBlockPerCu = FmhaBwdOGradDotO::kBlockPerCu;
1591  static constexpr ck_tile::index_t kM0 = kBlockSize;
1592  static constexpr ck_tile::index_t kVHeaddim = FmhaBwdOGradDotO::kVHeaddim;
1593 
1597 
1598  static constexpr bool kIsGroupMode = FmhaBwdOGradDotO::kIsGroupMode;
1599  static constexpr bool kPadSeqLenQ = FmhaBwdOGradDotO::kPadSeqLenQ;
1600  static constexpr bool kPadHeadDimV = FmhaBwdOGradDotO::kPadHeadDimV;
1601 
1602  // clang-format off
1603  template <typename T> struct t2s;
1604  template <> struct t2s<ck_tile::fp16_t> { static constexpr const char * name = "fp16"; };
1605  template <> struct t2s<ck_tile::bf16_t> { static constexpr const char * name = "bf16"; };
1606  // clang-format on
1607 
1608  CK_TILE_HOST static std::string GetName()
1609  {
1610  // sync with generate.py
1611  // clang-format off
1612 
1613  #define _SS_ std::string
1614  #define _TS_ std::to_string
1615  auto pn = [&] () {
1616  std::string n;
1617  if (kPadSeqLenQ) n += "s";
1618  if (kPadHeadDimV) n += "dv";
1619  return n.empty() ? n : std::string("p") + n; }();
1620  return
1621  _SS_("fmha_bwd_dot_do_o_d") + _TS_(kVHeaddim) + "_" + _SS_(t2s<ODataType>::name) +
1622  "_" + (kIsGroupMode ? "group" : "batch") + "_" +
1623  ("o" + _TS_(kBlockPerCu)) + (pn.empty() ? "" : "_" + pn);
1624  #undef _SS_
1625  #undef _TS_
1626  // clang-format on
1627  }
1628 
1629  // kargs use aggregate initializer, so no constructor will provided
1630  // use inheritance to minimize karg size
1631  // user need to use MakeKargs() function to create kargs.
1633  {
1634  const void* o_ptr;
1635  const void* do_ptr;
1636  void* d_ptr;
1637 
1638  float p_undrop;
1639 
1642 
1645 
1649  };
1650 
1652  {
1656  };
1657 
1659  {
1660  const int32_t* seqstart_q_ptr;
1661  };
1662 
1663  using Kargs = std::
1664  conditional_t<kIsGroupMode, FmhaBwdOGradDotOGroupModeKargs, FmhaBwdOGradDotOBatchModeKargs>;
1665 
1666  template <bool Cond = !kIsGroupMode>
1667  CK_TILE_HOST static constexpr std::enable_if_t<Cond, Kargs>
1668  MakeKargs(const void* o_ptr,
1669  const void* do_ptr,
1670  void* d_ptr,
1671  float p_undrop,
1672  ck_tile::index_t seqlen_q,
1673  ck_tile::index_t hdim_v,
1674  ck_tile::index_t stride_do,
1675  ck_tile::index_t stride_o,
1676  ck_tile::index_t nhead_stride_do,
1677  ck_tile::index_t nhead_stride_o,
1678  ck_tile::index_t nhead_stride_d,
1679  ck_tile::index_t batch_stride_do,
1680  ck_tile::index_t batch_stride_o,
1681  ck_tile::index_t batch_stride_d)
1682  {
1683  Kargs kargs{{o_ptr,
1684  do_ptr,
1685  d_ptr,
1686  p_undrop,
1687  seqlen_q,
1688  hdim_v,
1689  stride_do,
1690  stride_o,
1691  nhead_stride_do,
1692  nhead_stride_o,
1693  nhead_stride_d},
1694  batch_stride_do,
1695  batch_stride_o,
1696  batch_stride_d};
1697 
1698  return kargs;
1699  }
1700 
1701  template <bool Cond = kIsGroupMode>
1702  CK_TILE_HOST static constexpr std::enable_if_t<Cond, Kargs>
1703  MakeKargs(const void* o_ptr,
1704  const void* do_ptr,
1705  void* d_ptr,
1706  float p_undrop,
1707  const void* seqstart_q_ptr,
1708  ck_tile::index_t hdim_v,
1709  ck_tile::index_t stride_do,
1710  ck_tile::index_t stride_o,
1711  ck_tile::index_t nhead_stride_do,
1712  ck_tile::index_t nhead_stride_o,
1713  ck_tile::index_t nhead_stride_d)
1714  {
1715  Kargs kargs{{o_ptr,
1716  do_ptr,
1717  d_ptr,
1718  p_undrop,
1719  -1, // seqlen will be updated by another pointer
1720  hdim_v,
1721  stride_do,
1722  stride_o,
1723  nhead_stride_do,
1724  nhead_stride_o,
1725  nhead_stride_d},
1726  reinterpret_cast<const int32_t*>(seqstart_q_ptr)};
1727 
1728  return kargs;
1729  }
1730 
1731  CK_TILE_HOST static constexpr auto
1733  {
1734  return dim3(ck_tile::integer_divide_ceil(seqlen_q_, kM0), nhead_, batch_size_);
1735  }
1736 
1737  CK_TILE_DEVICE static constexpr auto GetTileIndex()
1738  {
1739  const index_t i_block = blockIdx.x;
1740  const index_t i_nhead = blockIdx.y;
1741  const index_t i_batch = blockIdx.z;
1742 
1743  return ck_tile::make_tuple(i_block, i_nhead, i_batch);
1744  }
1745 
1746  CK_TILE_HOST static constexpr auto BlockSize() { return dim3(kBlockSize); }
1747 
1748  CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize() { return 0; }
1749 
1750  CK_TILE_DEVICE void operator()(Kargs kargs) const
1751  {
1752  // divide problem
1753  const auto [i_tile_m, i_nhead, i_batch] = GetTileIndex();
1754 
1755  const index_t i_m0 = __builtin_amdgcn_readfirstlane(i_tile_m * kM0);
1756 
1757  long_index_t batch_offset_o = 0;
1758  long_index_t batch_offset_do = 0;
1759  long_index_t batch_offset_d = 0;
1760 
1761  if constexpr(kIsGroupMode)
1762  {
1763  // get starting offset for each batch
1764  const long_index_t query_start = kargs.seqstart_q_ptr[i_batch];
1765 
1766  batch_offset_o = query_start * kargs.stride_o;
1767  batch_offset_do = query_start * kargs.stride_do;
1768  batch_offset_d = query_start;
1769 
1770  // get real # queries & # keys under group mode
1771  const auto adjusted_seqstart_q_ptr = kargs.seqstart_q_ptr + i_batch;
1772  kargs.seqlen_q = adjusted_seqstart_q_ptr[1] - adjusted_seqstart_q_ptr[0];
1773  // # of required blocks is different in each groups, terminate unnecessary blocks
1774  // earlier
1775  if(kargs.seqlen_q <= i_m0)
1776  {
1777  return;
1778  }
1779  }
1780  else
1781  {
1782  batch_offset_o = static_cast<long_index_t>(i_batch) * kargs.batch_stride_o;
1783  batch_offset_do = static_cast<long_index_t>(i_batch) * kargs.batch_stride_do;
1784  batch_offset_d = static_cast<long_index_t>(i_batch) * kargs.batch_stride_d;
1785  }
1786 
1787  // for simplicity, batch stride we just modify the pointer
1788  const ODataType* o_ptr = reinterpret_cast<const ODataType*>(kargs.o_ptr) +
1789  static_cast<long_index_t>(i_nhead) * kargs.nhead_stride_o +
1790  batch_offset_o;
1791  const OGradDataType* do_ptr = reinterpret_cast<const OGradDataType*>(kargs.do_ptr) +
1792  static_cast<long_index_t>(i_nhead) * kargs.nhead_stride_do +
1793  batch_offset_do;
1794  DDataType* d_ptr = reinterpret_cast<DDataType*>(kargs.d_ptr) +
1795  static_cast<long_index_t>(i_nhead) * kargs.nhead_stride_d +
1796  batch_offset_d;
1797 
1798  // O/dO/D DRAM and DRAM window
1799  const auto o_dram = [&]() {
1800  auto o_dram_naive = make_naive_tensor_view<address_space_enum::global>(
1801  o_ptr,
1802  make_tuple(kargs.seqlen_q, kargs.hdim_v),
1803  make_tuple(kargs.stride_o, 1),
1805  number<1>{});
1806  return pad_tensor_view(o_dram_naive,
1809  }();
1810  const auto do_dram = [&]() {
1811  auto do_dram_naive = make_naive_tensor_view<address_space_enum::global>(
1812  do_ptr,
1813  make_tuple(kargs.seqlen_q, kargs.hdim_v),
1814  make_tuple(kargs.stride_do, 1),
1816  number<1>{});
1817  return pad_tensor_view(do_dram_naive,
1820  }();
1821  auto d_dram = [&]() {
1822  const auto d_dram_naive = make_naive_tensor_view_packed<address_space_enum::global>(
1823  d_ptr, make_tuple(kargs.seqlen_q), number<1>{});
1824  return pad_tensor_view(
1825  d_dram_naive, make_tuple(number<kM0>{}), sequence<kPadSeqLenQ>{});
1826  }();
1827 
1828  auto o_dram_window =
1829  make_tile_window(o_dram, make_tuple(number<kM0>{}, number<kVHeaddim>{}), {i_m0, 0});
1830 
1831  auto do_dram_window =
1832  make_tile_window(do_dram, make_tuple(number<kM0>{}, number<kVHeaddim>{}), {i_m0, 0});
1833 
1834  auto d_dram_window = make_tile_window(d_dram, make_tuple(number<kM0>{}), {i_m0});
1835 
1836  FmhaBwdOGradDotO{}(o_dram_window, do_dram_window, d_dram_window, kargs.p_undrop);
1837  }
1838 };
1839 
1840 template <typename FmhaBwdConvertQGrad_>
1842 {
1844  static constexpr ck_tile::index_t kBlockSize = FmhaBwdConvertQGrad::kBlockSize;
1845  static constexpr ck_tile::index_t kBlockPerCu = FmhaBwdConvertQGrad::kBlockPerCu;
1846  static constexpr ck_tile::index_t kM0 = FmhaBwdConvertQGrad::kM0;
1847  static constexpr ck_tile::index_t kN0 = FmhaBwdConvertQGrad::kN0;
1848  static constexpr ck_tile::index_t kQKHeaddim = FmhaBwdConvertQGrad::kQKHeaddim;
1849 
1852 
1853  static constexpr bool kIsGroupMode = FmhaBwdConvertQGrad::kIsGroupMode;
1854  static constexpr bool kPadSeqLenQ = FmhaBwdConvertQGrad::kPadSeqLenQ;
1855  static constexpr bool kPadHeadDimQ = FmhaBwdConvertQGrad::kPadHeadDimQ;
1856  static constexpr bool kIsDeterministic = FmhaBwdConvertQGrad::kIsDeterministic;
1857 
1858  // clang-format off
1859  template <typename T> struct t2s;
1860  template <> struct t2s<ck_tile::fp16_t> { static constexpr const char * name = "fp16"; };
1861  template <> struct t2s<ck_tile::bf16_t> { static constexpr const char * name = "bf16"; };
1862  // clang-format on
1863 
1864  CK_TILE_HOST static std::string GetName()
1865  {
1866  // sync with generate.py
1867  // clang-format off
1868 
1869  #define _SS_ std::string
1870  #define _TS_ std::to_string
1871  auto pn = [&] () {
1872  std::string n;
1873  if (kPadSeqLenQ) n += "s";
1874  if (kPadHeadDimQ) n += "d";
1875  return n.empty() ? n : std::string("p") + n; }();
1876  return
1877  _SS_("fmha_bwd_convert_dq_d") + _TS_(kQKHeaddim) + "_" + _SS_(t2s<QGradDataType>::name) +
1878  "_" + (kIsGroupMode ? "group" : "batch") + (kIsDeterministic ? "_deterministic" : "") + "_" +
1879  ("o" + _TS_(kBlockPerCu)) + (pn.empty() ? "" : "_" + pn);
1880  #undef _SS_
1881  #undef _TS_
1882  // clang-format on
1883  }
1884 
1885  // to avoid duplicated base class prblem, introduce an template arg
1886  template <ck_tile::index_t I>
1888  {
1889  };
1890 
1891  // kargs use aggregate initializer, so no constructor will provided
1892  // use inheritance to minimize karg size
1893  // user need to use MakeKargs() function to create kargs.
1895  {
1896  const void* dq_acc_ptr;
1897  void* dq_ptr;
1898 
1902 
1907  };
1908 
1910  {
1912  };
1913 
1916  std::conditional_t<kIsDeterministic,
1917  FmhaBwdConvertQGradDeterministicKargs,
1918  FmhaBwdConvertQGradEmptyKargs<0>>
1919  {
1922  };
1923 
1926  std::conditional_t<kIsDeterministic,
1927  FmhaBwdConvertQGradDeterministicKargs,
1928  FmhaBwdConvertQGradEmptyKargs<0>>
1929  {
1930  const int32_t* seqstart_q_ptr;
1931  const int32_t* seqstart_k_ptr;
1932  };
1933 
1937 
1938  template <bool Cond = !kIsGroupMode>
1939  CK_TILE_HOST static constexpr std::enable_if_t<Cond, Kargs>
1940  MakeKargs(const void* dq_acc_ptr,
1941  void* dq_ptr,
1942  ck_tile::index_t seqlen_q,
1943  ck_tile::index_t seqlen_k,
1944  ck_tile::index_t hdim_q,
1945  ck_tile::index_t stride_dq,
1946  ck_tile::index_t stride_dq_acc,
1947  ck_tile::index_t nhead_stride_dq,
1948  ck_tile::index_t nhead_stride_dq_acc,
1949  ck_tile::index_t batch_stride_dq,
1950  ck_tile::index_t batch_stride_dq_acc,
1951  ck_tile::index_t split_stride_dq_acc)
1952  {
1953  Kargs kargs{{dq_acc_ptr,
1954  dq_ptr,
1955  seqlen_q,
1956  seqlen_k,
1957  hdim_q,
1958  stride_dq,
1959  stride_dq_acc,
1960  nhead_stride_dq,
1961  nhead_stride_dq_acc},
1962  {},
1963  batch_stride_dq,
1964  batch_stride_dq_acc};
1965 
1966  if constexpr(kIsDeterministic)
1967  {
1968  kargs.split_stride_dq_acc = split_stride_dq_acc;
1969  }
1970 
1971  return kargs;
1972  }
1973 
1974  template <bool Cond = kIsGroupMode>
1975  CK_TILE_HOST static constexpr std::enable_if_t<Cond, Kargs>
1976  MakeKargs(const void* dq_acc_ptr,
1977  void* dq_ptr,
1978  const void* seqstart_q_ptr,
1979  const void* seqstart_k_ptr,
1980  ck_tile::index_t hdim_q,
1981  ck_tile::index_t stride_dq,
1982  ck_tile::index_t stride_dq_acc,
1983  ck_tile::index_t nhead_stride_dq,
1984  ck_tile::index_t nhead_stride_dq_acc,
1985  ck_tile::index_t split_stride_dq_acc)
1986  {
1987  Kargs kargs{{dq_acc_ptr,
1988  dq_ptr,
1989  -1, // seqlen will be updated by another pointer
1990  -1, //
1991  hdim_q,
1992  stride_dq,
1993  stride_dq_acc,
1994  nhead_stride_dq,
1995  nhead_stride_dq_acc},
1996  {},
1997  reinterpret_cast<const int32_t*>(seqstart_q_ptr),
1998  reinterpret_cast<const int32_t*>(seqstart_k_ptr)};
1999 
2000  if constexpr(kIsDeterministic)
2001  {
2002  kargs.split_stride_dq_acc = split_stride_dq_acc;
2003  }
2004 
2005  return kargs;
2006  }
2007 
2008  CK_TILE_HOST static constexpr auto
2010  {
2011  return dim3(ck_tile::integer_divide_ceil(seqlen_q_, kM0), nhead_, batch_size_);
2012  }
2013 
2014  CK_TILE_DEVICE static constexpr auto GetTileIndex()
2015  {
2016  const index_t i_block = blockIdx.x;
2017  const index_t i_nhead = blockIdx.y;
2018  const index_t i_batch = blockIdx.z;
2019 
2020  return ck_tile::make_tuple(i_block, i_nhead, i_batch);
2021  }
2022 
2023  CK_TILE_HOST static constexpr auto BlockSize() { return dim3(kBlockSize); }
2024 
2025  CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize() { return 0; }
2026 
2027  CK_TILE_DEVICE void operator()(Kargs kargs) const
2028  {
2029  // divide problem
2030  const auto [i_tile_m, i_nhead, i_batch] = GetTileIndex();
2031 
2032  const index_t i_m0 = __builtin_amdgcn_readfirstlane(i_tile_m * kM0);
2033 
2034  long_index_t batch_offset_dq = 0;
2035  long_index_t batch_offset_dq_acc = 0;
2036  if constexpr(kIsGroupMode)
2037  {
2038  // get starting offset for each batch
2039  const long_index_t query_start = kargs.seqstart_q_ptr[i_batch];
2040  batch_offset_dq = query_start * kargs.stride_dq;
2041  batch_offset_dq_acc = query_start * kargs.stride_dq_acc;
2042 
2043  // get real # queries & # keys under group mode
2044  const auto adjusted_seqstart_q_ptr = kargs.seqstart_q_ptr + i_batch;
2045  kargs.seqlen_q = adjusted_seqstart_q_ptr[1] - adjusted_seqstart_q_ptr[0];
2046  if constexpr(kIsDeterministic)
2047  {
2048  const auto adjusted_seqstart_k_ptr = kargs.seqstart_k_ptr + i_batch;
2049  kargs.seqlen_k = adjusted_seqstart_k_ptr[1] - adjusted_seqstart_k_ptr[0];
2050  }
2051  // # of required blocks is different in each groups, terminate unnecessary blocks
2052  // earlier
2053  if(kargs.seqlen_q <= i_m0)
2054  {
2055  return;
2056  }
2057  }
2058  else
2059  {
2060  batch_offset_dq = static_cast<long_index_t>(i_batch) * kargs.batch_stride_dq;
2061  batch_offset_dq_acc = static_cast<long_index_t>(i_batch) * kargs.batch_stride_dq_acc;
2062  }
2063 
2064  // for simplicity, batch stride we just modify the pointer
2065  QGradDataType* dq_ptr = reinterpret_cast<QGradDataType*>(kargs.dq_ptr) +
2066  static_cast<long_index_t>(i_nhead) * kargs.nhead_stride_dq +
2067  batch_offset_dq;
2068 
2069  // dQAcc/dQ DRAM and DRAM window
2070  const auto dq_acc_dram = [&, i_nhead_ = i_nhead]() {
2071  if constexpr(kIsDeterministic)
2072  {
2073  const AccDataType* dq_acc_ptr =
2074  reinterpret_cast<const AccDataType*>(kargs.dq_acc_ptr) +
2075  static_cast<long_index_t>(i_nhead_) * (kargs.nhead_stride_dq_acc) +
2076  batch_offset_dq_acc;
2077 
2078  const index_t nsplits = ck_tile::integer_divide_ceil(kargs.seqlen_k, kN0);
2079 
2080  auto dq_acc_dram_naive = make_naive_tensor_view<address_space_enum::global>(
2081  dq_acc_ptr,
2082  make_tuple(nsplits, kargs.seqlen_q, kargs.hdim_q),
2083  make_tuple(kargs.split_stride_dq_acc, kargs.stride_dq_acc, 1),
2085  number<1>{});
2086  return pad_tensor_view(dq_acc_dram_naive,
2089  }
2090  else
2091  {
2092  const AccDataType* dq_acc_ptr =
2093  reinterpret_cast<const AccDataType*>(kargs.dq_acc_ptr) +
2094  static_cast<long_index_t>(i_nhead_) * (kargs.nhead_stride_dq_acc) +
2095  batch_offset_dq_acc;
2096 
2097  auto dq_acc_dram_naive = make_naive_tensor_view<address_space_enum::global>(
2098  dq_acc_ptr,
2099  make_tuple(kargs.seqlen_q, kargs.hdim_q),
2100  make_tuple(kargs.stride_dq_acc, 1),
2102  number<1>{});
2103  return pad_tensor_view(dq_acc_dram_naive,
2106  }
2107  }();
2108 
2109  auto dq_dram = [&]() {
2110  auto dq_dram_naive = make_naive_tensor_view<address_space_enum::global>(
2111  dq_ptr,
2112  make_tuple(kargs.seqlen_q, kargs.hdim_q),
2113  make_tuple(kargs.stride_dq, 1),
2115  number<1>{});
2116  return pad_tensor_view(dq_dram_naive,
2119  }();
2120 
2121  auto dq_acc_dram_window = [&]() {
2122  if constexpr(kIsDeterministic)
2123  {
2124  return make_tile_window(
2125  dq_acc_dram,
2127  {0, i_m0, 0});
2128  }
2129  else
2130  {
2131  return make_tile_window(
2132  dq_acc_dram, make_tuple(number<kM0>{}, number<kQKHeaddim>{}), {i_m0, 0});
2133  }
2134  }();
2135 
2136  auto dq_dram_window =
2137  make_tile_window(dq_dram, make_tuple(number<kM0>{}, number<kQKHeaddim>{}), {i_m0, 0});
2138 
2139  if constexpr(kIsDeterministic)
2140  {
2141  const index_t nsplits = ck_tile::integer_divide_ceil(kargs.seqlen_k, kN0);
2142  FmhaBwdConvertQGrad{}(dq_acc_dram_window, dq_dram_window, nsplits);
2143  }
2144  else
2145  {
2146  FmhaBwdConvertQGrad{}(dq_acc_dram_window, dq_dram_window);
2147  }
2148  }
2149 };
2150 
2151 } // namespace ck_tile
#define CK_TILE_DEVICE
Definition: config.hpp:40
#define CK_TILE_HOST
Definition: config.hpp:39
#define CK_TILE_HOST_DEVICE
Definition: config.hpp:41
#define _TS_
#define _SS_
__host__ constexpr __device__ T max(T x)
Definition: math.hpp:84
__host__ T floor(T x)
Definition: math_v2.hpp:367
Definition: cluster_descriptor.hpp:13
constexpr CK_TILE_DEVICE auto make_null_tile_window(const WindowLengths &window_lengths)
Definition: null_tile_window.hpp:63
constexpr CK_TILE_HOST_DEVICE auto integer_divide_ceil(X x, Y y)
Definition: math.hpp:149
constexpr CK_TILE_HOST_DEVICE auto make_naive_tensor_view(DataType *p, const tuple< Lengths... > &lengths, const tuple< Strides... > &strides, number< GuaranteedLastDimensionVectorLength >=number<-1 >{}, number< GuaranteedLastDimensionVectorStride >=number<-1 >{})
Definition: tensor_view.hpp:424
_Float16 fp16_t
Definition: half.hpp:110
bfloat16_t bf16_t
Definition: bfloat16.hpp:106
int32_t index_t
Definition: integer.hpp:9
constexpr CK_TILE_HOST_DEVICE auto pad_tensor_view(const TensorView &tensor_view, const TileLengths &tile_lengths, DoPads)
Definition: tensor_view.hpp:480
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition: type_traits.hpp:20
int64_t long_index_t
Definition: integer.hpp:10
constexpr CK_TILE_DEVICE auto make_tile_window(null_tensor_view, const WindowLengths &window_lengths, const multi_index< WindowLengths::size()> &, Ts &&...)
Definition: null_tile_window.hpp:72
constexpr CK_TILE_HOST_DEVICE auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:337
GenericAttentionMaskEnum
Definition: block_masking.hpp:11
constexpr CK_TILE_HOST_DEVICE T max(T x)
Definition: math.hpp:161
typename conditional< predicate, X, Y >::type conditional_t
Definition: functional.hpp:115
Definition: block_position_encoding.hpp:48
Definition: block_attention_bias_enum.hpp:19
Definition: block_position_encoding.hpp:137
ck_tile::index_t batch_stride_dq
Definition: fmha_bwd_kernel.hpp:1920
ck_tile::index_t batch_stride_dq_acc
Definition: fmha_bwd_kernel.hpp:1921
ck_tile::index_t seqlen_q
Definition: fmha_bwd_kernel.hpp:1899
ck_tile::index_t stride_dq_acc
Definition: fmha_bwd_kernel.hpp:1904
ck_tile::index_t nhead_stride_dq
Definition: fmha_bwd_kernel.hpp:1905
ck_tile::index_t hdim_q
Definition: fmha_bwd_kernel.hpp:1901
ck_tile::index_t stride_dq
Definition: fmha_bwd_kernel.hpp:1903
const void * dq_acc_ptr
Definition: fmha_bwd_kernel.hpp:1896
ck_tile::index_t seqlen_k
Definition: fmha_bwd_kernel.hpp:1900
ck_tile::index_t nhead_stride_dq_acc
Definition: fmha_bwd_kernel.hpp:1906
ck_tile::index_t split_stride_dq_acc
Definition: fmha_bwd_kernel.hpp:1911
const int32_t * seqstart_k_ptr
Definition: fmha_bwd_kernel.hpp:1931
const int32_t * seqstart_q_ptr
Definition: fmha_bwd_kernel.hpp:1930
Definition: fmha_bwd_kernel.hpp:1859
Definition: fmha_bwd_kernel.hpp:1842
static constexpr bool kIsGroupMode
Definition: fmha_bwd_kernel.hpp:1853
static constexpr bool kPadHeadDimQ
Definition: fmha_bwd_kernel.hpp:1855
static constexpr bool kIsDeterministic
Definition: fmha_bwd_kernel.hpp:1856
static constexpr ck_tile::index_t kBlockSize
Definition: fmha_bwd_kernel.hpp:1844
static constexpr bool kPadSeqLenQ
Definition: fmha_bwd_kernel.hpp:1854
static constexpr CK_TILE_HOST std::enable_if_t< Cond, Kargs > MakeKargs(const void *dq_acc_ptr, void *dq_ptr, const void *seqstart_q_ptr, const void *seqstart_k_ptr, ck_tile::index_t hdim_q, ck_tile::index_t stride_dq, ck_tile::index_t stride_dq_acc, ck_tile::index_t nhead_stride_dq, ck_tile::index_t nhead_stride_dq_acc, ck_tile::index_t split_stride_dq_acc)
Definition: fmha_bwd_kernel.hpp:1976
static constexpr ck_tile::index_t kBlockPerCu
Definition: fmha_bwd_kernel.hpp:1845
static constexpr CK_TILE_DEVICE auto GetTileIndex()
Definition: fmha_bwd_kernel.hpp:2014
static constexpr CK_TILE_HOST std::enable_if_t< Cond, Kargs > MakeKargs(const void *dq_acc_ptr, void *dq_ptr, ck_tile::index_t seqlen_q, ck_tile::index_t seqlen_k, ck_tile::index_t hdim_q, ck_tile::index_t stride_dq, ck_tile::index_t stride_dq_acc, ck_tile::index_t nhead_stride_dq, ck_tile::index_t nhead_stride_dq_acc, ck_tile::index_t batch_stride_dq, ck_tile::index_t batch_stride_dq_acc, ck_tile::index_t split_stride_dq_acc)
Definition: fmha_bwd_kernel.hpp:1940
static constexpr CK_TILE_HOST_DEVICE ck_tile::index_t GetSmemSize()
Definition: fmha_bwd_kernel.hpp:2025
CK_TILE_DEVICE void operator()(Kargs kargs) const
Definition: fmha_bwd_kernel.hpp:2027
ck_tile::remove_cvref_t< typename FmhaBwdConvertQGrad::AccDataType > AccDataType
Definition: fmha_bwd_kernel.hpp:1850
static constexpr ck_tile::index_t kM0
Definition: fmha_bwd_kernel.hpp:1846
ck_tile::remove_cvref_t< typename FmhaBwdConvertQGrad::QGradDataType > QGradDataType
Definition: fmha_bwd_kernel.hpp:1851
static constexpr CK_TILE_HOST auto GridSize(ck_tile::index_t batch_size_, ck_tile::index_t nhead_, ck_tile::index_t seqlen_q_)
Definition: fmha_bwd_kernel.hpp:2009
static constexpr ck_tile::index_t kN0
Definition: fmha_bwd_kernel.hpp:1847
ck_tile::remove_cvref_t< FmhaBwdConvertQGrad_ > FmhaBwdConvertQGrad
Definition: fmha_bwd_kernel.hpp:1843
std::conditional_t< kIsGroupMode, FmhaBwdConvertQGradGroupModeKargs, FmhaBwdConvertQGradBatchModeKargs > Kargs
Definition: fmha_bwd_kernel.hpp:1936
static constexpr ck_tile::index_t kQKHeaddim
Definition: fmha_bwd_kernel.hpp:1848
static constexpr CK_TILE_HOST auto BlockSize()
Definition: fmha_bwd_kernel.hpp:2023
static CK_TILE_HOST std::string GetName()
Definition: fmha_bwd_kernel.hpp:1864
Definition: fmha_bwd_kernel.hpp:176
ck_tile::index_t alibi_slope_stride
Definition: fmha_bwd_kernel.hpp:179
const void * alibi_slope_ptr
Definition: fmha_bwd_kernel.hpp:178
ck_tile::index_t batch_stride_dbias
Definition: fmha_bwd_kernel.hpp:191
ck_tile::index_t batch_stride_bias
Definition: fmha_bwd_kernel.hpp:172
ck_tile::index_t batch_stride_randval
Definition: fmha_bwd_kernel.hpp:256
Definition: fmha_bwd_kernel.hpp:275
ck_tile::index_t batch_stride_dq_acc
Definition: fmha_bwd_kernel.hpp:281
ck_tile::index_t batch_stride_v
Definition: fmha_bwd_kernel.hpp:278
ck_tile::index_t batch_stride_k
Definition: fmha_bwd_kernel.hpp:277
ck_tile::index_t batch_stride_lsed
Definition: fmha_bwd_kernel.hpp:280
ck_tile::index_t batch_stride_dv
Definition: fmha_bwd_kernel.hpp:283
ck_tile::index_t batch_stride_dk
Definition: fmha_bwd_kernel.hpp:282
ck_tile::index_t batch_stride_do
Definition: fmha_bwd_kernel.hpp:279
ck_tile::index_t batch_stride_q
Definition: fmha_bwd_kernel.hpp:276
ck_tile::index_t stride_dbias
Definition: fmha_bwd_kernel.hpp:185
void * dbias_ptr
Definition: fmha_bwd_kernel.hpp:184
ck_tile::index_t nhead_stride_dbias
Definition: fmha_bwd_kernel.hpp:186
Definition: fmha_bwd_kernel.hpp:164
ck_tile::index_t stride_bias
Definition: fmha_bwd_kernel.hpp:166
const void * bias_ptr
Definition: fmha_bwd_kernel.hpp:165
ck_tile::index_t nhead_stride_bias
Definition: fmha_bwd_kernel.hpp:167
ck_tile::index_t stride_randval
Definition: fmha_bwd_kernel.hpp:250
void init_dropout(float p_drop, uint64_t seed, uint64_t offset, float raw_scale)
Definition: fmha_bwd_kernel.hpp:216
float scale_rp_undrop
Definition: fmha_bwd_kernel.hpp:246
float rp_undrop
Definition: fmha_bwd_kernel.hpp:245
ck_tile::index_t nhead_stride_randval
Definition: fmha_bwd_kernel.hpp:251
uint8_t p_undrop_in_uint8_t
Definition: fmha_bwd_kernel.hpp:247
void * rand_val_ptr
Definition: fmha_bwd_kernel.hpp:248
void init_dropout(float p_drop, const uint64_t *seed_ptr, const uint64_t *offset_ptr, float raw_scale)
Definition: fmha_bwd_kernel.hpp:229
Definition: fmha_bwd_kernel.hpp:122
float raw_scale
Definition: fmha_bwd_kernel.hpp:142
void * dq_acc_ptr
Definition: fmha_bwd_kernel.hpp:129
ck_tile::index_t nhead_stride_do
Definition: fmha_bwd_kernel.hpp:156
ck_tile::index_t stride_dk
Definition: fmha_bwd_kernel.hpp:150
ck_tile::index_t hdim_v
Definition: fmha_bwd_kernel.hpp:136
void * dv_ptr
Definition: fmha_bwd_kernel.hpp:131
ck_tile::index_t nhead_stride_dv
Definition: fmha_bwd_kernel.hpp:160
ck_tile::index_t seqlen_q
Definition: fmha_bwd_kernel.hpp:133
ck_tile::index_t nhead_stride_lsed
Definition: fmha_bwd_kernel.hpp:157
ck_tile::index_t nhead_stride_q
Definition: fmha_bwd_kernel.hpp:153
ck_tile::index_t nhead_stride_dk
Definition: fmha_bwd_kernel.hpp:159
ck_tile::index_t stride_v
Definition: fmha_bwd_kernel.hpp:147
ck_tile::index_t stride_dq_acc
Definition: fmha_bwd_kernel.hpp:149
ck_tile::index_t num_head_q
Definition: fmha_bwd_kernel.hpp:140
ck_tile::index_t nhead_stride_k
Definition: fmha_bwd_kernel.hpp:154
ck_tile::index_t nhead_stride_dq_acc
Definition: fmha_bwd_kernel.hpp:158
const void * lse_ptr
Definition: fmha_bwd_kernel.hpp:126
const void * d_ptr
Definition: fmha_bwd_kernel.hpp:128
ck_tile::index_t nhead_ratio_qk
Definition: fmha_bwd_kernel.hpp:141
ck_tile::index_t stride_k
Definition: fmha_bwd_kernel.hpp:146
const void * do_ptr
Definition: fmha_bwd_kernel.hpp:127
float scale
Definition: fmha_bwd_kernel.hpp:143
void * dk_ptr
Definition: fmha_bwd_kernel.hpp:130
const void * v_ptr
Definition: fmha_bwd_kernel.hpp:125
ck_tile::index_t stride_do
Definition: fmha_bwd_kernel.hpp:148
const void * k_ptr
Definition: fmha_bwd_kernel.hpp:124
const void * q_ptr
Definition: fmha_bwd_kernel.hpp:123
ck_tile::index_t nhead_stride_v
Definition: fmha_bwd_kernel.hpp:155
ck_tile::index_t stride_dv
Definition: fmha_bwd_kernel.hpp:151
ck_tile::index_t stride_q
Definition: fmha_bwd_kernel.hpp:145
ck_tile::index_t seqlen_k
Definition: fmha_bwd_kernel.hpp:134
ck_tile::index_t hdim_q
Definition: fmha_bwd_kernel.hpp:135
ck_tile::index_t split_stride_dq_acc
Definition: fmha_bwd_kernel.hpp:261
ValueOrPointer< uint64_t > drop_seed
Definition: fmha_bwd_kernel.hpp:209
bool is_drop_seed_offset_from_host
Definition: fmha_bwd_kernel.hpp:211
ValueOrPointer< uint64_t > drop_offset
Definition: fmha_bwd_kernel.hpp:210
Definition: fmha_bwd_kernel.hpp:115
Definition: fmha_bwd_kernel.hpp:297
const int32_t * seqstart_k_ptr
Definition: fmha_bwd_kernel.hpp:299
const int32_t * seqlen_k_ptr
Definition: fmha_bwd_kernel.hpp:300
const int32_t * seqstart_q_ptr
Definition: fmha_bwd_kernel.hpp:298
Definition: fmha_bwd_kernel.hpp:195
ck_tile::GenericAttentionMaskEnum mask_type
Definition: fmha_bwd_kernel.hpp:197
ck_tile::index_t window_size_right
Definition: fmha_bwd_kernel.hpp:196
ck_tile::index_t window_size_left
Definition: fmha_bwd_kernel.hpp:196
Definition: fmha_bwd_kernel.hpp:69
Definition: fmha_bwd_kernel.hpp:31
static constexpr CK_TILE_HOST std::enable_if_t< Cond, Kargs > MakeKargsImpl(const void *q_ptr, const void *k_ptr, const void *v_ptr, const void *bias_ptr, const void *lse_ptr, const void *do_ptr, const void *d_ptr, void *rand_val_ptr, void *dk_ptr, void *dv_ptr, void *dbias_ptr, void *dq_acc_ptr, ck_tile::index_t seqlen_q, ck_tile::index_t seqlen_k, ck_tile::index_t hdim_q, ck_tile::index_t hdim_v, ck_tile::index_t num_head_q, ck_tile::index_t nhead_ratio_qk, float scale, ck_tile::index_t stride_q, ck_tile::index_t stride_k, ck_tile::index_t stride_v, ck_tile::index_t stride_bias, ck_tile::index_t stride_randval, ck_tile::index_t stride_do, ck_tile::index_t stride_dq_acc, ck_tile::index_t stride_dk, ck_tile::index_t stride_dv, ck_tile::index_t stride_dbias, ck_tile::index_t nhead_stride_q, ck_tile::index_t nhead_stride_k, ck_tile::index_t nhead_stride_v, ck_tile::index_t nhead_stride_bias, ck_tile::index_t nhead_stride_randval, ck_tile::index_t nhead_stride_do, ck_tile::index_t nhead_stride_lsed, ck_tile::index_t nhead_stride_dq_acc, ck_tile::index_t nhead_stride_dk, ck_tile::index_t nhead_stride_dv, ck_tile::index_t nhead_stride_dbias, ck_tile::index_t batch_stride_q, ck_tile::index_t batch_stride_k, ck_tile::index_t batch_stride_v, ck_tile::index_t batch_stride_bias, ck_tile::index_t batch_stride_randval, ck_tile::index_t batch_stride_do, ck_tile::index_t batch_stride_lsed, ck_tile::index_t batch_stride_dq_acc, ck_tile::index_t batch_stride_dk, ck_tile::index_t batch_stride_dv, ck_tile::index_t batch_stride_dbias, ck_tile::index_t split_stride_dq_acc, ck_tile::index_t window_size_left, ck_tile::index_t window_size_right, ck_tile::index_t mask_type, float p_drop, std::variant< std::pair< uint64_t, uint64_t >, std::pair< const void *, const void * >> drop_seed_offset)
Definition: fmha_bwd_kernel.hpp:307
static constexpr bool kHasDropout
Definition: fmha_bwd_kernel.hpp:64
static constexpr CK_TILE_HOST std::enable_if_t< Cond, Kargs > MakeKargs(const void *q_ptr, const void *k_ptr, const void *v_ptr, const void *bias_ptr, const void *lse_ptr, const void *do_ptr, const void *d_ptr, void *rand_val_ptr, void *dk_ptr, void *dv_ptr, void *dbias_ptr, void *dq_acc_ptr, const void *seqstart_q_ptr, const void *seqstart_k_ptr, const void *seqlen_k_ptr, ck_tile::index_t hdim_q, ck_tile::index_t hdim_v, ck_tile::index_t num_head_q, ck_tile::index_t nhead_ratio_qk, float scale, ck_tile::index_t stride_q, ck_tile::index_t stride_k, ck_tile::index_t stride_v, ck_tile::index_t stride_bias, ck_tile::index_t stride_randval, ck_tile::index_t stride_do, ck_tile::index_t stride_dq_acc, ck_tile::index_t stride_dk, ck_tile::index_t stride_dv, ck_tile::index_t stride_dbias, ck_tile::index_t nhead_stride_q, ck_tile::index_t nhead_stride_k, ck_tile::index_t nhead_stride_v, ck_tile::index_t nhead_stride_bias, ck_tile::index_t nhead_stride_randval, ck_tile::index_t nhead_stride_do, ck_tile::index_t nhead_stride_lsed, ck_tile::index_t nhead_stride_dq_acc, ck_tile::index_t nhead_stride_dk, ck_tile::index_t nhead_stride_dv, ck_tile::index_t nhead_stride_dbias, ck_tile::index_t split_stride_dq_acc, ck_tile::index_t window_size_left, ck_tile::index_t window_size_right, ck_tile::index_t mask_type, float p_drop, const std::tuple< const void *, const void * > &drop_seed_offset)
Definition: fmha_bwd_kernel.hpp:965
ck_tile::remove_cvref_t< typename FmhaPipeline::DDataType > DDataType
Definition: fmha_bwd_kernel.hpp:45
ck_tile::remove_cvref_t< typename FmhaPipeline::KDataType > KDataType
Definition: fmha_bwd_kernel.hpp:39
ck_tile::remove_cvref_t< typename FmhaPipeline::BiasGradDataType > BiasGradDataType
Definition: fmha_bwd_kernel.hpp:52
static constexpr bool kHasBiasGrad
Definition: fmha_bwd_kernel.hpp:60
ck_tile::remove_cvref_t< typename FmhaPipeline::FmhaMask > FmhaMask
Definition: fmha_bwd_kernel.hpp:61
ck_tile::remove_cvref_t< typename FmhaPipeline::QGradDataType > QGradDataType
Definition: fmha_bwd_kernel.hpp:49
ck_tile::remove_cvref_t< typename FmhaPipeline::OGradDataType > OGradDataType
Definition: fmha_bwd_kernel.hpp:48
static constexpr bool kIsDeterministic
Definition: fmha_bwd_kernel.hpp:66
static constexpr CK_TILE_HOST std::enable_if_t< Cond, Kargs > MakeKargs(const void *q_ptr, const void *k_ptr, const void *v_ptr, const void *bias_ptr, const void *lse_ptr, const void *do_ptr, const void *d_ptr, void *rand_val_ptr, void *dk_ptr, void *dv_ptr, void *dbias_ptr, void *dq_acc_ptr, ck_tile::index_t seqlen_q, ck_tile::index_t seqlen_k, ck_tile::index_t hdim_q, ck_tile::index_t hdim_v, ck_tile::index_t num_head_q, ck_tile::index_t nhead_ratio_qk, float scale, ck_tile::index_t stride_q, ck_tile::index_t stride_k, ck_tile::index_t stride_v, ck_tile::index_t stride_bias, ck_tile::index_t stride_randval, ck_tile::index_t stride_do, ck_tile::index_t stride_dq_acc, ck_tile::index_t stride_dk, ck_tile::index_t stride_dv, ck_tile::index_t stride_dbias, ck_tile::index_t nhead_stride_q, ck_tile::index_t nhead_stride_k, ck_tile::index_t nhead_stride_v, ck_tile::index_t nhead_stride_bias, ck_tile::index_t nhead_stride_randval, ck_tile::index_t nhead_stride_do, ck_tile::index_t nhead_stride_lsed, ck_tile::index_t nhead_stride_dq_acc, ck_tile::index_t nhead_stride_dk, ck_tile::index_t nhead_stride_dv, ck_tile::index_t nhead_stride_dbias, ck_tile::index_t batch_stride_q, ck_tile::index_t batch_stride_k, ck_tile::index_t batch_stride_v, ck_tile::index_t batch_stride_bias, ck_tile::index_t batch_stride_randval, ck_tile::index_t batch_stride_do, ck_tile::index_t batch_stride_lsed, ck_tile::index_t batch_stride_dq_acc, ck_tile::index_t batch_stride_dk, ck_tile::index_t batch_stride_dv, ck_tile::index_t batch_stride_dbias, ck_tile::index_t split_stride_dq_acc, ck_tile::index_t window_size_left, ck_tile::index_t window_size_right, ck_tile::index_t mask_type, float p_drop, const std::tuple< const void *, const void * > &drop_seed_offset)
Definition: fmha_bwd_kernel.hpp:597
static constexpr ck_tile::index_t kBlockPerCu
Definition: fmha_bwd_kernel.hpp:36
static constexpr bool kPadSeqLenK
Definition: fmha_bwd_kernel.hpp:56
ck_tile::remove_cvref_t< FmhaPipeline_ > FmhaPipeline
Definition: fmha_bwd_kernel.hpp:32
static constexpr CK_TILE_HOST std::enable_if_t< Cond, Kargs > MakeKargs(const void *q_ptr, const void *k_ptr, const void *v_ptr, const void *bias_ptr, const void *lse_ptr, const void *do_ptr, const void *d_ptr, void *rand_val_ptr, void *dk_ptr, void *dv_ptr, void *dbias_ptr, void *dq_acc_ptr, const void *seqstart_q_ptr, const void *seqstart_k_ptr, const void *seqlen_k_ptr, ck_tile::index_t hdim_q, ck_tile::index_t hdim_v, ck_tile::index_t num_head_q, ck_tile::index_t nhead_ratio_qk, float scale, ck_tile::index_t stride_q, ck_tile::index_t stride_k, ck_tile::index_t stride_v, ck_tile::index_t stride_bias, ck_tile::index_t stride_randval, ck_tile::index_t stride_do, ck_tile::index_t stride_dq_acc, ck_tile::index_t stride_dk, ck_tile::index_t stride_dv, ck_tile::index_t stride_dbias, ck_tile::index_t nhead_stride_q, ck_tile::index_t nhead_stride_k, ck_tile::index_t nhead_stride_v, ck_tile::index_t nhead_stride_bias, ck_tile::index_t nhead_stride_randval, ck_tile::index_t nhead_stride_do, ck_tile::index_t nhead_stride_lsed, ck_tile::index_t nhead_stride_dq_acc, ck_tile::index_t nhead_stride_dk, ck_tile::index_t nhead_stride_dv, ck_tile::index_t nhead_stride_dbias, ck_tile::index_t split_stride_dq_acc, ck_tile::index_t window_size_left, ck_tile::index_t window_size_right, ck_tile::index_t mask_type, float p_drop, const std::tuple< uint64_t, uint64_t > &drop_seed_offset)
Definition: fmha_bwd_kernel.hpp:864
static constexpr bool kPadHeadDimV
Definition: fmha_bwd_kernel.hpp:58
ck_tile::remove_cvref_t< typename FmhaPipeline::KGradDataType > KGradDataType
Definition: fmha_bwd_kernel.hpp:50
static constexpr CK_TILE_HOST std::enable_if_t< Cond, Kargs > MakeKargs(const void *q_ptr, const void *k_ptr, const void *v_ptr, const void *bias_ptr, const void *lse_ptr, const void *do_ptr, const void *d_ptr, void *rand_val_ptr, void *dk_ptr, void *dv_ptr, void *dbias_ptr, void *dq_acc_ptr, ck_tile::index_t seqlen_q, ck_tile::index_t seqlen_k, ck_tile::index_t hdim_q, ck_tile::index_t hdim_v, ck_tile::index_t num_head_q, ck_tile::index_t nhead_ratio_qk, float scale, ck_tile::index_t stride_q, ck_tile::index_t stride_k, ck_tile::index_t stride_v, ck_tile::index_t stride_bias, ck_tile::index_t stride_randval, ck_tile::index_t stride_do, ck_tile::index_t stride_dq_acc, ck_tile::index_t stride_dk, ck_tile::index_t stride_dv, ck_tile::index_t stride_dbias, ck_tile::index_t nhead_stride_q, ck_tile::index_t nhead_stride_k, ck_tile::index_t nhead_stride_v, ck_tile::index_t nhead_stride_bias, ck_tile::index_t nhead_stride_randval, ck_tile::index_t nhead_stride_do, ck_tile::index_t nhead_stride_lsed, ck_tile::index_t nhead_stride_dq_acc, ck_tile::index_t nhead_stride_dk, ck_tile::index_t nhead_stride_dv, ck_tile::index_t nhead_stride_dbias, ck_tile::index_t batch_stride_q, ck_tile::index_t batch_stride_k, ck_tile::index_t batch_stride_v, ck_tile::index_t batch_stride_bias, ck_tile::index_t batch_stride_randval, ck_tile::index_t batch_stride_do, ck_tile::index_t batch_stride_lsed, ck_tile::index_t batch_stride_dq_acc, ck_tile::index_t batch_stride_dk, ck_tile::index_t batch_stride_dv, ck_tile::index_t batch_stride_dbias, ck_tile::index_t split_stride_dq_acc, ck_tile::index_t window_size_left, ck_tile::index_t window_size_right, ck_tile::index_t mask_type, float p_drop, const std::tuple< uint64_t, uint64_t > &drop_seed_offset)
Definition: fmha_bwd_kernel.hpp:476
std::conditional_t< kIsGroupMode, FmhaBwdGroupModeKargs, FmhaBwdBatchModeKargs > Kargs
Definition: fmha_bwd_kernel.hpp:303
static constexpr bool kPadHeadDimQ
Definition: fmha_bwd_kernel.hpp:57
static constexpr CK_TILE_DEVICE auto GetTileIndex()
Definition: fmha_bwd_kernel.hpp:1070
ck_tile::remove_cvref_t< typename FmhaPipeline::FmhaDropout > FmhaDropout
Definition: fmha_bwd_kernel.hpp:62
ck_tile::remove_cvref_t< typename FmhaPipeline::LSEDataType > LSEDataType
Definition: fmha_bwd_kernel.hpp:43
ck_tile::remove_cvref_t< typename FmhaPipeline::QDataType > QDataType
Definition: fmha_bwd_kernel.hpp:38
ck_tile::remove_cvref_t< typename FmhaPipeline::AccDataType > AccDataType
Definition: fmha_bwd_kernel.hpp:44
static CK_TILE_HOST std::string GetName()
Definition: fmha_bwd_kernel.hpp:74
ck_tile::remove_cvref_t< KGradEpiloguePipeline_ > KGradEpiloguePipeline
Definition: fmha_bwd_kernel.hpp:33
CK_TILE_DEVICE void operator()(Kargs kargs) const
Definition: fmha_bwd_kernel.hpp:1088
ck_tile::remove_cvref_t< VGradEpiloguePipeline_ > VGradEpiloguePipeline
Definition: fmha_bwd_kernel.hpp:34
static constexpr CK_TILE_HOST std::enable_if_t< Cond, Kargs > MakeKargsImpl(const void *q_ptr, const void *k_ptr, const void *v_ptr, const void *bias_ptr, const void *lse_ptr, const void *do_ptr, const void *d_ptr, void *rand_val_ptr, void *dk_ptr, void *dv_ptr, void *dbias_ptr, void *dq_acc_ptr, const void *seqstart_q_ptr, const void *seqstart_k_ptr, const void *seqlen_k_ptr, ck_tile::index_t hdim_q, ck_tile::index_t hdim_v, ck_tile::index_t num_head_q, ck_tile::index_t nhead_ratio_qk, float scale, ck_tile::index_t stride_q, ck_tile::index_t stride_k, ck_tile::index_t stride_v, ck_tile::index_t stride_bias, ck_tile::index_t stride_randval, ck_tile::index_t stride_do, ck_tile::index_t stride_dq_acc, ck_tile::index_t stride_dk, ck_tile::index_t stride_dv, ck_tile::index_t stride_dbias, ck_tile::index_t nhead_stride_q, ck_tile::index_t nhead_stride_k, ck_tile::index_t nhead_stride_v, ck_tile::index_t nhead_stride_bias, ck_tile::index_t nhead_stride_randval, ck_tile::index_t nhead_stride_do, ck_tile::index_t nhead_stride_lsed, ck_tile::index_t nhead_stride_dq_acc, ck_tile::index_t nhead_stride_dk, ck_tile::index_t nhead_stride_dv, ck_tile::index_t nhead_stride_dbias, ck_tile::index_t split_stride_dq_acc, ck_tile::index_t window_size_left, ck_tile::index_t window_size_right, ck_tile::index_t mask_type, float p_drop, std::variant< std::pair< uint64_t, uint64_t >, std::pair< const void *, const void * >> drop_seed_offset)
Definition: fmha_bwd_kernel.hpp:717
static constexpr CK_TILE_HOST auto BlockSize()
Definition: fmha_bwd_kernel.hpp:1079
static constexpr bool kHasMask
Definition: fmha_bwd_kernel.hpp:63
static constexpr CK_TILE_HOST auto GridSize(ck_tile::index_t batch_size_, ck_tile::index_t nhead_, ck_tile::index_t seqlen_k_)
Definition: fmha_bwd_kernel.hpp:1064
static constexpr auto BiasEnum
Definition: fmha_bwd_kernel.hpp:59
ck_tile::remove_cvref_t< typename FmhaPipeline::RandValOutputDataType > RandValOutputDataType
Definition: fmha_bwd_kernel.hpp:47
ck_tile::remove_cvref_t< typename FmhaPipeline::GemmDataType > GemmDataType
Definition: fmha_bwd_kernel.hpp:42
static constexpr bool kPadSeqLenQ
Definition: fmha_bwd_kernel.hpp:55
ck_tile::remove_cvref_t< typename FmhaPipeline::VGradDataType > VGradDataType
Definition: fmha_bwd_kernel.hpp:51
ck_tile::remove_cvref_t< typename FmhaPipeline::BiasDataType > BiasDataType
Definition: fmha_bwd_kernel.hpp:41
static constexpr ck_tile::index_t kBlockSize
Definition: fmha_bwd_kernel.hpp:35
static constexpr bool kIsStoreRandval
Definition: fmha_bwd_kernel.hpp:65
ck_tile::remove_cvref_t< typename FmhaPipeline::VDataType > VDataType
Definition: fmha_bwd_kernel.hpp:40
static constexpr bool kIsGroupMode
Definition: fmha_bwd_kernel.hpp:54
static constexpr CK_TILE_HOST_DEVICE ck_tile::index_t GetSmemSize()
Definition: fmha_bwd_kernel.hpp:1081
ck_tile::index_t batch_stride_o
Definition: fmha_bwd_kernel.hpp:1654
ck_tile::index_t batch_stride_do
Definition: fmha_bwd_kernel.hpp:1653
ck_tile::index_t batch_stride_d
Definition: fmha_bwd_kernel.hpp:1655
void * d_ptr
Definition: fmha_bwd_kernel.hpp:1636
const void * o_ptr
Definition: fmha_bwd_kernel.hpp:1634
ck_tile::index_t hdim_v
Definition: fmha_bwd_kernel.hpp:1641
ck_tile::index_t nhead_stride_do
Definition: fmha_bwd_kernel.hpp:1646
ck_tile::index_t stride_o
Definition: fmha_bwd_kernel.hpp:1644
ck_tile::index_t nhead_stride_o
Definition: fmha_bwd_kernel.hpp:1647
const void * do_ptr
Definition: fmha_bwd_kernel.hpp:1635
ck_tile::index_t stride_do
Definition: fmha_bwd_kernel.hpp:1643
ck_tile::index_t seqlen_q
Definition: fmha_bwd_kernel.hpp:1640
float p_undrop
Definition: fmha_bwd_kernel.hpp:1638
ck_tile::index_t nhead_stride_d
Definition: fmha_bwd_kernel.hpp:1648
const int32_t * seqstart_q_ptr
Definition: fmha_bwd_kernel.hpp:1660
Definition: fmha_bwd_kernel.hpp:1603
Definition: fmha_bwd_kernel.hpp:1587
ck_tile::remove_cvref_t< typename FmhaBwdOGradDotO::ODataType > ODataType
Definition: fmha_bwd_kernel.hpp:1595
static constexpr CK_TILE_HOST_DEVICE ck_tile::index_t GetSmemSize()
Definition: fmha_bwd_kernel.hpp:1748
ck_tile::remove_cvref_t< FmhaBwdOGradDotO_ > FmhaBwdOGradDotO
Definition: fmha_bwd_kernel.hpp:1588
static constexpr CK_TILE_HOST std::enable_if_t< Cond, Kargs > MakeKargs(const void *o_ptr, const void *do_ptr, void *d_ptr, float p_undrop, ck_tile::index_t seqlen_q, ck_tile::index_t hdim_v, ck_tile::index_t stride_do, ck_tile::index_t stride_o, ck_tile::index_t nhead_stride_do, ck_tile::index_t nhead_stride_o, ck_tile::index_t nhead_stride_d, ck_tile::index_t batch_stride_do, ck_tile::index_t batch_stride_o, ck_tile::index_t batch_stride_d)
Definition: fmha_bwd_kernel.hpp:1668
static constexpr ck_tile::index_t kVHeaddim
Definition: fmha_bwd_kernel.hpp:1592
CK_TILE_DEVICE void operator()(Kargs kargs) const
Definition: fmha_bwd_kernel.hpp:1750
static constexpr bool kIsGroupMode
Definition: fmha_bwd_kernel.hpp:1598
ck_tile::remove_cvref_t< typename FmhaBwdOGradDotO::OGradDataType > OGradDataType
Definition: fmha_bwd_kernel.hpp:1596
static constexpr ck_tile::index_t kM0
Definition: fmha_bwd_kernel.hpp:1591
static constexpr CK_TILE_HOST std::enable_if_t< Cond, Kargs > MakeKargs(const void *o_ptr, const void *do_ptr, void *d_ptr, float p_undrop, const void *seqstart_q_ptr, ck_tile::index_t hdim_v, ck_tile::index_t stride_do, ck_tile::index_t stride_o, ck_tile::index_t nhead_stride_do, ck_tile::index_t nhead_stride_o, ck_tile::index_t nhead_stride_d)
Definition: fmha_bwd_kernel.hpp:1703
static constexpr ck_tile::index_t kBlockSize
Definition: fmha_bwd_kernel.hpp:1589
static constexpr ck_tile::index_t kBlockPerCu
Definition: fmha_bwd_kernel.hpp:1590
static constexpr CK_TILE_DEVICE auto GetTileIndex()
Definition: fmha_bwd_kernel.hpp:1737
static constexpr bool kPadSeqLenQ
Definition: fmha_bwd_kernel.hpp:1599
std::conditional_t< kIsGroupMode, FmhaBwdOGradDotOGroupModeKargs, FmhaBwdOGradDotOBatchModeKargs > Kargs
Definition: fmha_bwd_kernel.hpp:1664
static CK_TILE_HOST std::string GetName()
Definition: fmha_bwd_kernel.hpp:1608
ck_tile::remove_cvref_t< typename FmhaBwdOGradDotO::DDataType > DDataType
Definition: fmha_bwd_kernel.hpp:1594
static constexpr bool kPadHeadDimV
Definition: fmha_bwd_kernel.hpp:1600
static constexpr CK_TILE_HOST auto BlockSize()
Definition: fmha_bwd_kernel.hpp:1746
static constexpr CK_TILE_HOST auto GridSize(ck_tile::index_t batch_size_, ck_tile::index_t nhead_, ck_tile::index_t seqlen_q_)
Definition: fmha_bwd_kernel.hpp:1732
Definition: integral_constant.hpp:13
Definition: coordinate_transform.hpp:1443
Definition: sequence.hpp:52