/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/fmha/kernel/fmha_fwd_v3_kernel.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/fmha/kernel/fmha_fwd_v3_kernel.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/fmha/kernel/fmha_fwd_v3_kernel.hpp Source File
fmha_fwd_v3_kernel.hpp
Go to the documentation of this file.
1 // Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
2 // SPDX-License-Identifier: MIT
3 
4 #pragma once
5 
6 #include "ck_tile/core.hpp"
7 #include "ck_tile/ops/common.hpp"
10 
11 #include <type_traits>
12 #include <utility>
13 
14 namespace ck_tile {
15 
18 template <typename FmhaPipeline_, typename EpiloguePipeline_>
20 {
23  static constexpr ck_tile::index_t kBlockSize = FmhaPipeline::kBlockSize;
24  static constexpr ck_tile::index_t kBlockPerCu = FmhaPipeline::kBlockPerCu;
25  static_assert(kBlockPerCu > 0);
26 
33 
34  static constexpr bool kIsGroupMode = FmhaPipeline::kIsGroupMode;
35  static constexpr bool kPadSeqLenQ = FmhaPipeline::kPadSeqLenQ;
36  static constexpr bool kPadSeqLenK = FmhaPipeline::kPadSeqLenK;
37  static constexpr bool kPadHeadDimQ = FmhaPipeline::kPadHeadDimQ;
38  static constexpr bool kPadHeadDimV = FmhaPipeline::kPadHeadDimV;
39  static constexpr bool kHasLogitsSoftCap = FmhaPipeline::kHasLogitsSoftCap;
40  static constexpr bool kStoreLSE = FmhaPipeline::kStoreLSE;
41 
44  static constexpr bool kHasMask = FmhaMask::IsMasking;
45 
46  template <ck_tile::index_t I> // to avoid duplicated base class prblem, introduce an template
47  // arg
49  {
50  };
51 
52  // kargs use aggregate initializer, so no constructor will provided
53  // use inheritance to minimize karg size
54  // user need to use MakeKargs() function to create kargs.
56  {
57  const void* q_ptr;
58  const void* k_ptr;
59  const void* v_ptr;
60  void* o_ptr;
61 
66 
68  // for MQA/GQA, nhead could be different. This parameter is nhead_q / nhead_k
69  // if this param is larger than 1, indicate MQA/GQA case
71  float scale_s;
72 
77 
82  };
83 
85  {
86  // ck_tile::index_t window_size_left, window_size_right;
90  };
91 
93  {
94  void* lse_ptr = nullptr;
97  };
98 
100  {
102 
103  void init_logits_soft_cap(float logits_soft_cap_)
104  {
105  if(0 < logits_soft_cap_)
106  {
107  logits_soft_cap = logits_soft_cap_;
109  }
110  else
111  {
112  logits_soft_cap = 0.f;
113  logits_soft_cap_rcp = 0.f;
114  }
115  }
116 
119  };
120 
123  std::conditional_t<kHasMask, FmhaFwdMaskKargs, FmhaFwdEmptyKargs<0>>,
124  std::conditional_t<kStoreLSE, FmhaFwdCommonLSEKargs, FmhaFwdEmptyKargs<1>>,
125  std::conditional_t<kHasLogitsSoftCap, FmhaFwdLogitsSoftCapKargs, FmhaFwdEmptyKargs<2>>
126  {
131 
132  // Optional cumulative sequence length pointers for batch mode
133  // If provided, they override seqlen_q / seqlen_k per-batch to skip tail padding.
134  const ck_tile::index_t* cu_seqlen_q_ptr = nullptr; // [batch+1]
135  const ck_tile::index_t* cu_seqlen_k_ptr = nullptr; // [batch+1]
136  };
137 
140  std::conditional_t<kHasMask, FmhaFwdMaskKargs, FmhaFwdEmptyKargs<0>>,
141  std::conditional_t<kStoreLSE, FmhaFwdCommonLSEKargs, FmhaFwdEmptyKargs<1>>,
142  std::conditional_t<kHasLogitsSoftCap, FmhaFwdLogitsSoftCapKargs, FmhaFwdEmptyKargs<2>>
143  {
148 
149  // Optional cumulative padded sequence starts (including PAD tokens)
150  // Used solely to compute memory offsets when sequences are physically padded.
151  const int32_t* cu_seqlen_q_ptr = nullptr; // [batch+1]
152  const int32_t* cu_seqlen_k_ptr = nullptr; // [batch+1]
153  };
154 
155  using Kargs = std::conditional_t<kIsGroupMode, FmhaFwdGroupModeKargs, FmhaFwdBatchModeKargs>;
156 
158  {
162  };
163 
164  template <bool Cond = !kIsGroupMode>
165  CK_TILE_HOST static constexpr std::enable_if_t<Cond, Kargs>
166  MakeKargs(const void* q_ptr,
167  const void* k_ptr,
168  const void* v_ptr,
169  void* lse_ptr,
170  void* o_ptr,
171  ck_tile::index_t seqlen_q,
172  ck_tile::index_t seqlen_k,
173  ck_tile::index_t hdim_q,
174  ck_tile::index_t hdim_v,
175  ck_tile::index_t num_head_q,
176  ck_tile::index_t nhead_ratio_qk,
177  float scale_s,
178  float logits_soft_cap,
179  ck_tile::index_t stride_q,
180  ck_tile::index_t stride_k,
181  ck_tile::index_t stride_v,
182  ck_tile::index_t stride_o,
183  ck_tile::index_t nhead_stride_q,
184  ck_tile::index_t nhead_stride_k,
185  ck_tile::index_t nhead_stride_v,
186  ck_tile::index_t nhead_stride_lse,
187  ck_tile::index_t nhead_stride_o,
188  ck_tile::index_t batch_stride_q,
189  ck_tile::index_t batch_stride_k,
190  ck_tile::index_t batch_stride_v,
191  ck_tile::index_t batch_stride_lse,
192  ck_tile::index_t batch_stride_o,
193  ck_tile::index_t window_size_left,
194  ck_tile::index_t window_size_right,
195  ck_tile::index_t mask_type,
196  ck_tile::index_t remap_opt,
197  const void* cu_seqlen_q_ptr = nullptr,
198  const void* cu_seqlen_k_ptr = nullptr)
199  {
200  Kargs kargs{{q_ptr,
201  k_ptr,
202  v_ptr,
203  o_ptr,
204  seqlen_q,
205  seqlen_k,
206  hdim_q,
207  hdim_v,
208  num_head_q,
209  nhead_ratio_qk,
210  static_cast<float>(scale_s * ck_tile::log2e_v<>),
211  stride_q,
212  stride_k,
213  stride_v,
214  stride_o,
215  nhead_stride_q,
216  nhead_stride_k,
217  nhead_stride_v,
218  nhead_stride_o}, // args for common karg
219  {}, // placeholder for mask
220  {}, // placeholder for lse
221  {}, // placeholder for logits_soft_cap
222  batch_stride_q,
223  batch_stride_k,
224  batch_stride_v,
225  batch_stride_o};
226 
227  if constexpr(kHasMask)
228  {
229  kargs.window_size_left = window_size_left;
230  kargs.window_size_right = window_size_right;
231  kargs.mask_type = static_cast<ck_tile::GenericAttentionMaskEnum>(mask_type);
232  kargs.remap_opt = remap_opt;
233  }
234  if constexpr(kStoreLSE)
235  {
236  kargs.lse_ptr = lse_ptr;
237  kargs.nhead_stride_lse = nhead_stride_lse;
238  kargs.batch_stride_lse = batch_stride_lse;
239  }
240  if constexpr(kHasLogitsSoftCap)
241  {
242  kargs.init_logits_soft_cap(logits_soft_cap);
243  }
244 
245  kargs.cu_seqlen_q_ptr = reinterpret_cast<const int32_t*>(cu_seqlen_q_ptr);
246  kargs.cu_seqlen_k_ptr = reinterpret_cast<const int32_t*>(cu_seqlen_k_ptr);
247  return kargs;
248  }
249 
250  template <bool Cond = kIsGroupMode>
251  CK_TILE_HOST static constexpr std::enable_if_t<Cond, Kargs>
252  MakeKargs(const void* q_ptr,
253  const void* k_ptr,
254  const void* v_ptr,
255  void* lse_ptr,
256  void* o_ptr,
257  const void* seqstart_q_ptr,
258  const void* seqstart_k_ptr,
259  const void* seqlen_q_ptr,
260  const void* seqlen_k_ptr,
261  ck_tile::index_t hdim_q,
262  ck_tile::index_t hdim_v,
263  ck_tile::index_t num_head_q,
264  ck_tile::index_t nhead_ratio_qk,
265  float scale_s,
266  float logits_soft_cap,
267  ck_tile::index_t stride_q,
268  ck_tile::index_t stride_k,
269  ck_tile::index_t stride_v,
270  ck_tile::index_t stride_o,
271  ck_tile::index_t nhead_stride_q,
272  ck_tile::index_t nhead_stride_k,
273  ck_tile::index_t nhead_stride_v,
274  ck_tile::index_t nhead_stride_lse,
275  ck_tile::index_t nhead_stride_o,
276  ck_tile::index_t window_size_left,
277  ck_tile::index_t window_size_right,
278  ck_tile::index_t mask_type,
279  ck_tile::index_t remap_opt,
280  const void* cu_seqlen_q_ptr = nullptr,
281  const void* cu_seqlen_k_ptr = nullptr)
282  {
283  Kargs kargs{{q_ptr,
284  k_ptr,
285  v_ptr,
286  o_ptr,
287  -1, // seqlen will be updated by another pointer
288  -1, //
289  hdim_q,
290  hdim_v,
291  num_head_q,
292  nhead_ratio_qk,
293  static_cast<float>(scale_s * ck_tile::log2e_v<>),
294  stride_q,
295  stride_k,
296  stride_v,
297  stride_o,
298  nhead_stride_q,
299  nhead_stride_k,
300  nhead_stride_v,
301  nhead_stride_o}, // args for common karg
302  {}, // placeholder for mask
303  {}, // placeholder for lse
304  {}, // placeholder for logits_soft_cap
305  reinterpret_cast<const int32_t*>(seqstart_q_ptr),
306  reinterpret_cast<const int32_t*>(seqstart_k_ptr),
307  reinterpret_cast<const int32_t*>(seqlen_q_ptr),
308  reinterpret_cast<const int32_t*>(seqlen_k_ptr)};
309 
310  if constexpr(kHasMask)
311  {
312  kargs.window_size_left = window_size_left;
313  kargs.window_size_right = window_size_right;
314  kargs.mask_type = static_cast<ck_tile::GenericAttentionMaskEnum>(mask_type);
315  kargs.remap_opt = remap_opt;
316  }
317  if constexpr(kStoreLSE)
318  {
319  kargs.lse_ptr = lse_ptr;
320  kargs.nhead_stride_lse = nhead_stride_lse;
321  }
322  if constexpr(kHasLogitsSoftCap)
323  {
324  kargs.init_logits_soft_cap(logits_soft_cap);
325  }
326 
327  kargs.cu_seqlen_q_ptr = reinterpret_cast<const int32_t*>(cu_seqlen_q_ptr);
328  kargs.cu_seqlen_k_ptr = reinterpret_cast<const int32_t*>(cu_seqlen_k_ptr);
329  return kargs;
330  }
331 
332  CK_TILE_HOST static constexpr auto GridSize(ck_tile::index_t batch_size,
333  ck_tile::index_t nhead,
334  ck_tile::index_t max_seqlen_q,
335  ck_tile::index_t hdim_v)
336  {
337  if constexpr(kIsGroupMode)
338  {
339  return dim3(nhead,
340  batch_size,
341  ck_tile::integer_divide_ceil(max_seqlen_q, FmhaPipeline::kM0) *
342  ck_tile::integer_divide_ceil(hdim_v, FmhaPipeline::kN1));
343  }
344  else
345  {
346  return dim3(nhead,
347  ck_tile::integer_divide_ceil(max_seqlen_q, FmhaPipeline::kM0) *
348  ck_tile::integer_divide_ceil(hdim_v, FmhaPipeline::kN1),
349  batch_size);
350  }
351  }
352 
353  CK_TILE_DEVICE static constexpr auto
354  RemapTileIndices(int32_t tg_idx, int32_t tg_idy, int32_t remap_option)
355  {
356  if(remap_option < 1)
357  {
358  return make_tuple(static_cast<int32_t>(gridDim.x - tg_idx - 1), tg_idy);
359  }
360 
361  int32_t remapped_tg_idx = tg_idx;
362  int32_t remapped_tg_idy = tg_idy;
363 
364  if(remap_option == 2)
365  { // special remapping
366  int32_t tmp0 = (remapped_tg_idy & 0x7) * gridDim.x + remapped_tg_idx;
367  int32_t tmp1 = tmp0 & 0x7;
368 
369  remapped_tg_idx = tmp0 >> 3;
370  remapped_tg_idy = (remapped_tg_idy & 0xfffffff8) + tmp1;
371  }
372  else
373  { // normal remapping
374  int32_t cus_per_xdim_per_xcc = gridDim.x >> 3;
375  int32_t tgs_cu_id = remapped_tg_idx >> 3;
376 
377  if(tgs_cu_id < cus_per_xdim_per_xcc)
378  {
379  int32_t tgs_xcc_id = remapped_tg_idx & 0x7;
380  int32_t new_tg_idx = tgs_xcc_id * cus_per_xdim_per_xcc + tgs_cu_id;
381 
382  remapped_tg_idx = new_tg_idx;
383  }
384  }
385 
386  return make_tuple(remapped_tg_idx, remapped_tg_idy);
387  }
388 
389  CK_TILE_DEVICE static constexpr auto GetTileIndex(const Kargs&)
390  {
391  using namespace ck_tile;
392 
393  // const index_t num_tile_n1 = ck_tile::integer_divide_ceil(kargs.hdim_v,
394  // FmhaPipeline::kN1);
395 
396  // assume that num_tile_n1 is always 1
397  if constexpr(kIsGroupMode)
398  {
399  const index_t i_nhead = blockIdx.x;
400  const index_t i_batch = blockIdx.y;
401  const index_t i_block = blockIdx.z;
402 
403  if constexpr(kHasMask)
404  {
405  return ck_tile::make_tuple(gridDim.z - 1 - i_block, 0, i_nhead, i_batch);
406  }
407  else
408  {
409  return ck_tile::make_tuple(i_block, 0, i_nhead, i_batch);
410  }
411  }
412  else
413  {
414  const index_t i_nhead = blockIdx.x;
415  const index_t i_block = blockIdx.y;
416  const index_t i_batch = blockIdx.z;
417 
418  if constexpr(kHasMask)
419  {
420  return ck_tile::make_tuple(gridDim.y - 1 - i_block, 0, i_nhead, i_batch);
421  }
422  else
423  {
424  return ck_tile::make_tuple(i_block, 0, i_nhead, i_batch);
425  }
426  }
427  }
428 
429  CK_TILE_HOST static constexpr auto BlockSize() { return dim3(kBlockSize); }
430 
432  {
433  return ck_tile::max(FmhaPipeline::GetSmemSize(), EpiloguePipeline::GetSmemSize());
434  }
435 
436  CK_TILE_DEVICE void operator()(Kargs kargs) const
437  {
438  using namespace ck_tile;
439 
440  // allocate LDS
441  __shared__ char smem_ptr[GetSmemSize()];
442 
443  // divide problem
444  const auto [i_tile_m, i_tile_n, i_nhead, i_batch] = GetTileIndex(kargs);
445 
446  const index_t i_m0 = amd_wave_read_first_lane(i_tile_m * FmhaPipeline::kM0);
447  const index_t i_n1 = amd_wave_read_first_lane(i_tile_n * FmhaPipeline::kN1);
448 
449  long_index_t batch_offset_q = 0;
450  long_index_t batch_offset_k = 0;
451  long_index_t batch_offset_v = 0;
452  long_index_t batch_offset_lse = 0;
453  long_index_t batch_offset_o = 0;
454 
455  if constexpr(kIsGroupMode)
456  {
457  // Use seqstart_q_ptr and seqstart_k_ptr for physical starts
458  const long_index_t query_start = kargs.seqstart_q_ptr[i_batch];
459  const long_index_t key_start = kargs.seqstart_k_ptr[i_batch];
460 
461  batch_offset_q = query_start * kargs.stride_q;
462  batch_offset_k = key_start * kargs.stride_k;
463  batch_offset_v = key_start * kargs.stride_v;
464 
465  if constexpr(kStoreLSE)
466  {
467  // LSE layout is [nhead, total_seqlen], index by unpadded start
468  batch_offset_lse = query_start;
469  }
470  batch_offset_o = query_start * kargs.stride_o;
471 
472  // real logical lengths (exclude PAD)
473  // Priority: seqlen_q_ptr > cu_seqlen_q_ptr > calculated from seqstart_q_ptr
474  if(kargs.seqlen_q_ptr != nullptr)
475  {
476  kargs.seqlen_q = kargs.seqlen_q_ptr[i_batch];
477  }
478  else if(kargs.cu_seqlen_q_ptr != nullptr)
479  {
480  kargs.seqlen_q =
481  kargs.cu_seqlen_q_ptr[i_batch + 1] - kargs.cu_seqlen_q_ptr[i_batch];
482  }
483  else
484  {
485  kargs.seqlen_q = kargs.seqstart_q_ptr[i_batch + 1] - kargs.seqstart_q_ptr[i_batch];
486  }
487  // # of required blocks is different in each groups, terminate unnecessary blocks
488  // earlier
489  if(kargs.seqlen_q <= i_m0)
490  {
491  return;
492  }
493 
494  if(kargs.seqlen_k_ptr != nullptr)
495  {
496  kargs.seqlen_k = kargs.seqlen_k_ptr[i_batch];
497  }
498  else if(kargs.cu_seqlen_k_ptr != nullptr)
499  {
500  kargs.seqlen_k =
501  kargs.cu_seqlen_k_ptr[i_batch + 1] - kargs.cu_seqlen_k_ptr[i_batch];
502  }
503  else
504  {
505  kargs.seqlen_k = kargs.seqstart_k_ptr[i_batch + 1] - kargs.seqstart_k_ptr[i_batch];
506  }
507  }
508  else
509  {
510  batch_offset_q = static_cast<long_index_t>(i_batch) * kargs.batch_stride_q;
511  batch_offset_k = static_cast<long_index_t>(i_batch) * kargs.batch_stride_k;
512  batch_offset_v = static_cast<long_index_t>(i_batch) * kargs.batch_stride_v;
513  if constexpr(kStoreLSE)
514  {
515  batch_offset_lse = static_cast<long_index_t>(i_batch) * kargs.batch_stride_lse;
516  }
517  batch_offset_o = static_cast<long_index_t>(i_batch) * kargs.batch_stride_o;
518 
519  // If cumulative seqlen pointers are provided, override per-batch effective lengths
520  if(kargs.cu_seqlen_q_ptr != nullptr)
521  {
522  kargs.seqlen_q =
523  kargs.cu_seqlen_q_ptr[i_batch + 1] - kargs.cu_seqlen_q_ptr[i_batch];
524  }
525  if(kargs.cu_seqlen_k_ptr != nullptr)
526  {
527  kargs.seqlen_k =
528  kargs.cu_seqlen_k_ptr[i_batch + 1] - kargs.cu_seqlen_k_ptr[i_batch];
529  }
530  }
531 
532  // for simplicity, batch stride we just modify the pointer
533  const QDataType* q_ptr = reinterpret_cast<const QDataType*>(kargs.q_ptr) +
534  static_cast<long_index_t>(i_nhead) * kargs.nhead_stride_q +
535  batch_offset_q;
536  const KDataType* k_ptr =
537  reinterpret_cast<const KDataType*>(kargs.k_ptr) +
538  static_cast<long_index_t>(i_nhead / kargs.nhead_ratio_qk) * kargs.nhead_stride_k +
539  batch_offset_k;
540  const VDataType* v_ptr =
541  reinterpret_cast<const VDataType*>(kargs.v_ptr) +
542  static_cast<long_index_t>(i_nhead / kargs.nhead_ratio_qk) * kargs.nhead_stride_v +
543  batch_offset_v;
544  ODataType* o_ptr = reinterpret_cast<ODataType*>(kargs.o_ptr) +
545  static_cast<long_index_t>(i_nhead) * kargs.nhead_stride_o +
546  batch_offset_o;
547 
548  // Q/K/V DRAM and DRAM window
549  const auto q_dram = [&]() {
550  const auto q_dram_naive = make_naive_tensor_view<address_space_enum::global>(
551  q_ptr,
552  make_tuple(kargs.seqlen_q, kargs.hdim_q),
553  make_tuple(kargs.stride_q, 1),
555  number<1>{});
556 
557  return pad_tensor_view(
558  q_dram_naive,
561  }();
562  const auto k_dram = [&]() {
563  const auto k_dram_naive = make_naive_tensor_view<address_space_enum::global>(
564  k_ptr,
565  make_tuple(kargs.seqlen_k, kargs.hdim_q),
566  make_tuple(kargs.stride_k, 1),
568  number<1>{});
569 
570  return pad_tensor_view(
571  k_dram_naive,
574  }();
575  const auto v_dram = [&]() {
576  const auto v_dram_naive = make_naive_tensor_view<address_space_enum::global>(
577  v_ptr,
578  make_tuple(kargs.seqlen_k, kargs.hdim_v),
579  make_tuple(kargs.stride_v, 1),
581  number<1>{});
582 
583  return pad_tensor_view(
584  v_dram_naive,
587  }();
588 
589  auto q_dram_window = make_tile_window(
590  q_dram,
592  {i_m0, 0});
593 
594  auto k_dram_window = make_tile_window(
596 
597  auto v_dram_window =
598  make_tile_window(v_dram,
600  {0, i_n1});
601 
602  // lse
603  auto lse_dram_window = [&, i_nhead_ = i_nhead]() {
604  constexpr auto lse_dram_window_lengths = make_tuple(number<FmhaPipeline::kM0>{});
605  if constexpr(kStoreLSE)
606  {
607  LSEDataType* lse_ptr =
608  reinterpret_cast<LSEDataType*>(kargs.lse_ptr) +
609  static_cast<long_index_t>(i_nhead_) * kargs.nhead_stride_lse + batch_offset_lse;
610 
611  const auto lse_dram = [&]() {
612  const auto lse_dram_naive = make_naive_tensor_view<address_space_enum::global>(
613  lse_ptr,
614  make_tuple(kargs.seqlen_q),
615  make_tuple(1),
616  number<1>{},
617  number<1>{});
618 
619  return pad_tensor_view(
620  lse_dram_naive, lse_dram_window_lengths, sequence<kPadSeqLenQ>{});
621  }();
622 
623  return make_tile_window(lse_dram, lse_dram_window_lengths, {i_m0});
624  }
625  else
626  {
627  return make_null_tile_window(lse_dram_window_lengths);
628  }
629  }();
630 
631  FmhaMask mask = [&]() {
632  if constexpr(kHasMask)
633  return ck_tile::make_generic_attention_mask_from_lr_window<FmhaMask>(
634  kargs.window_size_left,
635  kargs.window_size_right,
636  kargs.seqlen_q,
637  kargs.seqlen_k,
639  else
640  return FmhaMask{kargs.seqlen_q, kargs.seqlen_k};
641  }();
642 
643  AttentionVariant variant;
644  const auto variant_params = [&] {
645  if constexpr(kHasLogitsSoftCap)
646  {
648  mask, kargs.scale_s, kargs.logits_soft_cap, kargs.logits_soft_cap_rcp};
649  }
650  else
651  {
652  return ck_tile::StandardAttentionParams<FmhaMask>{mask, kargs.scale_s};
653  }
654  }();
655 
656  BlockIndices block_indices{i_batch, i_nhead, i_nhead / kargs.nhead_ratio_qk};
657 
658  auto o_acc_tile = [&]() {
659  return FmhaPipeline{}(q_dram_window,
660  k_dram_window,
661  v_dram_window,
662  lse_dram_window,
663  mask,
664  kargs.scale_s,
665  variant,
666  variant_params,
667  block_indices,
668  smem_ptr);
669  }();
670 
671  // O DRAM and O DRAM window
672  auto o_dram = [&]() {
673  const auto o_dram_naive = make_naive_tensor_view<address_space_enum::global>(
674  o_ptr,
675  make_tuple(kargs.seqlen_q, kargs.hdim_v),
676  make_tuple(kargs.stride_o, 1),
678  number<1>{});
679 
680  return pad_tensor_view(
681  o_dram_naive,
684  }();
685 
686  auto o_dram_window =
687  make_tile_window(o_dram,
689  {i_m0, i_n1});
690 
691  EpiloguePipeline{}(o_dram_window, o_acc_tile, nullptr);
692  }
693 };
694 } // namespace ck_tile
#define CK_TILE_DEVICE
Definition: config.hpp:45
#define CK_TILE_HOST
Definition: config.hpp:44
#define CK_TILE_HOST_DEVICE
Definition: config.hpp:46
Definition: cluster_descriptor.hpp:13
constexpr CK_TILE_DEVICE auto make_null_tile_window(const WindowLengths &window_lengths)
Definition: null_tile_window.hpp:66
constexpr CK_TILE_HOST_DEVICE auto integer_divide_ceil(X x, Y y)
Definition: math.hpp:145
__device__ uint32_t amd_wave_read_first_lane(uint16_t v)
Definition: amd_buffer_addressing.hpp:36
int32_t index_t
Definition: integer.hpp:9
constexpr CK_TILE_HOST_DEVICE auto pad_tensor_view(const TensorView &tensor_view, const TileLengths &tile_lengths, DoPads)
Definition: tensor_view.hpp:545
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition: type_traits.hpp:21
int64_t long_index_t
Definition: integer.hpp:11
int32_t int32_t
Definition: integer.hpp:10
constexpr CK_TILE_DEVICE auto make_tile_window(null_tensor_view, const WindowLengths &window_lengths, const multi_index< WindowLengths::size()> &, Ts &&...)
Definition: null_tile_window.hpp:75
constexpr CK_TILE_HOST_DEVICE auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:360
GenericAttentionMaskEnum
Definition: block_masking.hpp:11
constexpr CK_TILE_HOST_DEVICE T max(T x)
Definition: math.hpp:157
typename conditional< predicate, X, Y >::type conditional_t
Definition: functional.hpp:115
Definition: fmha_fwd_v3_kernel.hpp:158
ck_tile::index_t qo_head_idx
Definition: fmha_fwd_v3_kernel.hpp:160
ck_tile::index_t kv_head_idx
Definition: fmha_fwd_v3_kernel.hpp:161
ck_tile::index_t batch_idx
Definition: fmha_fwd_v3_kernel.hpp:159
Definition: fmha_fwd_v3_kernel.hpp:126
const ck_tile::index_t * cu_seqlen_k_ptr
Definition: fmha_fwd_v3_kernel.hpp:135
ck_tile::index_t batch_stride_o
Definition: fmha_fwd_v3_kernel.hpp:130
ck_tile::index_t batch_stride_q
Definition: fmha_fwd_v3_kernel.hpp:127
ck_tile::index_t batch_stride_k
Definition: fmha_fwd_v3_kernel.hpp:128
ck_tile::index_t batch_stride_v
Definition: fmha_fwd_v3_kernel.hpp:129
const ck_tile::index_t * cu_seqlen_q_ptr
Definition: fmha_fwd_v3_kernel.hpp:134
Definition: fmha_fwd_v3_kernel.hpp:56
ck_tile::index_t nhead_stride_v
Definition: fmha_fwd_v3_kernel.hpp:80
ck_tile::index_t nhead_stride_q
Definition: fmha_fwd_v3_kernel.hpp:78
ck_tile::index_t nhead_ratio_qk
Definition: fmha_fwd_v3_kernel.hpp:70
ck_tile::index_t stride_o
Definition: fmha_fwd_v3_kernel.hpp:76
ck_tile::index_t seqlen_k
Definition: fmha_fwd_v3_kernel.hpp:63
ck_tile::index_t nhead_stride_o
Definition: fmha_fwd_v3_kernel.hpp:81
ck_tile::index_t hdim_q
Definition: fmha_fwd_v3_kernel.hpp:64
ck_tile::index_t seqlen_q
Definition: fmha_fwd_v3_kernel.hpp:62
ck_tile::index_t stride_v
Definition: fmha_fwd_v3_kernel.hpp:75
const void * q_ptr
Definition: fmha_fwd_v3_kernel.hpp:57
float scale_s
Definition: fmha_fwd_v3_kernel.hpp:71
const void * v_ptr
Definition: fmha_fwd_v3_kernel.hpp:59
void * o_ptr
Definition: fmha_fwd_v3_kernel.hpp:60
ck_tile::index_t stride_q
Definition: fmha_fwd_v3_kernel.hpp:73
ck_tile::index_t num_head_q
Definition: fmha_fwd_v3_kernel.hpp:67
const void * k_ptr
Definition: fmha_fwd_v3_kernel.hpp:58
ck_tile::index_t hdim_v
Definition: fmha_fwd_v3_kernel.hpp:65
ck_tile::index_t nhead_stride_k
Definition: fmha_fwd_v3_kernel.hpp:79
ck_tile::index_t stride_k
Definition: fmha_fwd_v3_kernel.hpp:74
Definition: fmha_fwd_v3_kernel.hpp:93
void * lse_ptr
Definition: fmha_fwd_v3_kernel.hpp:94
ck_tile::index_t nhead_stride_lse
Definition: fmha_fwd_v3_kernel.hpp:95
ck_tile::index_t batch_stride_lse
Definition: fmha_fwd_v3_kernel.hpp:96
Definition: fmha_fwd_v3_kernel.hpp:49
Definition: fmha_fwd_v3_kernel.hpp:143
const int32_t * cu_seqlen_q_ptr
Definition: fmha_fwd_v3_kernel.hpp:151
const int32_t * cu_seqlen_k_ptr
Definition: fmha_fwd_v3_kernel.hpp:152
const int32_t * seqlen_k_ptr
Definition: fmha_fwd_v3_kernel.hpp:147
const int32_t * seqstart_k_ptr
Definition: fmha_fwd_v3_kernel.hpp:145
const int32_t * seqstart_q_ptr
Definition: fmha_fwd_v3_kernel.hpp:144
const int32_t * seqlen_q_ptr
Definition: fmha_fwd_v3_kernel.hpp:146
Definition: fmha_fwd_v3_kernel.hpp:100
float logits_soft_cap_rcp
Definition: fmha_fwd_v3_kernel.hpp:118
void init_logits_soft_cap(float logits_soft_cap_)
Definition: fmha_fwd_v3_kernel.hpp:103
float logits_soft_cap
Definition: fmha_fwd_v3_kernel.hpp:117
Definition: fmha_fwd_v3_kernel.hpp:85
ck_tile::index_t window_size_left
Definition: fmha_fwd_v3_kernel.hpp:87
ck_tile::index_t remap_opt
Definition: fmha_fwd_v3_kernel.hpp:89
ck_tile::index_t window_size_right
Definition: fmha_fwd_v3_kernel.hpp:87
ck_tile::GenericAttentionMaskEnum mask_type
Definition: fmha_fwd_v3_kernel.hpp:88
Definition: fmha_fwd_v3_kernel.hpp:20
static constexpr bool kPadSeqLenQ
Definition: fmha_fwd_v3_kernel.hpp:35
ck_tile::remove_cvref_t< typename FmhaPipeline::QDataType > QDataType
Definition: fmha_fwd_v3_kernel.hpp:27
CK_TILE_DEVICE void operator()(Kargs kargs) const
Definition: fmha_fwd_v3_kernel.hpp:436
static constexpr CK_TILE_HOST std::enable_if_t< Cond, Kargs > MakeKargs(const void *q_ptr, const void *k_ptr, const void *v_ptr, void *lse_ptr, void *o_ptr, ck_tile::index_t seqlen_q, ck_tile::index_t seqlen_k, ck_tile::index_t hdim_q, ck_tile::index_t hdim_v, ck_tile::index_t num_head_q, ck_tile::index_t nhead_ratio_qk, float scale_s, float logits_soft_cap, ck_tile::index_t stride_q, ck_tile::index_t stride_k, ck_tile::index_t stride_v, ck_tile::index_t stride_o, ck_tile::index_t nhead_stride_q, ck_tile::index_t nhead_stride_k, ck_tile::index_t nhead_stride_v, ck_tile::index_t nhead_stride_lse, ck_tile::index_t nhead_stride_o, ck_tile::index_t batch_stride_q, ck_tile::index_t batch_stride_k, ck_tile::index_t batch_stride_v, ck_tile::index_t batch_stride_lse, ck_tile::index_t batch_stride_o, ck_tile::index_t window_size_left, ck_tile::index_t window_size_right, ck_tile::index_t mask_type, ck_tile::index_t remap_opt, const void *cu_seqlen_q_ptr=nullptr, const void *cu_seqlen_k_ptr=nullptr)
Definition: fmha_fwd_v3_kernel.hpp:166
static constexpr bool kPadSeqLenK
Definition: fmha_fwd_v3_kernel.hpp:36
ck_tile::remove_cvref_t< typename FmhaPipeline::AttentionVariant > AttentionVariant
Definition: fmha_fwd_v3_kernel.hpp:42
static constexpr CK_TILE_HOST_DEVICE ck_tile::index_t GetSmemSize()
Definition: fmha_fwd_v3_kernel.hpp:431
static constexpr ck_tile::index_t kBlockSize
Definition: fmha_fwd_v3_kernel.hpp:23
ck_tile::remove_cvref_t< typename FmhaPipeline::ODataType > ODataType
Definition: fmha_fwd_v3_kernel.hpp:31
ck_tile::remove_cvref_t< typename FmhaPipeline::KDataType > KDataType
Definition: fmha_fwd_v3_kernel.hpp:28
static constexpr bool kHasLogitsSoftCap
Definition: fmha_fwd_v3_kernel.hpp:39
static constexpr bool kPadHeadDimV
Definition: fmha_fwd_v3_kernel.hpp:38
static constexpr CK_TILE_HOST auto BlockSize()
Definition: fmha_fwd_v3_kernel.hpp:429
ck_tile::remove_cvref_t< FmhaPipeline_ > FmhaPipeline
Definition: fmha_fwd_v3_kernel.hpp:21
static constexpr CK_TILE_HOST std::enable_if_t< Cond, Kargs > MakeKargs(const void *q_ptr, const void *k_ptr, const void *v_ptr, void *lse_ptr, void *o_ptr, const void *seqstart_q_ptr, const void *seqstart_k_ptr, const void *seqlen_q_ptr, const void *seqlen_k_ptr, ck_tile::index_t hdim_q, ck_tile::index_t hdim_v, ck_tile::index_t num_head_q, ck_tile::index_t nhead_ratio_qk, float scale_s, float logits_soft_cap, ck_tile::index_t stride_q, ck_tile::index_t stride_k, ck_tile::index_t stride_v, ck_tile::index_t stride_o, ck_tile::index_t nhead_stride_q, ck_tile::index_t nhead_stride_k, ck_tile::index_t nhead_stride_v, ck_tile::index_t nhead_stride_lse, ck_tile::index_t nhead_stride_o, ck_tile::index_t window_size_left, ck_tile::index_t window_size_right, ck_tile::index_t mask_type, ck_tile::index_t remap_opt, const void *cu_seqlen_q_ptr=nullptr, const void *cu_seqlen_k_ptr=nullptr)
Definition: fmha_fwd_v3_kernel.hpp:252
static constexpr bool kHasMask
Definition: fmha_fwd_v3_kernel.hpp:44
std::conditional_t< kIsGroupMode, FmhaFwdGroupModeKargs, FmhaFwdBatchModeKargs > Kargs
Definition: fmha_fwd_v3_kernel.hpp:155
static constexpr CK_TILE_DEVICE auto GetTileIndex(const Kargs &)
Definition: fmha_fwd_v3_kernel.hpp:389
ck_tile::remove_cvref_t< EpiloguePipeline_ > EpiloguePipeline
Definition: fmha_fwd_v3_kernel.hpp:22
static constexpr bool kPadHeadDimQ
Definition: fmha_fwd_v3_kernel.hpp:37
ck_tile::remove_cvref_t< typename FmhaPipeline::VDataType > VDataType
Definition: fmha_fwd_v3_kernel.hpp:29
static constexpr CK_TILE_HOST auto GridSize(ck_tile::index_t batch_size, ck_tile::index_t nhead, ck_tile::index_t max_seqlen_q, ck_tile::index_t hdim_v)
Definition: fmha_fwd_v3_kernel.hpp:332
static constexpr ck_tile::index_t kBlockPerCu
Definition: fmha_fwd_v3_kernel.hpp:24
static constexpr bool kStoreLSE
Definition: fmha_fwd_v3_kernel.hpp:40
ck_tile::remove_cvref_t< typename FmhaPipeline::SaccDataType > SaccDataType
Definition: fmha_fwd_v3_kernel.hpp:32
static constexpr CK_TILE_DEVICE auto RemapTileIndices(int32_t tg_idx, int32_t tg_idy, int32_t remap_option)
Definition: fmha_fwd_v3_kernel.hpp:354
ck_tile::remove_cvref_t< typename FmhaPipeline::LSEDataType > LSEDataType
Definition: fmha_fwd_v3_kernel.hpp:30
static constexpr bool kIsGroupMode
Definition: fmha_fwd_v3_kernel.hpp:34
ck_tile::remove_cvref_t< typename FmhaPipeline::FmhaMask > FmhaMask
Definition: fmha_fwd_v3_kernel.hpp:43
Definition: variants.hpp:63
float logits_soft_cap
Definition: variants.hpp:128
Definition: variants.hpp:51
Definition: integral_constant.hpp:13
Definition: sequence.hpp:49