/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/fmha/kernel/fmha_fwd_v3_kernel.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/fmha/kernel/fmha_fwd_v3_kernel.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/fmha/kernel/fmha_fwd_v3_kernel.hpp Source File
fmha_fwd_v3_kernel.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include "ck_tile/core.hpp"
7 #include "ck_tile/ops/common.hpp"
9 
10 #include <type_traits>
11 #include <utility>
12 
13 namespace ck_tile {
14 
15 template <typename FmhaPipeline_, typename EpiloguePipeline_>
17 {
20  static constexpr ck_tile::index_t kBlockSize = FmhaPipeline::kBlockSize;
21  static constexpr ck_tile::index_t kBlockPerCu = FmhaPipeline::kBlockPerCu;
22  static_assert(kBlockPerCu > 0);
23 
30 
31  static constexpr bool kIsGroupMode = FmhaPipeline::kIsGroupMode;
32  static constexpr bool kPadSeqLenQ = FmhaPipeline::kPadSeqLenQ;
33  static constexpr bool kPadSeqLenK = FmhaPipeline::kPadSeqLenK;
34  static constexpr bool kPadHeadDimQ = FmhaPipeline::kPadHeadDimQ;
35  static constexpr bool kPadHeadDimV = FmhaPipeline::kPadHeadDimV;
36  static constexpr bool kStoreLSE = FmhaPipeline::kStoreLSE;
37 
39  static constexpr bool kHasMask = FmhaMask::IsMasking;
40 
41  template <ck_tile::index_t I> // to avoid duplicated base class prblem, introduce an template
42  // arg
44  {
45  };
46 
47  // kargs use aggregate initializer, so no constructor will provided
48  // use inheritance to minimize karg size
49  // user need to use MakeKargs() function to create kargs.
51  {
52  const void* q_ptr;
53  const void* k_ptr;
54  const void* v_ptr;
55  void* o_ptr;
56 
61 
63  // for MQA/GQA, nhead could be different. This parameter is nhead_q / nhead_k
64  // if this param is larger than 1, indicate MQA/GQA case
66  float scale_s;
67 
72 
77  };
78 
80  {
81  // ck_tile::index_t window_size_left, window_size_right;
85  };
86 
88  {
89  void* lse_ptr = nullptr;
92  };
93 
96  std::conditional_t<kHasMask, FmhaFwdMaskKargs, FmhaFwdEmptyKargs<0>>,
97  std::conditional_t<kStoreLSE, FmhaFwdCommonLSEKargs, FmhaFwdEmptyKargs<1>>
98  {
103 
104  // Optional cumulative sequence length pointers for batch mode
105  // If provided, they override seqlen_q / seqlen_k per-batch to skip tail padding.
106  const ck_tile::index_t* cu_seqlen_q_ptr = nullptr; // [batch+1]
107  const ck_tile::index_t* cu_seqlen_kv_ptr = nullptr; // [batch+1]
108  };
109 
112  std::conditional_t<kHasMask, FmhaFwdMaskKargs, FmhaFwdEmptyKargs<0>>,
113  std::conditional_t<kStoreLSE, FmhaFwdCommonLSEKargs, FmhaFwdEmptyKargs<1>>
114  {
118 
119  // Optional cumulative padded sequence starts (including PAD tokens)
120  // Used solely to compute memory offsets when sequences are physically padded.
121  const int32_t* seqstart_padded_q_ptr = nullptr; // [batch+1]
122  const int32_t* seqstart_padded_k_ptr = nullptr; // [batch+1]
123  };
124 
125  using Kargs = std::conditional_t<kIsGroupMode, FmhaFwdGroupModeKargs, FmhaFwdBatchModeKargs>;
126 
127  template <bool Cond = !kIsGroupMode>
128  CK_TILE_HOST static constexpr std::enable_if_t<Cond, Kargs>
129  MakeKargs(const void* q_ptr,
130  const void* k_ptr,
131  const void* v_ptr,
132  void* lse_ptr,
133  void* o_ptr,
134  ck_tile::index_t seqlen_q,
135  ck_tile::index_t seqlen_k,
136  ck_tile::index_t hdim_q,
137  ck_tile::index_t hdim_v,
138  ck_tile::index_t num_head_q,
139  ck_tile::index_t nhead_ratio_qk,
140  float scale_s,
141  ck_tile::index_t stride_q,
142  ck_tile::index_t stride_k,
143  ck_tile::index_t stride_v,
144  ck_tile::index_t stride_o,
145  ck_tile::index_t nhead_stride_q,
146  ck_tile::index_t nhead_stride_k,
147  ck_tile::index_t nhead_stride_v,
148  ck_tile::index_t nhead_stride_lse,
149  ck_tile::index_t nhead_stride_o,
150  ck_tile::index_t batch_stride_q,
151  ck_tile::index_t batch_stride_k,
152  ck_tile::index_t batch_stride_v,
153  ck_tile::index_t batch_stride_lse,
154  ck_tile::index_t batch_stride_o,
155  ck_tile::index_t window_size_left,
156  ck_tile::index_t window_size_right,
157  ck_tile::index_t mask_type,
158  ck_tile::index_t remap_opt,
159  const ck_tile::index_t* cu_seqlen_q_ptr = nullptr,
160  const ck_tile::index_t* cu_seqlen_kv_ptr = nullptr)
161  {
162  Kargs kargs{{q_ptr,
163  k_ptr,
164  v_ptr,
165  o_ptr,
166  seqlen_q,
167  seqlen_k,
168  hdim_q,
169  hdim_v,
170  num_head_q,
171  nhead_ratio_qk,
172  static_cast<float>(scale_s * ck_tile::log2e_v<>),
173  stride_q,
174  stride_k,
175  stride_v,
176  stride_o,
177  nhead_stride_q,
178  nhead_stride_k,
179  nhead_stride_v,
180  nhead_stride_o}, // args for common karg
181  {}, // placeholder for mask
182  {}, // placeholder for lse
183  batch_stride_q,
184  batch_stride_k,
185  batch_stride_v,
186  batch_stride_o};
187 
188  if constexpr(kHasMask)
189  {
190  kargs.window_size_left = window_size_left;
191  kargs.window_size_right = window_size_right;
192  kargs.mask_type = static_cast<ck_tile::GenericAttentionMaskEnum>(mask_type);
193  kargs.remap_opt = remap_opt;
194  }
195  if constexpr(kStoreLSE)
196  {
197  kargs.lse_ptr = lse_ptr;
198  kargs.nhead_stride_lse = nhead_stride_lse;
199  kargs.batch_stride_lse = batch_stride_lse;
200  }
201 
202  kargs.cu_seqlen_q_ptr = cu_seqlen_q_ptr;
203  kargs.cu_seqlen_kv_ptr = cu_seqlen_kv_ptr;
204  return kargs;
205  }
206 
207  template <bool Cond = kIsGroupMode>
208  CK_TILE_HOST static constexpr std::enable_if_t<Cond, Kargs>
209  MakeKargs(const void* q_ptr,
210  const void* k_ptr,
211  const void* v_ptr,
212  void* lse_ptr,
213  void* o_ptr,
214  const void* seqstart_q_ptr,
215  const void* seqstart_k_ptr,
216  const void* seqlen_k_ptr,
217  ck_tile::index_t hdim_q,
218  ck_tile::index_t hdim_v,
219  ck_tile::index_t num_head_q,
220  ck_tile::index_t nhead_ratio_qk,
221  float scale_s,
222  ck_tile::index_t stride_q,
223  ck_tile::index_t stride_k,
224  ck_tile::index_t stride_v,
225  ck_tile::index_t stride_o,
226  ck_tile::index_t nhead_stride_q,
227  ck_tile::index_t nhead_stride_k,
228  ck_tile::index_t nhead_stride_v,
229  ck_tile::index_t nhead_stride_lse,
230  ck_tile::index_t nhead_stride_o,
231  ck_tile::index_t window_size_left,
232  ck_tile::index_t window_size_right,
233  ck_tile::index_t mask_type,
234  ck_tile::index_t remap_opt,
235  const void* seqstart_padded_q_ptr = nullptr,
236  const void* seqstart_padded_k_ptr = nullptr)
237  {
238  Kargs kargs{{q_ptr,
239  k_ptr,
240  v_ptr,
241  o_ptr,
242  -1, // seqlen will be updated by another pointer
243  -1, //
244  hdim_q,
245  hdim_v,
246  num_head_q,
247  nhead_ratio_qk,
248  static_cast<float>(scale_s * ck_tile::log2e_v<>),
249  stride_q,
250  stride_k,
251  stride_v,
252  stride_o,
253  nhead_stride_q,
254  nhead_stride_k,
255  nhead_stride_v,
256  nhead_stride_o}, // args for common karg
257  {}, // placeholder for mask
258  {}, // placeholder for lse
259  reinterpret_cast<const int32_t*>(seqstart_q_ptr),
260  reinterpret_cast<const int32_t*>(seqstart_k_ptr),
261  reinterpret_cast<const int32_t*>(seqlen_k_ptr)};
262 
263  if constexpr(kHasMask)
264  {
265  kargs.window_size_left = window_size_left;
266  kargs.window_size_right = window_size_right;
267  kargs.mask_type = static_cast<ck_tile::GenericAttentionMaskEnum>(mask_type);
268  kargs.remap_opt = remap_opt;
269  }
270  if constexpr(kStoreLSE)
271  {
272  kargs.lse_ptr = lse_ptr;
273  kargs.nhead_stride_lse = nhead_stride_lse;
274  }
275 
276  kargs.seqstart_padded_q_ptr = reinterpret_cast<const int32_t*>(seqstart_padded_q_ptr);
277  kargs.seqstart_padded_k_ptr = reinterpret_cast<const int32_t*>(seqstart_padded_k_ptr);
278  return kargs;
279  }
280 
281  CK_TILE_HOST static constexpr auto GridSize(ck_tile::index_t batch_size_,
282  ck_tile::index_t nhead_,
283  ck_tile::index_t seqlen_q_,
284  ck_tile::index_t hdim_v_)
285  {
286  // TODO: this may need tuning
287  if constexpr(kHasMask)
288  {
289  return dim3(nhead_,
290  ck_tile::integer_divide_ceil(seqlen_q_, FmhaPipeline::kM0) *
291  ck_tile::integer_divide_ceil(hdim_v_, FmhaPipeline::kN1),
292  batch_size_);
293  }
294  else
295  {
296  return dim3(nhead_,
297  ck_tile::integer_divide_ceil(seqlen_q_, FmhaPipeline::kM0) *
298  ck_tile::integer_divide_ceil(hdim_v_, FmhaPipeline::kN1),
299  batch_size_);
300  }
301  }
302 
303  CK_TILE_DEVICE static constexpr auto
304  RemapTileIndices(int32_t tg_idx, int32_t tg_idy, int32_t remap_option)
305  {
306  if(remap_option < 1)
307  {
308  return make_tuple(static_cast<int32_t>(gridDim.x - tg_idx - 1), tg_idy);
309  }
310 
311  int32_t remapped_tg_idx = tg_idx;
312  int32_t remapped_tg_idy = tg_idy;
313 
314  if(remap_option == 2)
315  { // special remapping
316  int32_t tmp0 = (remapped_tg_idy & 0x7) * gridDim.x + remapped_tg_idx;
317  int32_t tmp1 = tmp0 & 0x7;
318 
319  remapped_tg_idx = tmp0 >> 3;
320  remapped_tg_idy = (remapped_tg_idy & 0xfffffff8) + tmp1;
321  }
322  else
323  { // normal remapping
324  int32_t cus_per_xdim_per_xcc = gridDim.x >> 3;
325  int32_t tgs_cu_id = remapped_tg_idx >> 3;
326 
327  if(tgs_cu_id < cus_per_xdim_per_xcc)
328  {
329  int32_t tgs_xcc_id = remapped_tg_idx & 0x7;
330  int32_t new_tg_idx = tgs_xcc_id * cus_per_xdim_per_xcc + tgs_cu_id;
331 
332  remapped_tg_idx = new_tg_idx;
333  }
334  }
335 
336  return make_tuple(remapped_tg_idx, remapped_tg_idy);
337  }
338 
339  CK_TILE_DEVICE static constexpr auto GetTileIndex(const Kargs&)
340  {
341  using namespace ck_tile;
342 
343  // const index_t num_tile_n1 = ck_tile::integer_divide_ceil(kargs.hdim_v,
344  // FmhaPipeline::kN1);
345 
346  // assume that num_tile_n1 is always 1
347  if constexpr(kHasMask)
348  {
349  const index_t i_nhead = blockIdx.x;
350  const index_t i_block = blockIdx.y;
351  const index_t i_batch = blockIdx.z;
352 
353  return ck_tile::make_tuple(gridDim.y - 1 - i_block, 0, i_nhead, i_batch);
354  }
355  else
356  {
357  const index_t i_nhead = blockIdx.x;
358  const index_t i_block = blockIdx.y;
359  const index_t i_batch = blockIdx.z;
360 
361  return ck_tile::make_tuple(i_block, 0, i_nhead, i_batch);
362  }
363  }
364 
365  CK_TILE_HOST static constexpr auto BlockSize() { return dim3(kBlockSize); }
366 
368  {
369  return ck_tile::max(FmhaPipeline::GetSmemSize(), EpiloguePipeline::GetSmemSize());
370  }
371 
372  CK_TILE_DEVICE void operator()(Kargs kargs) const
373  {
374  using namespace ck_tile;
375 
376  // allocate LDS
377  __shared__ char smem_ptr[GetSmemSize()];
378 
379  // divide problem
380  const auto [i_tile_m, i_tile_n, i_nhead, i_batch] = GetTileIndex(kargs);
381 
382  const index_t i_m0 = amd_wave_read_first_lane(i_tile_m * FmhaPipeline::kM0);
383  const index_t i_n1 = amd_wave_read_first_lane(i_tile_n * FmhaPipeline::kN1);
384 
385  long_index_t batch_offset_q = 0;
386  long_index_t batch_offset_k = 0;
387  long_index_t batch_offset_v = 0;
388  long_index_t batch_offset_lse = 0;
389  long_index_t batch_offset_o = 0;
390 
391  if constexpr(kIsGroupMode)
392  {
393  // get starting offset for each batch
394  const long_index_t query_start_unpadded = kargs.seqstart_q_ptr[i_batch];
395  const long_index_t key_start_unpadded = kargs.seqstart_k_ptr[i_batch];
396 
397  const long_index_t query_start_padded = kargs.seqstart_padded_q_ptr
398  ? kargs.seqstart_padded_q_ptr[i_batch]
399  : query_start_unpadded;
400  const long_index_t key_start_padded = kargs.seqstart_padded_k_ptr
401  ? kargs.seqstart_padded_k_ptr[i_batch]
402  : key_start_unpadded;
403 
404  batch_offset_q = query_start_padded * kargs.stride_q;
405  batch_offset_k = key_start_padded * kargs.stride_k;
406  batch_offset_v = key_start_padded * kargs.stride_v;
407 
408  if constexpr(kStoreLSE)
409  {
410  // LSE layout is [nhead, total_seqlen], index by unpadded start
411  batch_offset_lse = query_start_unpadded;
412  }
413  batch_offset_o = query_start_padded * kargs.stride_o;
414 
415  // get real # queries & # keys under group mode
416  const auto adjusted_seqstart_q_ptr = kargs.seqstart_q_ptr + i_batch;
417  kargs.seqlen_q = adjusted_seqstart_q_ptr[1] - adjusted_seqstart_q_ptr[0];
418 
419  // # of required blocks is different in each groups, terminate unnecessary blocks
420  // earlier
421  if(kargs.seqlen_q <= i_m0)
422  {
423  return;
424  }
425 
426  if(kargs.seqlen_k_ptr != nullptr)
427  {
428  kargs.seqlen_k = kargs.seqlen_k_ptr[i_batch];
429  }
430  else
431  {
432  const auto adjusted_seqstart_k_ptr = kargs.seqstart_k_ptr + i_batch;
433  kargs.seqlen_k = adjusted_seqstart_k_ptr[1] - adjusted_seqstart_k_ptr[0];
434  }
435  }
436  else
437  {
438  batch_offset_q = static_cast<long_index_t>(i_batch) * kargs.batch_stride_q;
439  batch_offset_k = static_cast<long_index_t>(i_batch) * kargs.batch_stride_k;
440  batch_offset_v = static_cast<long_index_t>(i_batch) * kargs.batch_stride_v;
441  if constexpr(kStoreLSE)
442  {
443  batch_offset_lse = static_cast<long_index_t>(i_batch) * kargs.batch_stride_lse;
444  }
445  batch_offset_o = static_cast<long_index_t>(i_batch) * kargs.batch_stride_o;
446 
447  // If cumulative seqlen pointers are provided, override per-batch effective lengths
448  if(kargs.cu_seqlen_q_ptr != nullptr)
449  {
450  kargs.seqlen_q =
451  kargs.cu_seqlen_q_ptr[i_batch + 1] - kargs.cu_seqlen_q_ptr[i_batch];
452  }
453  if(kargs.cu_seqlen_kv_ptr != nullptr)
454  {
455  kargs.seqlen_k =
456  kargs.cu_seqlen_kv_ptr[i_batch + 1] - kargs.cu_seqlen_kv_ptr[i_batch];
457  }
458  }
459 
460  // for simplicity, batch stride we just modify the pointer
461  const QDataType* q_ptr = reinterpret_cast<const QDataType*>(kargs.q_ptr) +
462  static_cast<long_index_t>(i_nhead) * kargs.nhead_stride_q +
463  batch_offset_q;
464  const KDataType* k_ptr =
465  reinterpret_cast<const KDataType*>(kargs.k_ptr) +
466  static_cast<long_index_t>(i_nhead / kargs.nhead_ratio_qk) * kargs.nhead_stride_k +
467  batch_offset_k;
468  const VDataType* v_ptr =
469  reinterpret_cast<const VDataType*>(kargs.v_ptr) +
470  static_cast<long_index_t>(i_nhead / kargs.nhead_ratio_qk) * kargs.nhead_stride_v +
471  batch_offset_v;
472  ODataType* o_ptr = reinterpret_cast<ODataType*>(kargs.o_ptr) +
473  static_cast<long_index_t>(i_nhead) * kargs.nhead_stride_o +
474  batch_offset_o;
475 
476  // Q/K/V DRAM and DRAM window
477  const auto q_dram = [&]() {
478  const auto q_dram_naive = make_naive_tensor_view<address_space_enum::global>(
479  q_ptr,
480  make_tuple(kargs.seqlen_q, kargs.hdim_q),
481  make_tuple(kargs.stride_q, 1),
483  number<1>{});
484 
485  return pad_tensor_view(
486  q_dram_naive,
489  }();
490  const auto k_dram = [&]() {
491  const auto k_dram_naive = make_naive_tensor_view<address_space_enum::global>(
492  k_ptr,
493  make_tuple(kargs.seqlen_k, kargs.hdim_q),
494  make_tuple(kargs.stride_k, 1),
496  number<1>{});
497 
498  return pad_tensor_view(
499  k_dram_naive,
502  }();
503  const auto v_dram = [&]() {
504  const auto v_dram_naive = make_naive_tensor_view<address_space_enum::global>(
505  v_ptr,
506  make_tuple(kargs.seqlen_k, kargs.hdim_v),
507  make_tuple(kargs.stride_v, 1),
509  number<1>{});
510 
511  return pad_tensor_view(
512  v_dram_naive,
515  }();
516 
517  auto q_dram_window = make_tile_window(
518  q_dram,
520  {i_m0, 0});
521 
522  auto k_dram_window = make_tile_window(
524 
525  auto v_dram_window =
526  make_tile_window(v_dram,
528  {0, i_n1});
529 
530  // lse
531  auto lse_dram_window = [&, i_nhead_ = i_nhead]() {
532  constexpr auto lse_dram_window_lengths = make_tuple(number<FmhaPipeline::kM0>{});
533  if constexpr(kStoreLSE)
534  {
535  LSEDataType* lse_ptr =
536  reinterpret_cast<LSEDataType*>(kargs.lse_ptr) +
537  static_cast<long_index_t>(i_nhead_) * kargs.nhead_stride_lse + batch_offset_lse;
538 
539  const auto lse_dram = [&]() {
540  const auto lse_dram_naive = make_naive_tensor_view<address_space_enum::global>(
541  lse_ptr,
542  make_tuple(kargs.seqlen_q),
543  make_tuple(1),
544  number<1>{},
545  number<1>{});
546 
547  return pad_tensor_view(
548  lse_dram_naive, lse_dram_window_lengths, sequence<kPadSeqLenQ>{});
549  }();
550 
551  return make_tile_window(lse_dram, lse_dram_window_lengths, {i_m0});
552  }
553  else
554  {
555  return make_null_tile_window(lse_dram_window_lengths);
556  }
557  }();
558 
559  FmhaMask mask = [&]() {
560  if constexpr(kHasMask)
561  return ck_tile::make_generic_attention_mask_from_lr_window<FmhaMask>(
562  kargs.window_size_left,
563  kargs.window_size_right,
564  kargs.seqlen_q,
565  kargs.seqlen_k,
567  else
568  return FmhaMask{kargs.seqlen_q, kargs.seqlen_k};
569  }();
570 
571  auto o_acc_tile = [&]() {
572  return FmhaPipeline{}(q_dram_window,
573  k_dram_window,
574  v_dram_window,
575  lse_dram_window,
576  mask,
577  kargs.scale_s,
578  smem_ptr);
579  }();
580 
581  // O DRAM and O DRAM window
582  auto o_dram = [&]() {
583  const auto o_dram_naive = make_naive_tensor_view<address_space_enum::global>(
584  o_ptr,
585  make_tuple(kargs.seqlen_q, kargs.hdim_v),
586  make_tuple(kargs.stride_o, 1),
588  number<1>{});
589 
590  return pad_tensor_view(
591  o_dram_naive,
594  }();
595 
596  auto o_dram_window =
597  make_tile_window(o_dram,
599  {i_m0, i_n1});
600 
601  EpiloguePipeline{}(o_dram_window, o_acc_tile, nullptr);
602  }
603 };
604 } // namespace ck_tile
#define CK_TILE_DEVICE
Definition: config.hpp:41
#define CK_TILE_HOST
Definition: config.hpp:40
#define CK_TILE_HOST_DEVICE
Definition: config.hpp:42
Definition: cluster_descriptor.hpp:13
constexpr CK_TILE_DEVICE auto make_null_tile_window(const WindowLengths &window_lengths)
Definition: null_tile_window.hpp:66
constexpr CK_TILE_HOST_DEVICE auto integer_divide_ceil(X x, Y y)
Definition: math.hpp:149
__device__ uint32_t amd_wave_read_first_lane(uint16_t v)
Definition: amd_buffer_addressing.hpp:2834
int32_t index_t
Definition: integer.hpp:9
constexpr CK_TILE_HOST_DEVICE auto pad_tensor_view(const TensorView &tensor_view, const TileLengths &tile_lengths, DoPads)
Definition: tensor_view.hpp:530
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition: type_traits.hpp:21
int64_t long_index_t
Definition: integer.hpp:11
int32_t int32_t
Definition: integer.hpp:10
constexpr CK_TILE_DEVICE auto make_tile_window(null_tensor_view, const WindowLengths &window_lengths, const multi_index< WindowLengths::size()> &, Ts &&...)
Definition: null_tile_window.hpp:75
constexpr CK_TILE_HOST_DEVICE auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:360
GenericAttentionMaskEnum
Definition: block_masking.hpp:11
constexpr CK_TILE_HOST_DEVICE T max(T x)
Definition: math.hpp:161
typename conditional< predicate, X, Y >::type conditional_t
Definition: functional.hpp:115
Definition: fmha_fwd_v3_kernel.hpp:98
const ck_tile::index_t * cu_seqlen_kv_ptr
Definition: fmha_fwd_v3_kernel.hpp:107
ck_tile::index_t batch_stride_o
Definition: fmha_fwd_v3_kernel.hpp:102
ck_tile::index_t batch_stride_q
Definition: fmha_fwd_v3_kernel.hpp:99
ck_tile::index_t batch_stride_k
Definition: fmha_fwd_v3_kernel.hpp:100
ck_tile::index_t batch_stride_v
Definition: fmha_fwd_v3_kernel.hpp:101
const ck_tile::index_t * cu_seqlen_q_ptr
Definition: fmha_fwd_v3_kernel.hpp:106
Definition: fmha_fwd_v3_kernel.hpp:51
ck_tile::index_t nhead_stride_v
Definition: fmha_fwd_v3_kernel.hpp:75
ck_tile::index_t nhead_stride_q
Definition: fmha_fwd_v3_kernel.hpp:73
ck_tile::index_t nhead_ratio_qk
Definition: fmha_fwd_v3_kernel.hpp:65
ck_tile::index_t stride_o
Definition: fmha_fwd_v3_kernel.hpp:71
ck_tile::index_t seqlen_k
Definition: fmha_fwd_v3_kernel.hpp:58
ck_tile::index_t nhead_stride_o
Definition: fmha_fwd_v3_kernel.hpp:76
ck_tile::index_t hdim_q
Definition: fmha_fwd_v3_kernel.hpp:59
ck_tile::index_t seqlen_q
Definition: fmha_fwd_v3_kernel.hpp:57
ck_tile::index_t stride_v
Definition: fmha_fwd_v3_kernel.hpp:70
const void * q_ptr
Definition: fmha_fwd_v3_kernel.hpp:52
float scale_s
Definition: fmha_fwd_v3_kernel.hpp:66
const void * v_ptr
Definition: fmha_fwd_v3_kernel.hpp:54
void * o_ptr
Definition: fmha_fwd_v3_kernel.hpp:55
ck_tile::index_t stride_q
Definition: fmha_fwd_v3_kernel.hpp:68
ck_tile::index_t num_head_q
Definition: fmha_fwd_v3_kernel.hpp:62
const void * k_ptr
Definition: fmha_fwd_v3_kernel.hpp:53
ck_tile::index_t hdim_v
Definition: fmha_fwd_v3_kernel.hpp:60
ck_tile::index_t nhead_stride_k
Definition: fmha_fwd_v3_kernel.hpp:74
ck_tile::index_t stride_k
Definition: fmha_fwd_v3_kernel.hpp:69
Definition: fmha_fwd_v3_kernel.hpp:88
void * lse_ptr
Definition: fmha_fwd_v3_kernel.hpp:89
ck_tile::index_t nhead_stride_lse
Definition: fmha_fwd_v3_kernel.hpp:90
ck_tile::index_t batch_stride_lse
Definition: fmha_fwd_v3_kernel.hpp:91
Definition: fmha_fwd_v3_kernel.hpp:44
Definition: fmha_fwd_v3_kernel.hpp:114
const int32_t * seqstart_padded_q_ptr
Definition: fmha_fwd_v3_kernel.hpp:121
const int32_t * seqlen_k_ptr
Definition: fmha_fwd_v3_kernel.hpp:117
const int32_t * seqstart_padded_k_ptr
Definition: fmha_fwd_v3_kernel.hpp:122
const int32_t * seqstart_k_ptr
Definition: fmha_fwd_v3_kernel.hpp:116
const int32_t * seqstart_q_ptr
Definition: fmha_fwd_v3_kernel.hpp:115
Definition: fmha_fwd_v3_kernel.hpp:80
ck_tile::index_t window_size_left
Definition: fmha_fwd_v3_kernel.hpp:82
ck_tile::index_t remap_opt
Definition: fmha_fwd_v3_kernel.hpp:84
ck_tile::index_t window_size_right
Definition: fmha_fwd_v3_kernel.hpp:82
ck_tile::GenericAttentionMaskEnum mask_type
Definition: fmha_fwd_v3_kernel.hpp:83
Definition: fmha_fwd_v3_kernel.hpp:17
static constexpr bool kPadSeqLenQ
Definition: fmha_fwd_v3_kernel.hpp:32
static constexpr CK_TILE_HOST std::enable_if_t< Cond, Kargs > MakeKargs(const void *q_ptr, const void *k_ptr, const void *v_ptr, void *lse_ptr, void *o_ptr, const void *seqstart_q_ptr, const void *seqstart_k_ptr, const void *seqlen_k_ptr, ck_tile::index_t hdim_q, ck_tile::index_t hdim_v, ck_tile::index_t num_head_q, ck_tile::index_t nhead_ratio_qk, float scale_s, ck_tile::index_t stride_q, ck_tile::index_t stride_k, ck_tile::index_t stride_v, ck_tile::index_t stride_o, ck_tile::index_t nhead_stride_q, ck_tile::index_t nhead_stride_k, ck_tile::index_t nhead_stride_v, ck_tile::index_t nhead_stride_lse, ck_tile::index_t nhead_stride_o, ck_tile::index_t window_size_left, ck_tile::index_t window_size_right, ck_tile::index_t mask_type, ck_tile::index_t remap_opt, const void *seqstart_padded_q_ptr=nullptr, const void *seqstart_padded_k_ptr=nullptr)
Definition: fmha_fwd_v3_kernel.hpp:209
ck_tile::remove_cvref_t< typename FmhaPipeline::QDataType > QDataType
Definition: fmha_fwd_v3_kernel.hpp:24
CK_TILE_DEVICE void operator()(Kargs kargs) const
Definition: fmha_fwd_v3_kernel.hpp:372
static constexpr bool kPadSeqLenK
Definition: fmha_fwd_v3_kernel.hpp:33
static constexpr CK_TILE_HOST_DEVICE ck_tile::index_t GetSmemSize()
Definition: fmha_fwd_v3_kernel.hpp:367
static constexpr CK_TILE_HOST auto GridSize(ck_tile::index_t batch_size_, ck_tile::index_t nhead_, ck_tile::index_t seqlen_q_, ck_tile::index_t hdim_v_)
Definition: fmha_fwd_v3_kernel.hpp:281
static constexpr ck_tile::index_t kBlockSize
Definition: fmha_fwd_v3_kernel.hpp:20
ck_tile::remove_cvref_t< typename FmhaPipeline::ODataType > ODataType
Definition: fmha_fwd_v3_kernel.hpp:28
ck_tile::remove_cvref_t< typename FmhaPipeline::KDataType > KDataType
Definition: fmha_fwd_v3_kernel.hpp:25
static constexpr bool kPadHeadDimV
Definition: fmha_fwd_v3_kernel.hpp:35
static constexpr CK_TILE_HOST auto BlockSize()
Definition: fmha_fwd_v3_kernel.hpp:365
ck_tile::remove_cvref_t< FmhaPipeline_ > FmhaPipeline
Definition: fmha_fwd_v3_kernel.hpp:18
static constexpr bool kHasMask
Definition: fmha_fwd_v3_kernel.hpp:39
std::conditional_t< kIsGroupMode, FmhaFwdGroupModeKargs, FmhaFwdBatchModeKargs > Kargs
Definition: fmha_fwd_v3_kernel.hpp:125
static constexpr CK_TILE_DEVICE auto GetTileIndex(const Kargs &)
Definition: fmha_fwd_v3_kernel.hpp:339
ck_tile::remove_cvref_t< EpiloguePipeline_ > EpiloguePipeline
Definition: fmha_fwd_v3_kernel.hpp:19
static constexpr bool kPadHeadDimQ
Definition: fmha_fwd_v3_kernel.hpp:34
ck_tile::remove_cvref_t< typename FmhaPipeline::VDataType > VDataType
Definition: fmha_fwd_v3_kernel.hpp:26
static constexpr CK_TILE_HOST std::enable_if_t< Cond, Kargs > MakeKargs(const void *q_ptr, const void *k_ptr, const void *v_ptr, void *lse_ptr, void *o_ptr, ck_tile::index_t seqlen_q, ck_tile::index_t seqlen_k, ck_tile::index_t hdim_q, ck_tile::index_t hdim_v, ck_tile::index_t num_head_q, ck_tile::index_t nhead_ratio_qk, float scale_s, ck_tile::index_t stride_q, ck_tile::index_t stride_k, ck_tile::index_t stride_v, ck_tile::index_t stride_o, ck_tile::index_t nhead_stride_q, ck_tile::index_t nhead_stride_k, ck_tile::index_t nhead_stride_v, ck_tile::index_t nhead_stride_lse, ck_tile::index_t nhead_stride_o, ck_tile::index_t batch_stride_q, ck_tile::index_t batch_stride_k, ck_tile::index_t batch_stride_v, ck_tile::index_t batch_stride_lse, ck_tile::index_t batch_stride_o, ck_tile::index_t window_size_left, ck_tile::index_t window_size_right, ck_tile::index_t mask_type, ck_tile::index_t remap_opt, const ck_tile::index_t *cu_seqlen_q_ptr=nullptr, const ck_tile::index_t *cu_seqlen_kv_ptr=nullptr)
Definition: fmha_fwd_v3_kernel.hpp:129
static constexpr ck_tile::index_t kBlockPerCu
Definition: fmha_fwd_v3_kernel.hpp:21
static constexpr bool kStoreLSE
Definition: fmha_fwd_v3_kernel.hpp:36
ck_tile::remove_cvref_t< typename FmhaPipeline::SaccDataType > SaccDataType
Definition: fmha_fwd_v3_kernel.hpp:29
static constexpr CK_TILE_DEVICE auto RemapTileIndices(int32_t tg_idx, int32_t tg_idy, int32_t remap_option)
Definition: fmha_fwd_v3_kernel.hpp:304
ck_tile::remove_cvref_t< typename FmhaPipeline::LSEDataType > LSEDataType
Definition: fmha_fwd_v3_kernel.hpp:27
static constexpr bool kIsGroupMode
Definition: fmha_fwd_v3_kernel.hpp:31
ck_tile::remove_cvref_t< typename FmhaPipeline::FmhaMask > FmhaMask
Definition: fmha_fwd_v3_kernel.hpp:38
Definition: integral_constant.hpp:13
Definition: sequence.hpp:49