/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/fmha/kernel/fmha_fwd_v3_kernel.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/fmha/kernel/fmha_fwd_v3_kernel.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/fmha/kernel/fmha_fwd_v3_kernel.hpp Source File
fmha_fwd_v3_kernel.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include "ck_tile/core.hpp"
7 #include "ck_tile/ops/common.hpp"
9 
10 #include <type_traits>
11 #include <utility>
12 
13 namespace ck_tile {
14 
15 template <typename FmhaPipeline_, typename EpiloguePipeline_>
17 {
20  static constexpr ck_tile::index_t kBlockSize = FmhaPipeline::kBlockSize;
21  static constexpr ck_tile::index_t kBlockPerCu = FmhaPipeline::kBlockPerCu;
22  static_assert(kBlockPerCu > 0);
23 
30 
31  static constexpr bool kIsGroupMode = FmhaPipeline::kIsGroupMode;
32  static constexpr bool kPadSeqLenQ = FmhaPipeline::kPadSeqLenQ;
33  static constexpr bool kPadSeqLenK = FmhaPipeline::kPadSeqLenK;
34  static constexpr bool kPadHeadDimQ = FmhaPipeline::kPadHeadDimQ;
35  static constexpr bool kPadHeadDimV = FmhaPipeline::kPadHeadDimV;
36  static constexpr bool kStoreLSE = FmhaPipeline::kStoreLSE;
37 
39  static constexpr bool kHasMask = FmhaMask::IsMasking;
40 
41  template <ck_tile::index_t I> // to avoid duplicated base class prblem, introduce an template
42  // arg
44  {
45  };
46 
47  // kargs use aggregate initializer, so no constructor will provided
48  // use inheritance to minimize karg size
49  // user need to use MakeKargs() function to create kargs.
51  {
52  const void* q_ptr;
53  const void* k_ptr;
54  const void* v_ptr;
55  void* o_ptr;
56 
61 
63  // for MQA/GQA, nhead could be different. This parameter is nhead_q / nhead_k
64  // if this param is larger than 1, indicate MQA/GQA case
66  float scale_s;
67 
72 
77  };
78 
80  {
81  // ck_tile::index_t window_size_left, window_size_right;
84  };
85 
87  {
88  void* lse_ptr = nullptr;
91  };
92 
95  std::conditional_t<kHasMask, FmhaFwdMaskKargs, FmhaFwdEmptyKargs<0>>,
96  std::conditional_t<kStoreLSE, FmhaFwdCommonLSEKargs, FmhaFwdEmptyKargs<1>>
97  {
102  };
103 
106  std::conditional_t<kHasMask, FmhaFwdMaskKargs, FmhaFwdEmptyKargs<0>>,
107  std::conditional_t<kStoreLSE, FmhaFwdCommonLSEKargs, FmhaFwdEmptyKargs<1>>
108  {
112  };
113 
114  using Kargs = std::conditional_t<kIsGroupMode, FmhaFwdGroupModeKargs, FmhaFwdBatchModeKargs>;
115 
116  template <bool Cond = !kIsGroupMode>
117  CK_TILE_HOST static constexpr std::enable_if_t<Cond, Kargs>
118  MakeKargs(const void* q_ptr,
119  const void* k_ptr,
120  const void* v_ptr,
121  void* lse_ptr,
122  void* o_ptr,
123  ck_tile::index_t seqlen_q,
124  ck_tile::index_t seqlen_k,
125  ck_tile::index_t hdim_q,
126  ck_tile::index_t hdim_v,
127  ck_tile::index_t num_head_q,
128  ck_tile::index_t nhead_ratio_qk,
129  float scale_s,
130  ck_tile::index_t stride_q,
131  ck_tile::index_t stride_k,
132  ck_tile::index_t stride_v,
133  ck_tile::index_t stride_o,
134  ck_tile::index_t nhead_stride_q,
135  ck_tile::index_t nhead_stride_k,
136  ck_tile::index_t nhead_stride_v,
137  ck_tile::index_t nhead_stride_lse,
138  ck_tile::index_t nhead_stride_o,
139  ck_tile::index_t batch_stride_q,
140  ck_tile::index_t batch_stride_k,
141  ck_tile::index_t batch_stride_v,
142  ck_tile::index_t batch_stride_lse,
143  ck_tile::index_t batch_stride_o,
144  ck_tile::index_t window_size_left,
145  ck_tile::index_t window_size_right,
146  ck_tile::index_t mask_type)
147  {
148  Kargs kargs{{q_ptr,
149  k_ptr,
150  v_ptr,
151  o_ptr,
152  seqlen_q,
153  seqlen_k,
154  hdim_q,
155  hdim_v,
156  num_head_q,
157  nhead_ratio_qk,
158  static_cast<float>(scale_s * ck_tile::log2e_v<>),
159  stride_q,
160  stride_k,
161  stride_v,
162  stride_o,
163  nhead_stride_q,
164  nhead_stride_k,
165  nhead_stride_v,
166  nhead_stride_o}, // args for common karg
167  {}, // placeholder for mask
168  {}, // placeholder for lse
169  batch_stride_q,
170  batch_stride_k,
171  batch_stride_v,
172  batch_stride_o};
173 
174  if constexpr(kHasMask)
175  {
176  kargs.window_size_left = window_size_left;
177  kargs.window_size_right = window_size_right;
178  kargs.mask_type = static_cast<ck_tile::GenericAttentionMaskEnum>(mask_type);
179  }
180  if constexpr(kStoreLSE)
181  {
182  kargs.lse_ptr = lse_ptr;
183  kargs.nhead_stride_lse = nhead_stride_lse;
184  kargs.batch_stride_lse = batch_stride_lse;
185  }
186 
187  return kargs;
188  }
189 
190  template <bool Cond = kIsGroupMode>
191  CK_TILE_HOST static constexpr std::enable_if_t<Cond, Kargs>
192  MakeKargs(const void* q_ptr,
193  const void* k_ptr,
194  const void* v_ptr,
195  void* lse_ptr,
196  void* o_ptr,
197  const void* seqstart_q_ptr,
198  const void* seqstart_k_ptr,
199  const void* seqlen_k_ptr,
200  ck_tile::index_t hdim_q,
201  ck_tile::index_t hdim_v,
202  ck_tile::index_t num_head_q,
203  ck_tile::index_t nhead_ratio_qk,
204  float scale_s,
205  ck_tile::index_t stride_q,
206  ck_tile::index_t stride_k,
207  ck_tile::index_t stride_v,
208  ck_tile::index_t stride_o,
209  ck_tile::index_t nhead_stride_q,
210  ck_tile::index_t nhead_stride_k,
211  ck_tile::index_t nhead_stride_v,
212  ck_tile::index_t nhead_stride_lse,
213  ck_tile::index_t nhead_stride_o,
214  ck_tile::index_t window_size_left,
215  ck_tile::index_t window_size_right,
216  ck_tile::index_t mask_type)
217  {
218  Kargs kargs{{q_ptr,
219  k_ptr,
220  v_ptr,
221  o_ptr,
222  -1, // seqlen will be updated by another pointer
223  -1, //
224  hdim_q,
225  hdim_v,
226  num_head_q,
227  nhead_ratio_qk,
228  static_cast<float>(scale_s * ck_tile::log2e_v<>),
229  stride_q,
230  stride_k,
231  stride_v,
232  stride_o,
233  nhead_stride_q,
234  nhead_stride_k,
235  nhead_stride_v,
236  nhead_stride_o}, // args for common karg
237  {}, // placeholder for mask
238  {}, // placeholder for lse
239  reinterpret_cast<const int32_t*>(seqstart_q_ptr),
240  reinterpret_cast<const int32_t*>(seqstart_k_ptr),
241  reinterpret_cast<const int32_t*>(seqlen_k_ptr)};
242 
243  if constexpr(kHasMask)
244  {
245  kargs.window_size_left = window_size_left;
246  kargs.window_size_right = window_size_right;
247  kargs.mask_type = static_cast<ck_tile::GenericAttentionMaskEnum>(mask_type);
248  }
249  if constexpr(kStoreLSE)
250  {
251  kargs.lse_ptr = lse_ptr;
252  kargs.nhead_stride_lse = nhead_stride_lse;
253  }
254 
255  return kargs;
256  }
257 
258  CK_TILE_HOST static constexpr auto GridSize(ck_tile::index_t batch_size_,
259  ck_tile::index_t nhead_,
260  ck_tile::index_t seqlen_q_,
261  ck_tile::index_t hdim_v_)
262  {
263  // TODO: this may need tuning
264  return dim3(ck_tile::integer_divide_ceil(seqlen_q_, FmhaPipeline::kM0) *
265  ck_tile::integer_divide_ceil(hdim_v_, FmhaPipeline::kN1),
266  nhead_,
267  batch_size_);
268  }
269 
270  CK_TILE_DEVICE static constexpr auto GetTileIndex(const Kargs& kargs)
271  {
272  using namespace ck_tile;
273 
274  // const index_t num_tile_m0 = seqlen_q / kM0;
275  const index_t num_tile_n1 = ck_tile::integer_divide_ceil(kargs.hdim_v, FmhaPipeline::kN1);
276 
277  const index_t i_block = blockIdx.x;
278  const index_t i_nhead = blockIdx.y;
279  const index_t i_batch = blockIdx.z;
280 
281  const auto f = [](index_t dividend, index_t divisor) {
282  index_t quotient = dividend / divisor;
283  index_t modulus = dividend - quotient * divisor;
284  return ck_tile::make_tuple(quotient, modulus);
285  };
286 
287  const auto [i_tile_m, i_tile_n] = f(i_block, num_tile_n1);
288 
289  if constexpr(kHasMask)
290  {
291  // assume that num_tile_n1 is always 1
292  return ck_tile::make_tuple(gridDim.x - 1 - i_tile_m, i_tile_n, i_nhead, i_batch);
293  }
294  else
295  {
296  return ck_tile::make_tuple(i_tile_m, i_tile_n, i_nhead, i_batch);
297  }
298  }
299 
300  CK_TILE_HOST static constexpr auto BlockSize() { return dim3(kBlockSize); }
301 
303  {
304  return ck_tile::max(FmhaPipeline::GetSmemSize(), EpiloguePipeline::GetSmemSize());
305  }
306 
307  CK_TILE_DEVICE void operator()(Kargs kargs) const
308  {
309  using namespace ck_tile;
310 
311  // allocate LDS
312  __shared__ char smem_ptr[GetSmemSize()];
313 
314  // divide problem
315  const auto [i_tile_m, i_tile_n, i_nhead, i_batch] = GetTileIndex(kargs);
316 
317  const index_t i_m0 = __builtin_amdgcn_readfirstlane(i_tile_m * FmhaPipeline::kM0);
318  const index_t i_n1 = __builtin_amdgcn_readfirstlane(i_tile_n * FmhaPipeline::kN1);
319 
320  long_index_t batch_offset_q = 0;
321  long_index_t batch_offset_k = 0;
322  long_index_t batch_offset_v = 0;
323  long_index_t batch_offset_lse = 0;
324  long_index_t batch_offset_o = 0;
325 
326  if constexpr(kIsGroupMode)
327  {
328  // get starting offset for each batch
329  const long_index_t query_start = kargs.seqstart_q_ptr[i_batch];
330  const long_index_t key_start = kargs.seqstart_k_ptr[i_batch];
331 
332  batch_offset_q = query_start * kargs.stride_q;
333  batch_offset_k = key_start * kargs.stride_k;
334  batch_offset_v = key_start * kargs.stride_v;
335 
336  if constexpr(kStoreLSE)
337  {
338  batch_offset_lse = query_start;
339  }
340  batch_offset_o = query_start * kargs.stride_o;
341 
342  // get real # queries & # keys under group mode
343  const auto adjusted_seqstart_q_ptr = kargs.seqstart_q_ptr + i_batch;
344  kargs.seqlen_q = adjusted_seqstart_q_ptr[1] - adjusted_seqstart_q_ptr[0];
345 
346  // # of required blocks is different in each groups, terminate unnecessary blocks
347  // earlier
348  if(kargs.seqlen_q <= i_m0)
349  {
350  return;
351  }
352 
353  if(kargs.seqlen_k_ptr != nullptr)
354  {
355  kargs.seqlen_k = kargs.seqlen_k_ptr[i_batch];
356  }
357  else
358  {
359  const auto adjusted_seqstart_k_ptr = kargs.seqstart_k_ptr + i_batch;
360  kargs.seqlen_k = adjusted_seqstart_k_ptr[1] - adjusted_seqstart_k_ptr[0];
361  }
362  }
363  else
364  {
365  batch_offset_q = static_cast<long_index_t>(i_batch) * kargs.batch_stride_q;
366  batch_offset_k = static_cast<long_index_t>(i_batch) * kargs.batch_stride_k;
367  batch_offset_v = static_cast<long_index_t>(i_batch) * kargs.batch_stride_v;
368  if constexpr(kStoreLSE)
369  {
370  batch_offset_lse = static_cast<long_index_t>(i_batch) * kargs.batch_stride_lse;
371  }
372  batch_offset_o = static_cast<long_index_t>(i_batch) * kargs.batch_stride_o;
373  }
374 
375  // for simplicity, batch stride we just modify the pointer
376  const QDataType* q_ptr = reinterpret_cast<const QDataType*>(kargs.q_ptr) +
377  static_cast<long_index_t>(i_nhead) * kargs.nhead_stride_q +
378  batch_offset_q;
379  const KDataType* k_ptr =
380  reinterpret_cast<const KDataType*>(kargs.k_ptr) +
381  static_cast<long_index_t>(i_nhead / kargs.nhead_ratio_qk) * kargs.nhead_stride_k +
382  batch_offset_k;
383  const VDataType* v_ptr =
384  reinterpret_cast<const VDataType*>(kargs.v_ptr) +
385  static_cast<long_index_t>(i_nhead / kargs.nhead_ratio_qk) * kargs.nhead_stride_v +
386  batch_offset_v;
387  ODataType* o_ptr = reinterpret_cast<ODataType*>(kargs.o_ptr) +
388  static_cast<long_index_t>(i_nhead) * kargs.nhead_stride_o +
389  batch_offset_o;
390 
391  // Q/K/V DRAM and DRAM window
392  const auto q_dram = [&]() {
393  const auto q_dram_naive = make_naive_tensor_view<address_space_enum::global>(
394  q_ptr,
395  make_tuple(kargs.seqlen_q, kargs.hdim_q),
396  make_tuple(kargs.stride_q, 1),
398  number<1>{});
399 
400  return pad_tensor_view(
401  q_dram_naive,
404  }();
405  const auto k_dram = [&]() {
406  const auto k_dram_naive = make_naive_tensor_view<address_space_enum::global>(
407  k_ptr,
408  make_tuple(kargs.seqlen_k, kargs.hdim_q),
409  make_tuple(kargs.stride_k, 1),
411  number<1>{});
412 
413  return pad_tensor_view(
414  k_dram_naive,
417  }();
418  const auto v_dram = [&]() {
419  const auto v_dram_naive = make_naive_tensor_view<address_space_enum::global>(
420  v_ptr,
421  make_tuple(kargs.seqlen_k, kargs.hdim_v),
422  make_tuple(kargs.stride_v, 1),
424  number<1>{});
425 
426  return pad_tensor_view(
427  v_dram_naive,
430  }();
431 
432  auto q_dram_window = make_tile_window(
433  q_dram,
435  {i_m0, 0});
436 
437  auto k_dram_window = make_tile_window(
439 
440  auto v_dram_window =
441  make_tile_window(v_dram,
443  {0, i_n1});
444 
445  // lse
446  auto lse_dram_window = [&, i_nhead_ = i_nhead]() {
447  constexpr auto lse_dram_window_lengths = make_tuple(number<FmhaPipeline::kM0>{});
448  if constexpr(kStoreLSE)
449  {
450  LSEDataType* lse_ptr =
451  reinterpret_cast<LSEDataType*>(kargs.lse_ptr) +
452  static_cast<long_index_t>(i_nhead_) * kargs.nhead_stride_lse + batch_offset_lse;
453 
454  const auto lse_dram = [&]() {
455  const auto lse_dram_naive = make_naive_tensor_view<address_space_enum::global>(
456  lse_ptr,
457  make_tuple(kargs.seqlen_q),
458  make_tuple(1),
459  number<1>{},
460  number<1>{});
461 
462  return pad_tensor_view(
463  lse_dram_naive, lse_dram_window_lengths, sequence<kPadSeqLenQ>{});
464  }();
465 
466  return make_tile_window(lse_dram, lse_dram_window_lengths, {i_m0});
467  }
468  else
469  {
470  return make_null_tile_window(lse_dram_window_lengths);
471  }
472  }();
473 
474  FmhaMask mask = [&]() {
475  if constexpr(kHasMask)
476  return ck_tile::make_generic_attention_mask_from_lr_window<FmhaMask>(
477  kargs.window_size_left,
478  kargs.window_size_right,
479  kargs.seqlen_q,
480  kargs.seqlen_k,
482  else
483  return FmhaMask{kargs.seqlen_q, kargs.seqlen_k};
484  }();
485 
486  auto o_acc_tile = [&]() {
487  return FmhaPipeline{}(q_dram_window,
488  k_dram_window,
489  v_dram_window,
490  lse_dram_window,
491  mask,
492  kargs.scale_s,
493  smem_ptr);
494  }();
495 
496  // O DRAM and O DRAM window
497  auto o_dram = [&]() {
498  const auto o_dram_naive = make_naive_tensor_view<address_space_enum::global>(
499  o_ptr,
500  make_tuple(kargs.seqlen_q, kargs.hdim_v),
501  make_tuple(kargs.stride_o, 1),
503  number<1>{});
504 
505  return pad_tensor_view(
506  o_dram_naive,
509  }();
510 
511  auto o_dram_window =
512  make_tile_window(o_dram,
514  {i_m0, i_n1});
515 
516  EpiloguePipeline{}(o_dram_window, o_acc_tile, nullptr);
517  }
518 };
519 } // namespace ck_tile
#define CK_TILE_DEVICE
Definition: config.hpp:41
#define CK_TILE_HOST
Definition: config.hpp:40
#define CK_TILE_HOST_DEVICE
Definition: config.hpp:42
Definition: cluster_descriptor.hpp:13
constexpr CK_TILE_DEVICE auto make_null_tile_window(const WindowLengths &window_lengths)
Definition: null_tile_window.hpp:66
constexpr CK_TILE_HOST_DEVICE auto integer_divide_ceil(X x, Y y)
Definition: math.hpp:149
int32_t index_t
Definition: integer.hpp:9
constexpr CK_TILE_HOST_DEVICE auto pad_tensor_view(const TensorView &tensor_view, const TileLengths &tile_lengths, DoPads)
Definition: tensor_view.hpp:530
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition: type_traits.hpp:21
int64_t long_index_t
Definition: integer.hpp:11
int32_t int32_t
Definition: integer.hpp:10
constexpr CK_TILE_DEVICE auto make_tile_window(null_tensor_view, const WindowLengths &window_lengths, const multi_index< WindowLengths::size()> &, Ts &&...)
Definition: null_tile_window.hpp:75
constexpr CK_TILE_HOST_DEVICE auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:360
GenericAttentionMaskEnum
Definition: block_masking.hpp:11
constexpr CK_TILE_HOST_DEVICE T max(T x)
Definition: math.hpp:161
typename conditional< predicate, X, Y >::type conditional_t
Definition: functional.hpp:115
Definition: fmha_fwd_v3_kernel.hpp:97
ck_tile::index_t batch_stride_o
Definition: fmha_fwd_v3_kernel.hpp:101
ck_tile::index_t batch_stride_q
Definition: fmha_fwd_v3_kernel.hpp:98
ck_tile::index_t batch_stride_k
Definition: fmha_fwd_v3_kernel.hpp:99
ck_tile::index_t batch_stride_v
Definition: fmha_fwd_v3_kernel.hpp:100
Definition: fmha_fwd_v3_kernel.hpp:51
ck_tile::index_t nhead_stride_v
Definition: fmha_fwd_v3_kernel.hpp:75
ck_tile::index_t nhead_stride_q
Definition: fmha_fwd_v3_kernel.hpp:73
ck_tile::index_t nhead_ratio_qk
Definition: fmha_fwd_v3_kernel.hpp:65
ck_tile::index_t stride_o
Definition: fmha_fwd_v3_kernel.hpp:71
ck_tile::index_t seqlen_k
Definition: fmha_fwd_v3_kernel.hpp:58
ck_tile::index_t nhead_stride_o
Definition: fmha_fwd_v3_kernel.hpp:76
ck_tile::index_t hdim_q
Definition: fmha_fwd_v3_kernel.hpp:59
ck_tile::index_t seqlen_q
Definition: fmha_fwd_v3_kernel.hpp:57
ck_tile::index_t stride_v
Definition: fmha_fwd_v3_kernel.hpp:70
const void * q_ptr
Definition: fmha_fwd_v3_kernel.hpp:52
float scale_s
Definition: fmha_fwd_v3_kernel.hpp:66
const void * v_ptr
Definition: fmha_fwd_v3_kernel.hpp:54
void * o_ptr
Definition: fmha_fwd_v3_kernel.hpp:55
ck_tile::index_t stride_q
Definition: fmha_fwd_v3_kernel.hpp:68
ck_tile::index_t num_head_q
Definition: fmha_fwd_v3_kernel.hpp:62
const void * k_ptr
Definition: fmha_fwd_v3_kernel.hpp:53
ck_tile::index_t hdim_v
Definition: fmha_fwd_v3_kernel.hpp:60
ck_tile::index_t nhead_stride_k
Definition: fmha_fwd_v3_kernel.hpp:74
ck_tile::index_t stride_k
Definition: fmha_fwd_v3_kernel.hpp:69
Definition: fmha_fwd_v3_kernel.hpp:87
void * lse_ptr
Definition: fmha_fwd_v3_kernel.hpp:88
ck_tile::index_t nhead_stride_lse
Definition: fmha_fwd_v3_kernel.hpp:89
ck_tile::index_t batch_stride_lse
Definition: fmha_fwd_v3_kernel.hpp:90
Definition: fmha_fwd_v3_kernel.hpp:44
Definition: fmha_fwd_v3_kernel.hpp:108
const int32_t * seqlen_k_ptr
Definition: fmha_fwd_v3_kernel.hpp:111
const int32_t * seqstart_k_ptr
Definition: fmha_fwd_v3_kernel.hpp:110
const int32_t * seqstart_q_ptr
Definition: fmha_fwd_v3_kernel.hpp:109
Definition: fmha_fwd_v3_kernel.hpp:80
ck_tile::index_t window_size_left
Definition: fmha_fwd_v3_kernel.hpp:82
ck_tile::index_t window_size_right
Definition: fmha_fwd_v3_kernel.hpp:82
ck_tile::GenericAttentionMaskEnum mask_type
Definition: fmha_fwd_v3_kernel.hpp:83
Definition: fmha_fwd_v3_kernel.hpp:17
static constexpr bool kPadSeqLenQ
Definition: fmha_fwd_v3_kernel.hpp:32
ck_tile::remove_cvref_t< typename FmhaPipeline::QDataType > QDataType
Definition: fmha_fwd_v3_kernel.hpp:24
CK_TILE_DEVICE void operator()(Kargs kargs) const
Definition: fmha_fwd_v3_kernel.hpp:307
static constexpr bool kPadSeqLenK
Definition: fmha_fwd_v3_kernel.hpp:33
static constexpr CK_TILE_HOST_DEVICE ck_tile::index_t GetSmemSize()
Definition: fmha_fwd_v3_kernel.hpp:302
static constexpr CK_TILE_HOST auto GridSize(ck_tile::index_t batch_size_, ck_tile::index_t nhead_, ck_tile::index_t seqlen_q_, ck_tile::index_t hdim_v_)
Definition: fmha_fwd_v3_kernel.hpp:258
static constexpr ck_tile::index_t kBlockSize
Definition: fmha_fwd_v3_kernel.hpp:20
ck_tile::remove_cvref_t< typename FmhaPipeline::ODataType > ODataType
Definition: fmha_fwd_v3_kernel.hpp:28
ck_tile::remove_cvref_t< typename FmhaPipeline::KDataType > KDataType
Definition: fmha_fwd_v3_kernel.hpp:25
static constexpr bool kPadHeadDimV
Definition: fmha_fwd_v3_kernel.hpp:35
static constexpr CK_TILE_HOST auto BlockSize()
Definition: fmha_fwd_v3_kernel.hpp:300
ck_tile::remove_cvref_t< FmhaPipeline_ > FmhaPipeline
Definition: fmha_fwd_v3_kernel.hpp:18
static constexpr bool kHasMask
Definition: fmha_fwd_v3_kernel.hpp:39
std::conditional_t< kIsGroupMode, FmhaFwdGroupModeKargs, FmhaFwdBatchModeKargs > Kargs
Definition: fmha_fwd_v3_kernel.hpp:114
static constexpr CK_TILE_HOST std::enable_if_t< Cond, Kargs > MakeKargs(const void *q_ptr, const void *k_ptr, const void *v_ptr, void *lse_ptr, void *o_ptr, const void *seqstart_q_ptr, const void *seqstart_k_ptr, const void *seqlen_k_ptr, ck_tile::index_t hdim_q, ck_tile::index_t hdim_v, ck_tile::index_t num_head_q, ck_tile::index_t nhead_ratio_qk, float scale_s, ck_tile::index_t stride_q, ck_tile::index_t stride_k, ck_tile::index_t stride_v, ck_tile::index_t stride_o, ck_tile::index_t nhead_stride_q, ck_tile::index_t nhead_stride_k, ck_tile::index_t nhead_stride_v, ck_tile::index_t nhead_stride_lse, ck_tile::index_t nhead_stride_o, ck_tile::index_t window_size_left, ck_tile::index_t window_size_right, ck_tile::index_t mask_type)
Definition: fmha_fwd_v3_kernel.hpp:192
ck_tile::remove_cvref_t< EpiloguePipeline_ > EpiloguePipeline
Definition: fmha_fwd_v3_kernel.hpp:19
static constexpr bool kPadHeadDimQ
Definition: fmha_fwd_v3_kernel.hpp:34
ck_tile::remove_cvref_t< typename FmhaPipeline::VDataType > VDataType
Definition: fmha_fwd_v3_kernel.hpp:26
static constexpr ck_tile::index_t kBlockPerCu
Definition: fmha_fwd_v3_kernel.hpp:21
static constexpr CK_TILE_HOST std::enable_if_t< Cond, Kargs > MakeKargs(const void *q_ptr, const void *k_ptr, const void *v_ptr, void *lse_ptr, void *o_ptr, ck_tile::index_t seqlen_q, ck_tile::index_t seqlen_k, ck_tile::index_t hdim_q, ck_tile::index_t hdim_v, ck_tile::index_t num_head_q, ck_tile::index_t nhead_ratio_qk, float scale_s, ck_tile::index_t stride_q, ck_tile::index_t stride_k, ck_tile::index_t stride_v, ck_tile::index_t stride_o, ck_tile::index_t nhead_stride_q, ck_tile::index_t nhead_stride_k, ck_tile::index_t nhead_stride_v, ck_tile::index_t nhead_stride_lse, ck_tile::index_t nhead_stride_o, ck_tile::index_t batch_stride_q, ck_tile::index_t batch_stride_k, ck_tile::index_t batch_stride_v, ck_tile::index_t batch_stride_lse, ck_tile::index_t batch_stride_o, ck_tile::index_t window_size_left, ck_tile::index_t window_size_right, ck_tile::index_t mask_type)
Definition: fmha_fwd_v3_kernel.hpp:118
static constexpr bool kStoreLSE
Definition: fmha_fwd_v3_kernel.hpp:36
ck_tile::remove_cvref_t< typename FmhaPipeline::SaccDataType > SaccDataType
Definition: fmha_fwd_v3_kernel.hpp:29
ck_tile::remove_cvref_t< typename FmhaPipeline::LSEDataType > LSEDataType
Definition: fmha_fwd_v3_kernel.hpp:27
static constexpr bool kIsGroupMode
Definition: fmha_fwd_v3_kernel.hpp:31
ck_tile::remove_cvref_t< typename FmhaPipeline::FmhaMask > FmhaMask
Definition: fmha_fwd_v3_kernel.hpp:38
static constexpr CK_TILE_DEVICE auto GetTileIndex(const Kargs &kargs)
Definition: fmha_fwd_v3_kernel.hpp:270
Definition: integral_constant.hpp:13
Definition: sequence.hpp:49