/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/fmha/kernel/fmha_fwd_v3_kernel.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/fmha/kernel/fmha_fwd_v3_kernel.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/fmha/kernel/fmha_fwd_v3_kernel.hpp Source File
fmha_fwd_v3_kernel.hpp
Go to the documentation of this file.
1 // Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
2 // SPDX-License-Identifier: MIT
3 
4 #pragma once
5 
6 #include "ck_tile/core.hpp"
7 #include "ck_tile/ops/common.hpp"
9 
10 #include <type_traits>
11 #include <utility>
12 
13 namespace ck_tile {
14 
17 template <typename FmhaPipeline_, typename EpiloguePipeline_>
19 {
22  static constexpr ck_tile::index_t kBlockSize = FmhaPipeline::kBlockSize;
23  static constexpr ck_tile::index_t kBlockPerCu = FmhaPipeline::kBlockPerCu;
24  static_assert(kBlockPerCu > 0);
25 
32 
33  static constexpr bool kIsGroupMode = FmhaPipeline::kIsGroupMode;
34  static constexpr bool kPadSeqLenQ = FmhaPipeline::kPadSeqLenQ;
35  static constexpr bool kPadSeqLenK = FmhaPipeline::kPadSeqLenK;
36  static constexpr bool kPadHeadDimQ = FmhaPipeline::kPadHeadDimQ;
37  static constexpr bool kPadHeadDimV = FmhaPipeline::kPadHeadDimV;
38  static constexpr bool kStoreLSE = FmhaPipeline::kStoreLSE;
39 
41  static constexpr bool kHasMask = FmhaMask::IsMasking;
42 
43  template <ck_tile::index_t I> // to avoid duplicated base class prblem, introduce an template
44  // arg
46  {
47  };
48 
49  // kargs use aggregate initializer, so no constructor will provided
50  // use inheritance to minimize karg size
51  // user need to use MakeKargs() function to create kargs.
53  {
54  const void* q_ptr;
55  const void* k_ptr;
56  const void* v_ptr;
57  void* o_ptr;
58 
63 
65  // for MQA/GQA, nhead could be different. This parameter is nhead_q / nhead_k
66  // if this param is larger than 1, indicate MQA/GQA case
68  float scale_s;
69 
74 
79  };
80 
82  {
83  // ck_tile::index_t window_size_left, window_size_right;
87  };
88 
90  {
91  void* lse_ptr = nullptr;
94  };
95 
98  std::conditional_t<kHasMask, FmhaFwdMaskKargs, FmhaFwdEmptyKargs<0>>,
99  std::conditional_t<kStoreLSE, FmhaFwdCommonLSEKargs, FmhaFwdEmptyKargs<1>>
100  {
105 
106  // Optional cumulative sequence length pointers for batch mode
107  // If provided, they override seqlen_q / seqlen_k per-batch to skip tail padding.
108  const ck_tile::index_t* cu_seqlen_q_ptr = nullptr; // [batch+1]
109  const ck_tile::index_t* cu_seqlen_k_ptr = nullptr; // [batch+1]
110  };
111 
114  std::conditional_t<kHasMask, FmhaFwdMaskKargs, FmhaFwdEmptyKargs<0>>,
115  std::conditional_t<kStoreLSE, FmhaFwdCommonLSEKargs, FmhaFwdEmptyKargs<1>>
116  {
121 
122  // Optional cumulative padded sequence starts (including PAD tokens)
123  // Used solely to compute memory offsets when sequences are physically padded.
124  const int32_t* cu_seqlen_q_ptr = nullptr; // [batch+1]
125  const int32_t* cu_seqlen_k_ptr = nullptr; // [batch+1]
126  };
127 
128  using Kargs = std::conditional_t<kIsGroupMode, FmhaFwdGroupModeKargs, FmhaFwdBatchModeKargs>;
129 
130  template <bool Cond = !kIsGroupMode>
131  CK_TILE_HOST static constexpr std::enable_if_t<Cond, Kargs>
132  MakeKargs(const void* q_ptr,
133  const void* k_ptr,
134  const void* v_ptr,
135  void* lse_ptr,
136  void* o_ptr,
137  ck_tile::index_t seqlen_q,
138  ck_tile::index_t seqlen_k,
139  ck_tile::index_t hdim_q,
140  ck_tile::index_t hdim_v,
141  ck_tile::index_t num_head_q,
142  ck_tile::index_t nhead_ratio_qk,
143  float scale_s,
144  ck_tile::index_t stride_q,
145  ck_tile::index_t stride_k,
146  ck_tile::index_t stride_v,
147  ck_tile::index_t stride_o,
148  ck_tile::index_t nhead_stride_q,
149  ck_tile::index_t nhead_stride_k,
150  ck_tile::index_t nhead_stride_v,
151  ck_tile::index_t nhead_stride_lse,
152  ck_tile::index_t nhead_stride_o,
153  ck_tile::index_t batch_stride_q,
154  ck_tile::index_t batch_stride_k,
155  ck_tile::index_t batch_stride_v,
156  ck_tile::index_t batch_stride_lse,
157  ck_tile::index_t batch_stride_o,
158  ck_tile::index_t window_size_left,
159  ck_tile::index_t window_size_right,
160  ck_tile::index_t mask_type,
161  ck_tile::index_t remap_opt,
162  const void* cu_seqlen_q_ptr = nullptr,
163  const void* cu_seqlen_k_ptr = nullptr)
164  {
165  Kargs kargs{{q_ptr,
166  k_ptr,
167  v_ptr,
168  o_ptr,
169  seqlen_q,
170  seqlen_k,
171  hdim_q,
172  hdim_v,
173  num_head_q,
174  nhead_ratio_qk,
175  static_cast<float>(scale_s * ck_tile::log2e_v<>),
176  stride_q,
177  stride_k,
178  stride_v,
179  stride_o,
180  nhead_stride_q,
181  nhead_stride_k,
182  nhead_stride_v,
183  nhead_stride_o}, // args for common karg
184  {}, // placeholder for mask
185  {}, // placeholder for lse
186  batch_stride_q,
187  batch_stride_k,
188  batch_stride_v,
189  batch_stride_o};
190 
191  if constexpr(kHasMask)
192  {
193  kargs.window_size_left = window_size_left;
194  kargs.window_size_right = window_size_right;
195  kargs.mask_type = static_cast<ck_tile::GenericAttentionMaskEnum>(mask_type);
196  kargs.remap_opt = remap_opt;
197  }
198  if constexpr(kStoreLSE)
199  {
200  kargs.lse_ptr = lse_ptr;
201  kargs.nhead_stride_lse = nhead_stride_lse;
202  kargs.batch_stride_lse = batch_stride_lse;
203  }
204 
205  kargs.cu_seqlen_q_ptr = reinterpret_cast<const int32_t*>(cu_seqlen_q_ptr);
206  kargs.cu_seqlen_k_ptr = reinterpret_cast<const int32_t*>(cu_seqlen_k_ptr);
207  return kargs;
208  }
209 
210  template <bool Cond = kIsGroupMode>
211  CK_TILE_HOST static constexpr std::enable_if_t<Cond, Kargs>
212  MakeKargs(const void* q_ptr,
213  const void* k_ptr,
214  const void* v_ptr,
215  void* lse_ptr,
216  void* o_ptr,
217  const void* seqstart_q_ptr,
218  const void* seqstart_k_ptr,
219  const void* seqlen_q_ptr,
220  const void* seqlen_k_ptr,
221  ck_tile::index_t hdim_q,
222  ck_tile::index_t hdim_v,
223  ck_tile::index_t num_head_q,
224  ck_tile::index_t nhead_ratio_qk,
225  float scale_s,
226  ck_tile::index_t stride_q,
227  ck_tile::index_t stride_k,
228  ck_tile::index_t stride_v,
229  ck_tile::index_t stride_o,
230  ck_tile::index_t nhead_stride_q,
231  ck_tile::index_t nhead_stride_k,
232  ck_tile::index_t nhead_stride_v,
233  ck_tile::index_t nhead_stride_lse,
234  ck_tile::index_t nhead_stride_o,
235  ck_tile::index_t window_size_left,
236  ck_tile::index_t window_size_right,
237  ck_tile::index_t mask_type,
238  ck_tile::index_t remap_opt,
239  const void* cu_seqlen_q_ptr = nullptr,
240  const void* cu_seqlen_k_ptr = nullptr)
241  {
242  Kargs kargs{{q_ptr,
243  k_ptr,
244  v_ptr,
245  o_ptr,
246  -1, // seqlen will be updated by another pointer
247  -1, //
248  hdim_q,
249  hdim_v,
250  num_head_q,
251  nhead_ratio_qk,
252  static_cast<float>(scale_s * ck_tile::log2e_v<>),
253  stride_q,
254  stride_k,
255  stride_v,
256  stride_o,
257  nhead_stride_q,
258  nhead_stride_k,
259  nhead_stride_v,
260  nhead_stride_o}, // args for common karg
261  {}, // placeholder for mask
262  {}, // placeholder for lse
263  reinterpret_cast<const int32_t*>(seqstart_q_ptr),
264  reinterpret_cast<const int32_t*>(seqstart_k_ptr),
265  reinterpret_cast<const int32_t*>(seqlen_q_ptr),
266  reinterpret_cast<const int32_t*>(seqlen_k_ptr)};
267 
268  if constexpr(kHasMask)
269  {
270  kargs.window_size_left = window_size_left;
271  kargs.window_size_right = window_size_right;
272  kargs.mask_type = static_cast<ck_tile::GenericAttentionMaskEnum>(mask_type);
273  kargs.remap_opt = remap_opt;
274  }
275  if constexpr(kStoreLSE)
276  {
277  kargs.lse_ptr = lse_ptr;
278  kargs.nhead_stride_lse = nhead_stride_lse;
279  }
280 
281  kargs.cu_seqlen_q_ptr = reinterpret_cast<const int32_t*>(cu_seqlen_q_ptr);
282  kargs.cu_seqlen_k_ptr = reinterpret_cast<const int32_t*>(cu_seqlen_k_ptr);
283  return kargs;
284  }
285 
286  CK_TILE_HOST static constexpr auto GridSize(ck_tile::index_t batch_size,
287  ck_tile::index_t nhead,
288  ck_tile::index_t max_seqlen_q,
289  ck_tile::index_t hdim_v)
290  {
291  if constexpr(kIsGroupMode)
292  {
293  return dim3(nhead,
294  batch_size,
295  ck_tile::integer_divide_ceil(max_seqlen_q, FmhaPipeline::kM0) *
296  ck_tile::integer_divide_ceil(hdim_v, FmhaPipeline::kN1));
297  }
298  else
299  {
300  return dim3(nhead,
301  ck_tile::integer_divide_ceil(max_seqlen_q, FmhaPipeline::kM0) *
302  ck_tile::integer_divide_ceil(hdim_v, FmhaPipeline::kN1),
303  batch_size);
304  }
305  }
306 
307  CK_TILE_DEVICE static constexpr auto
308  RemapTileIndices(int32_t tg_idx, int32_t tg_idy, int32_t remap_option)
309  {
310  if(remap_option < 1)
311  {
312  return make_tuple(static_cast<int32_t>(gridDim.x - tg_idx - 1), tg_idy);
313  }
314 
315  int32_t remapped_tg_idx = tg_idx;
316  int32_t remapped_tg_idy = tg_idy;
317 
318  if(remap_option == 2)
319  { // special remapping
320  int32_t tmp0 = (remapped_tg_idy & 0x7) * gridDim.x + remapped_tg_idx;
321  int32_t tmp1 = tmp0 & 0x7;
322 
323  remapped_tg_idx = tmp0 >> 3;
324  remapped_tg_idy = (remapped_tg_idy & 0xfffffff8) + tmp1;
325  }
326  else
327  { // normal remapping
328  int32_t cus_per_xdim_per_xcc = gridDim.x >> 3;
329  int32_t tgs_cu_id = remapped_tg_idx >> 3;
330 
331  if(tgs_cu_id < cus_per_xdim_per_xcc)
332  {
333  int32_t tgs_xcc_id = remapped_tg_idx & 0x7;
334  int32_t new_tg_idx = tgs_xcc_id * cus_per_xdim_per_xcc + tgs_cu_id;
335 
336  remapped_tg_idx = new_tg_idx;
337  }
338  }
339 
340  return make_tuple(remapped_tg_idx, remapped_tg_idy);
341  }
342 
343  CK_TILE_DEVICE static constexpr auto GetTileIndex(const Kargs&)
344  {
345  using namespace ck_tile;
346 
347  // const index_t num_tile_n1 = ck_tile::integer_divide_ceil(kargs.hdim_v,
348  // FmhaPipeline::kN1);
349 
350  // assume that num_tile_n1 is always 1
351  if constexpr(kIsGroupMode)
352  {
353  const index_t i_nhead = blockIdx.x;
354  const index_t i_batch = blockIdx.y;
355  const index_t i_block = blockIdx.z;
356 
357  if constexpr(kHasMask)
358  {
359  return ck_tile::make_tuple(gridDim.z - 1 - i_block, 0, i_nhead, i_batch);
360  }
361  else
362  {
363  return ck_tile::make_tuple(i_block, 0, i_nhead, i_batch);
364  }
365  }
366  else
367  {
368  const index_t i_nhead = blockIdx.x;
369  const index_t i_block = blockIdx.y;
370  const index_t i_batch = blockIdx.z;
371 
372  if constexpr(kHasMask)
373  {
374  return ck_tile::make_tuple(gridDim.y - 1 - i_block, 0, i_nhead, i_batch);
375  }
376  else
377  {
378  return ck_tile::make_tuple(i_block, 0, i_nhead, i_batch);
379  }
380  }
381  }
382 
383  CK_TILE_HOST static constexpr auto BlockSize() { return dim3(kBlockSize); }
384 
386  {
387  return ck_tile::max(FmhaPipeline::GetSmemSize(), EpiloguePipeline::GetSmemSize());
388  }
389 
390  CK_TILE_DEVICE void operator()(Kargs kargs) const
391  {
392  using namespace ck_tile;
393 
394  // allocate LDS
395  __shared__ char smem_ptr[GetSmemSize()];
396 
397  // divide problem
398  const auto [i_tile_m, i_tile_n, i_nhead, i_batch] = GetTileIndex(kargs);
399 
400  const index_t i_m0 = amd_wave_read_first_lane(i_tile_m * FmhaPipeline::kM0);
401  const index_t i_n1 = amd_wave_read_first_lane(i_tile_n * FmhaPipeline::kN1);
402 
403  long_index_t batch_offset_q = 0;
404  long_index_t batch_offset_k = 0;
405  long_index_t batch_offset_v = 0;
406  long_index_t batch_offset_lse = 0;
407  long_index_t batch_offset_o = 0;
408 
409  if constexpr(kIsGroupMode)
410  {
411  // Use seqstart_q_ptr and seqstart_k_ptr for physical starts
412  const long_index_t query_start = kargs.seqstart_q_ptr[i_batch];
413  const long_index_t key_start = kargs.seqstart_k_ptr[i_batch];
414 
415  batch_offset_q = query_start * kargs.stride_q;
416  batch_offset_k = key_start * kargs.stride_k;
417  batch_offset_v = key_start * kargs.stride_v;
418 
419  if constexpr(kStoreLSE)
420  {
421  // LSE layout is [nhead, total_seqlen], index by unpadded start
422  batch_offset_lse = query_start;
423  }
424  batch_offset_o = query_start * kargs.stride_o;
425 
426  // real logical lengths (exclude PAD)
427  // Priority: seqlen_q_ptr > cu_seqlen_q_ptr > calculated from seqstart_q_ptr
428  if(kargs.seqlen_q_ptr != nullptr)
429  {
430  kargs.seqlen_q = kargs.seqlen_q_ptr[i_batch];
431  }
432  else if(kargs.cu_seqlen_q_ptr != nullptr)
433  {
434  kargs.seqlen_q =
435  kargs.cu_seqlen_q_ptr[i_batch + 1] - kargs.cu_seqlen_q_ptr[i_batch];
436  }
437  else
438  {
439  kargs.seqlen_q = kargs.seqstart_q_ptr[i_batch + 1] - kargs.seqstart_q_ptr[i_batch];
440  }
441  // # of required blocks is different in each groups, terminate unnecessary blocks
442  // earlier
443  if(kargs.seqlen_q <= i_m0)
444  {
445  return;
446  }
447 
448  if(kargs.seqlen_k_ptr != nullptr)
449  {
450  kargs.seqlen_k = kargs.seqlen_k_ptr[i_batch];
451  }
452  else if(kargs.cu_seqlen_k_ptr != nullptr)
453  {
454  kargs.seqlen_k =
455  kargs.cu_seqlen_k_ptr[i_batch + 1] - kargs.cu_seqlen_k_ptr[i_batch];
456  }
457  else
458  {
459  kargs.seqlen_k = kargs.seqstart_k_ptr[i_batch + 1] - kargs.seqstart_k_ptr[i_batch];
460  }
461  }
462  else
463  {
464  batch_offset_q = static_cast<long_index_t>(i_batch) * kargs.batch_stride_q;
465  batch_offset_k = static_cast<long_index_t>(i_batch) * kargs.batch_stride_k;
466  batch_offset_v = static_cast<long_index_t>(i_batch) * kargs.batch_stride_v;
467  if constexpr(kStoreLSE)
468  {
469  batch_offset_lse = static_cast<long_index_t>(i_batch) * kargs.batch_stride_lse;
470  }
471  batch_offset_o = static_cast<long_index_t>(i_batch) * kargs.batch_stride_o;
472 
473  // If cumulative seqlen pointers are provided, override per-batch effective lengths
474  if(kargs.cu_seqlen_q_ptr != nullptr)
475  {
476  kargs.seqlen_q =
477  kargs.cu_seqlen_q_ptr[i_batch + 1] - kargs.cu_seqlen_q_ptr[i_batch];
478  }
479  if(kargs.cu_seqlen_k_ptr != nullptr)
480  {
481  kargs.seqlen_k =
482  kargs.cu_seqlen_k_ptr[i_batch + 1] - kargs.cu_seqlen_k_ptr[i_batch];
483  }
484  }
485 
486  // for simplicity, batch stride we just modify the pointer
487  const QDataType* q_ptr = reinterpret_cast<const QDataType*>(kargs.q_ptr) +
488  static_cast<long_index_t>(i_nhead) * kargs.nhead_stride_q +
489  batch_offset_q;
490  const KDataType* k_ptr =
491  reinterpret_cast<const KDataType*>(kargs.k_ptr) +
492  static_cast<long_index_t>(i_nhead / kargs.nhead_ratio_qk) * kargs.nhead_stride_k +
493  batch_offset_k;
494  const VDataType* v_ptr =
495  reinterpret_cast<const VDataType*>(kargs.v_ptr) +
496  static_cast<long_index_t>(i_nhead / kargs.nhead_ratio_qk) * kargs.nhead_stride_v +
497  batch_offset_v;
498  ODataType* o_ptr = reinterpret_cast<ODataType*>(kargs.o_ptr) +
499  static_cast<long_index_t>(i_nhead) * kargs.nhead_stride_o +
500  batch_offset_o;
501 
502  // Q/K/V DRAM and DRAM window
503  const auto q_dram = [&]() {
504  const auto q_dram_naive = make_naive_tensor_view<address_space_enum::global>(
505  q_ptr,
506  make_tuple(kargs.seqlen_q, kargs.hdim_q),
507  make_tuple(kargs.stride_q, 1),
509  number<1>{});
510 
511  return pad_tensor_view(
512  q_dram_naive,
515  }();
516  const auto k_dram = [&]() {
517  const auto k_dram_naive = make_naive_tensor_view<address_space_enum::global>(
518  k_ptr,
519  make_tuple(kargs.seqlen_k, kargs.hdim_q),
520  make_tuple(kargs.stride_k, 1),
522  number<1>{});
523 
524  return pad_tensor_view(
525  k_dram_naive,
528  }();
529  const auto v_dram = [&]() {
530  const auto v_dram_naive = make_naive_tensor_view<address_space_enum::global>(
531  v_ptr,
532  make_tuple(kargs.seqlen_k, kargs.hdim_v),
533  make_tuple(kargs.stride_v, 1),
535  number<1>{});
536 
537  return pad_tensor_view(
538  v_dram_naive,
541  }();
542 
543  auto q_dram_window = make_tile_window(
544  q_dram,
546  {i_m0, 0});
547 
548  auto k_dram_window = make_tile_window(
550 
551  auto v_dram_window =
552  make_tile_window(v_dram,
554  {0, i_n1});
555 
556  // lse
557  auto lse_dram_window = [&, i_nhead_ = i_nhead]() {
558  constexpr auto lse_dram_window_lengths = make_tuple(number<FmhaPipeline::kM0>{});
559  if constexpr(kStoreLSE)
560  {
561  LSEDataType* lse_ptr =
562  reinterpret_cast<LSEDataType*>(kargs.lse_ptr) +
563  static_cast<long_index_t>(i_nhead_) * kargs.nhead_stride_lse + batch_offset_lse;
564 
565  const auto lse_dram = [&]() {
566  const auto lse_dram_naive = make_naive_tensor_view<address_space_enum::global>(
567  lse_ptr,
568  make_tuple(kargs.seqlen_q),
569  make_tuple(1),
570  number<1>{},
571  number<1>{});
572 
573  return pad_tensor_view(
574  lse_dram_naive, lse_dram_window_lengths, sequence<kPadSeqLenQ>{});
575  }();
576 
577  return make_tile_window(lse_dram, lse_dram_window_lengths, {i_m0});
578  }
579  else
580  {
581  return make_null_tile_window(lse_dram_window_lengths);
582  }
583  }();
584 
585  FmhaMask mask = [&]() {
586  if constexpr(kHasMask)
587  return ck_tile::make_generic_attention_mask_from_lr_window<FmhaMask>(
588  kargs.window_size_left,
589  kargs.window_size_right,
590  kargs.seqlen_q,
591  kargs.seqlen_k,
593  else
594  return FmhaMask{kargs.seqlen_q, kargs.seqlen_k};
595  }();
596 
597  auto o_acc_tile = [&]() {
598  return FmhaPipeline{}(q_dram_window,
599  k_dram_window,
600  v_dram_window,
601  lse_dram_window,
602  mask,
603  kargs.scale_s,
604  smem_ptr);
605  }();
606 
607  // O DRAM and O DRAM window
608  auto o_dram = [&]() {
609  const auto o_dram_naive = make_naive_tensor_view<address_space_enum::global>(
610  o_ptr,
611  make_tuple(kargs.seqlen_q, kargs.hdim_v),
612  make_tuple(kargs.stride_o, 1),
614  number<1>{});
615 
616  return pad_tensor_view(
617  o_dram_naive,
620  }();
621 
622  auto o_dram_window =
623  make_tile_window(o_dram,
625  {i_m0, i_n1});
626 
627  EpiloguePipeline{}(o_dram_window, o_acc_tile, nullptr);
628  }
629 };
630 } // namespace ck_tile
#define CK_TILE_DEVICE
Definition: config.hpp:45
#define CK_TILE_HOST
Definition: config.hpp:44
#define CK_TILE_HOST_DEVICE
Definition: config.hpp:46
Definition: cluster_descriptor.hpp:13
constexpr CK_TILE_DEVICE auto make_null_tile_window(const WindowLengths &window_lengths)
Definition: null_tile_window.hpp:66
constexpr CK_TILE_HOST_DEVICE auto integer_divide_ceil(X x, Y y)
Definition: math.hpp:149
__device__ uint32_t amd_wave_read_first_lane(uint16_t v)
Definition: amd_buffer_addressing.hpp:35
int32_t index_t
Definition: integer.hpp:9
constexpr CK_TILE_HOST_DEVICE auto pad_tensor_view(const TensorView &tensor_view, const TileLengths &tile_lengths, DoPads)
Definition: tensor_view.hpp:545
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition: type_traits.hpp:21
int64_t long_index_t
Definition: integer.hpp:11
int32_t int32_t
Definition: integer.hpp:10
constexpr CK_TILE_DEVICE auto make_tile_window(null_tensor_view, const WindowLengths &window_lengths, const multi_index< WindowLengths::size()> &, Ts &&...)
Definition: null_tile_window.hpp:75
constexpr CK_TILE_HOST_DEVICE auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:360
GenericAttentionMaskEnum
Definition: block_masking.hpp:11
constexpr CK_TILE_HOST_DEVICE T max(T x)
Definition: math.hpp:161
typename conditional< predicate, X, Y >::type conditional_t
Definition: functional.hpp:115
Definition: fmha_fwd_v3_kernel.hpp:100
const ck_tile::index_t * cu_seqlen_k_ptr
Definition: fmha_fwd_v3_kernel.hpp:109
ck_tile::index_t batch_stride_o
Definition: fmha_fwd_v3_kernel.hpp:104
ck_tile::index_t batch_stride_q
Definition: fmha_fwd_v3_kernel.hpp:101
ck_tile::index_t batch_stride_k
Definition: fmha_fwd_v3_kernel.hpp:102
ck_tile::index_t batch_stride_v
Definition: fmha_fwd_v3_kernel.hpp:103
const ck_tile::index_t * cu_seqlen_q_ptr
Definition: fmha_fwd_v3_kernel.hpp:108
Definition: fmha_fwd_v3_kernel.hpp:53
ck_tile::index_t nhead_stride_v
Definition: fmha_fwd_v3_kernel.hpp:77
ck_tile::index_t nhead_stride_q
Definition: fmha_fwd_v3_kernel.hpp:75
ck_tile::index_t nhead_ratio_qk
Definition: fmha_fwd_v3_kernel.hpp:67
ck_tile::index_t stride_o
Definition: fmha_fwd_v3_kernel.hpp:73
ck_tile::index_t seqlen_k
Definition: fmha_fwd_v3_kernel.hpp:60
ck_tile::index_t nhead_stride_o
Definition: fmha_fwd_v3_kernel.hpp:78
ck_tile::index_t hdim_q
Definition: fmha_fwd_v3_kernel.hpp:61
ck_tile::index_t seqlen_q
Definition: fmha_fwd_v3_kernel.hpp:59
ck_tile::index_t stride_v
Definition: fmha_fwd_v3_kernel.hpp:72
const void * q_ptr
Definition: fmha_fwd_v3_kernel.hpp:54
float scale_s
Definition: fmha_fwd_v3_kernel.hpp:68
const void * v_ptr
Definition: fmha_fwd_v3_kernel.hpp:56
void * o_ptr
Definition: fmha_fwd_v3_kernel.hpp:57
ck_tile::index_t stride_q
Definition: fmha_fwd_v3_kernel.hpp:70
ck_tile::index_t num_head_q
Definition: fmha_fwd_v3_kernel.hpp:64
const void * k_ptr
Definition: fmha_fwd_v3_kernel.hpp:55
ck_tile::index_t hdim_v
Definition: fmha_fwd_v3_kernel.hpp:62
ck_tile::index_t nhead_stride_k
Definition: fmha_fwd_v3_kernel.hpp:76
ck_tile::index_t stride_k
Definition: fmha_fwd_v3_kernel.hpp:71
Definition: fmha_fwd_v3_kernel.hpp:90
void * lse_ptr
Definition: fmha_fwd_v3_kernel.hpp:91
ck_tile::index_t nhead_stride_lse
Definition: fmha_fwd_v3_kernel.hpp:92
ck_tile::index_t batch_stride_lse
Definition: fmha_fwd_v3_kernel.hpp:93
Definition: fmha_fwd_v3_kernel.hpp:46
Definition: fmha_fwd_v3_kernel.hpp:116
const int32_t * cu_seqlen_q_ptr
Definition: fmha_fwd_v3_kernel.hpp:124
const int32_t * cu_seqlen_k_ptr
Definition: fmha_fwd_v3_kernel.hpp:125
const int32_t * seqlen_k_ptr
Definition: fmha_fwd_v3_kernel.hpp:120
const int32_t * seqstart_k_ptr
Definition: fmha_fwd_v3_kernel.hpp:118
const int32_t * seqstart_q_ptr
Definition: fmha_fwd_v3_kernel.hpp:117
const int32_t * seqlen_q_ptr
Definition: fmha_fwd_v3_kernel.hpp:119
Definition: fmha_fwd_v3_kernel.hpp:82
ck_tile::index_t window_size_left
Definition: fmha_fwd_v3_kernel.hpp:84
ck_tile::index_t remap_opt
Definition: fmha_fwd_v3_kernel.hpp:86
ck_tile::index_t window_size_right
Definition: fmha_fwd_v3_kernel.hpp:84
ck_tile::GenericAttentionMaskEnum mask_type
Definition: fmha_fwd_v3_kernel.hpp:85
Definition: fmha_fwd_v3_kernel.hpp:19
static constexpr bool kPadSeqLenQ
Definition: fmha_fwd_v3_kernel.hpp:34
ck_tile::remove_cvref_t< typename FmhaPipeline::QDataType > QDataType
Definition: fmha_fwd_v3_kernel.hpp:26
CK_TILE_DEVICE void operator()(Kargs kargs) const
Definition: fmha_fwd_v3_kernel.hpp:390
static constexpr bool kPadSeqLenK
Definition: fmha_fwd_v3_kernel.hpp:35
static constexpr CK_TILE_HOST_DEVICE ck_tile::index_t GetSmemSize()
Definition: fmha_fwd_v3_kernel.hpp:385
static constexpr ck_tile::index_t kBlockSize
Definition: fmha_fwd_v3_kernel.hpp:22
ck_tile::remove_cvref_t< typename FmhaPipeline::ODataType > ODataType
Definition: fmha_fwd_v3_kernel.hpp:30
ck_tile::remove_cvref_t< typename FmhaPipeline::KDataType > KDataType
Definition: fmha_fwd_v3_kernel.hpp:27
static constexpr bool kPadHeadDimV
Definition: fmha_fwd_v3_kernel.hpp:37
static constexpr CK_TILE_HOST auto BlockSize()
Definition: fmha_fwd_v3_kernel.hpp:383
ck_tile::remove_cvref_t< FmhaPipeline_ > FmhaPipeline
Definition: fmha_fwd_v3_kernel.hpp:20
static constexpr bool kHasMask
Definition: fmha_fwd_v3_kernel.hpp:41
std::conditional_t< kIsGroupMode, FmhaFwdGroupModeKargs, FmhaFwdBatchModeKargs > Kargs
Definition: fmha_fwd_v3_kernel.hpp:128
static constexpr CK_TILE_DEVICE auto GetTileIndex(const Kargs &)
Definition: fmha_fwd_v3_kernel.hpp:343
static constexpr CK_TILE_HOST std::enable_if_t< Cond, Kargs > MakeKargs(const void *q_ptr, const void *k_ptr, const void *v_ptr, void *lse_ptr, void *o_ptr, const void *seqstart_q_ptr, const void *seqstart_k_ptr, const void *seqlen_q_ptr, const void *seqlen_k_ptr, ck_tile::index_t hdim_q, ck_tile::index_t hdim_v, ck_tile::index_t num_head_q, ck_tile::index_t nhead_ratio_qk, float scale_s, ck_tile::index_t stride_q, ck_tile::index_t stride_k, ck_tile::index_t stride_v, ck_tile::index_t stride_o, ck_tile::index_t nhead_stride_q, ck_tile::index_t nhead_stride_k, ck_tile::index_t nhead_stride_v, ck_tile::index_t nhead_stride_lse, ck_tile::index_t nhead_stride_o, ck_tile::index_t window_size_left, ck_tile::index_t window_size_right, ck_tile::index_t mask_type, ck_tile::index_t remap_opt, const void *cu_seqlen_q_ptr=nullptr, const void *cu_seqlen_k_ptr=nullptr)
Definition: fmha_fwd_v3_kernel.hpp:212
ck_tile::remove_cvref_t< EpiloguePipeline_ > EpiloguePipeline
Definition: fmha_fwd_v3_kernel.hpp:21
static constexpr bool kPadHeadDimQ
Definition: fmha_fwd_v3_kernel.hpp:36
ck_tile::remove_cvref_t< typename FmhaPipeline::VDataType > VDataType
Definition: fmha_fwd_v3_kernel.hpp:28
static constexpr CK_TILE_HOST auto GridSize(ck_tile::index_t batch_size, ck_tile::index_t nhead, ck_tile::index_t max_seqlen_q, ck_tile::index_t hdim_v)
Definition: fmha_fwd_v3_kernel.hpp:286
static constexpr CK_TILE_HOST std::enable_if_t< Cond, Kargs > MakeKargs(const void *q_ptr, const void *k_ptr, const void *v_ptr, void *lse_ptr, void *o_ptr, ck_tile::index_t seqlen_q, ck_tile::index_t seqlen_k, ck_tile::index_t hdim_q, ck_tile::index_t hdim_v, ck_tile::index_t num_head_q, ck_tile::index_t nhead_ratio_qk, float scale_s, ck_tile::index_t stride_q, ck_tile::index_t stride_k, ck_tile::index_t stride_v, ck_tile::index_t stride_o, ck_tile::index_t nhead_stride_q, ck_tile::index_t nhead_stride_k, ck_tile::index_t nhead_stride_v, ck_tile::index_t nhead_stride_lse, ck_tile::index_t nhead_stride_o, ck_tile::index_t batch_stride_q, ck_tile::index_t batch_stride_k, ck_tile::index_t batch_stride_v, ck_tile::index_t batch_stride_lse, ck_tile::index_t batch_stride_o, ck_tile::index_t window_size_left, ck_tile::index_t window_size_right, ck_tile::index_t mask_type, ck_tile::index_t remap_opt, const void *cu_seqlen_q_ptr=nullptr, const void *cu_seqlen_k_ptr=nullptr)
Definition: fmha_fwd_v3_kernel.hpp:132
static constexpr ck_tile::index_t kBlockPerCu
Definition: fmha_fwd_v3_kernel.hpp:23
static constexpr bool kStoreLSE
Definition: fmha_fwd_v3_kernel.hpp:38
ck_tile::remove_cvref_t< typename FmhaPipeline::SaccDataType > SaccDataType
Definition: fmha_fwd_v3_kernel.hpp:31
static constexpr CK_TILE_DEVICE auto RemapTileIndices(int32_t tg_idx, int32_t tg_idy, int32_t remap_option)
Definition: fmha_fwd_v3_kernel.hpp:308
ck_tile::remove_cvref_t< typename FmhaPipeline::LSEDataType > LSEDataType
Definition: fmha_fwd_v3_kernel.hpp:29
static constexpr bool kIsGroupMode
Definition: fmha_fwd_v3_kernel.hpp:33
ck_tile::remove_cvref_t< typename FmhaPipeline::FmhaMask > FmhaMask
Definition: fmha_fwd_v3_kernel.hpp:40
Definition: integral_constant.hpp:13
Definition: sequence.hpp:49