/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_combine_kernel.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_combine_kernel.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_combine_kernel.hpp Source File
fmha_fwd_splitkv_combine_kernel.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 namespace ck_tile {
7 
8 template <typename FmhaPipeline_, typename EpiloguePipeline_>
10 {
13 
14  static constexpr index_t kNumWarps = FmhaPipeline::kNumWarps;
15  static constexpr index_t kBlockSize = FmhaPipeline::kBlockSize;
16  static constexpr index_t kBlockPerCu = FmhaPipeline::kBlockPerCu;
17 
18  static_assert(kBlockPerCu > 0);
19  static constexpr index_t kBlockPerCuInput = FmhaPipeline::Problem::kBlockPerCu;
20 
24 
25  static constexpr bool kIsGroupMode = FmhaPipeline::kIsGroupMode;
26  static constexpr bool kPadSeqLenQ = FmhaPipeline::kPadSeqLenQ;
27  static constexpr bool kPadHeadDimV = FmhaPipeline::kPadHeadDimV;
28  static constexpr bool kStoreLSE = FmhaPipeline::kStoreLSE;
29  static constexpr bool kDoFp8StaticQuant = FmhaPipeline::Problem::kDoFp8StaticQuant;
30 
31  // clang-format off
32  template <typename T> struct t2s;
33  template <> struct t2s<float> { static constexpr const char * name = "fp32"; };
34  template <> struct t2s<ck_tile::fp16_t> { static constexpr const char * name = "fp16"; };
35  template <> struct t2s<ck_tile::bf16_t> { static constexpr const char * name = "bf16"; };
36  template <> struct t2s<ck_tile::fp8_t> { static constexpr const char * name = "fp8"; };
37  template <> struct t2s<ck_tile::bf8_t> { static constexpr const char * name = "bf8"; };
38  // clang-format on
39 
40  __host__ static std::string GetName()
41  {
42  // sync with generate.py
43  // clang-format off
44 
45  #define _SS_ std::string
46  #define _TS_ std::to_string
47  auto pn = [&] () {
48  std::string n;
49  if (kPadSeqLenQ) n += "s";
50  if (kPadHeadDimV) n += "dv";
51  return n.empty() ? n : std::string("p") + n; }();
52  return
53  _SS_("fmha_fwd_splitkv_combine_d") + _TS_(FmhaPipeline::kHeadDimV) + "_" + _SS_(t2s<ODataType>::name) +
54  "_" + (kIsGroupMode ? "group" : "batch") + "_"
55  "b" + _TS_(FmhaPipeline::kN1) + "_" +
56  (kBlockPerCuInput == -1 ? "" : ("o" + _TS_(kBlockPerCu) + "_")) +
57  _SS_(FmhaPipeline::name) +
58  (pn.empty() ? "_npad" : "_" + pn) +
59  (kStoreLSE ? "_lse" : "_nlse" ) +
60  (kDoFp8StaticQuant ? "_squant" : "_nsquant" );
61  #undef _SS_
62  #undef _TS_
63  // clang-format on
64  }
65 
66  template <ck_tile::index_t I> // to avoid duplicated base class prblem, introduce an template
67  // arg
68  struct EmptyKargs
69  {
70  };
71 
72  // kargs use aggregate initializer, so no constructor will provided
73  // use inheritance to minimize karg size
74  // user need to use MakeKargs() function to create kargs.
75  struct CommonKargs
76  {
77  const void* lse_acc_ptr;
78  const void* o_acc_ptr;
79  void* o_ptr;
80 
85 
88 
92 
95  };
96 
98  {
99  void* lse_ptr = nullptr;
102  };
103 
105  {
106  float scale_o;
107  };
108 
110  : CommonKargs,
111  std::conditional_t<kStoreLSE, CommonLSEKargs, EmptyKargs<0>>,
112  std::conditional_t<kDoFp8StaticQuant, Fp8StaticQuantKargs, EmptyKargs<1>>
113  {
117  };
118 
120  : CommonKargs,
121  std::conditional_t<kStoreLSE, CommonLSEKargs, EmptyKargs<0>>,
122  std::conditional_t<kDoFp8StaticQuant, Fp8StaticQuantKargs, EmptyKargs<3>>
123  {
125  };
126 
127  using Kargs = std::conditional_t<kIsGroupMode, GroupModeKargs, BatchModeKargs>;
128 
129  template <bool Cond = !kIsGroupMode>
130  __host__ static constexpr std::enable_if_t<Cond, Kargs>
131  MakeKargs(const void* lse_acc_ptr,
132  const void* o_acc_ptr,
133  void* lse_ptr,
134  void* o_ptr,
135  ck_tile::index_t batch,
136  ck_tile::index_t seqlen_q,
137  ck_tile::index_t hdim_v,
138  ck_tile::index_t num_splits,
139  float scale_o,
140  ck_tile::index_t row_stride_o_acc,
141  ck_tile::index_t row_stride_o,
142  ck_tile::index_t nhead_stride_lse_acc,
143  ck_tile::index_t nhead_stride_o_acc,
144  ck_tile::index_t nhead_stride_lse,
145  ck_tile::index_t nhead_stride_o,
146  ck_tile::index_t batch_stride_lse_acc,
147  ck_tile::index_t batch_stride_o_acc,
148  ck_tile::index_t batch_stride_lse,
149  ck_tile::index_t batch_stride_o,
150  ck_tile::index_t split_stride_lse_acc,
151  ck_tile::index_t split_stride_o_acc)
152  {
153  Kargs kargs{{lse_acc_ptr,
154  o_acc_ptr,
155  o_ptr,
156  batch,
157  seqlen_q,
158  hdim_v,
159  num_splits,
160  row_stride_o_acc,
161  row_stride_o,
162  nhead_stride_lse_acc,
163  nhead_stride_o_acc,
164  nhead_stride_o,
165  split_stride_lse_acc,
166  split_stride_o_acc}, // args for common karg
167  {}, // placeholder for lse
168  {}, // placeholder for fp8_static_quant args
169  batch_stride_lse_acc,
170  batch_stride_o_acc,
171  batch_stride_o};
172 
173  if constexpr(kStoreLSE)
174  {
175  kargs.lse_ptr = lse_ptr;
176  kargs.nhead_stride_lse = nhead_stride_lse;
177  kargs.batch_stride_lse = batch_stride_lse;
178  }
179  if constexpr(kDoFp8StaticQuant)
180  {
181  kargs.scale_o = scale_o;
182  }
183 
184  return kargs;
185  }
186 
187  template <bool Cond = kIsGroupMode>
188  __host__ static constexpr std::enable_if_t<Cond, Kargs>
189  MakeKargs(const void* lse_acc_ptr,
190  const void* o_acc_ptr,
191  void* lse_ptr,
192  void* o_ptr,
193  ck_tile::index_t batch,
194  const void* seqstart_q_ptr,
195  ck_tile::index_t hdim_v,
196  ck_tile::index_t num_splits,
197  float scale_o,
198  ck_tile::index_t row_stride_o_acc,
199  ck_tile::index_t row_stride_o,
200  ck_tile::index_t nhead_stride_lse_acc,
201  ck_tile::index_t nhead_stride_o_acc,
202  ck_tile::index_t nhead_stride_lse,
203  ck_tile::index_t nhead_stride_o,
204  ck_tile::index_t split_stride_lse_acc,
205  ck_tile::index_t split_stride_o_acc)
206  {
207  Kargs kargs{{lse_acc_ptr,
208  o_acc_ptr,
209  o_ptr,
210  batch,
211  -1, // seqlen will be updated by another pointer
212  hdim_v,
213  num_splits,
214  row_stride_o_acc,
215  row_stride_o,
216  nhead_stride_lse_acc,
217  nhead_stride_o_acc,
218  nhead_stride_o,
219  split_stride_lse_acc,
220  split_stride_o_acc}, // args for common karg
221  {}, // placeholder for lse
222  {}, // placeholder for fp8_static_quant args
223  reinterpret_cast<const int32_t*>(seqstart_q_ptr)};
224 
225  if constexpr(kStoreLSE)
226  {
227  kargs.lse_ptr = lse_ptr;
228  kargs.nhead_stride_lse = nhead_stride_lse;
229  }
230  if constexpr(kDoFp8StaticQuant)
231  {
232  kargs.scale_o = scale_o;
233  }
234 
235  return kargs;
236  }
237 
238  CK_TILE_HOST static constexpr auto GridSize(ck_tile::index_t batch_size,
239  ck_tile::index_t nhead,
240  ck_tile::index_t max_seqlen_q,
241  ck_tile::index_t hdim_v)
242  {
243  // TODO: this may need tuning
244  return dim3(ck_tile::integer_divide_ceil(max_seqlen_q, FmhaPipeline::kM0) *
245  ck_tile::integer_divide_ceil(hdim_v, FmhaPipeline::kN1),
246  nhead,
247  batch_size);
248  }
249 
250  CK_TILE_DEVICE static constexpr auto GetTileIndex(const Kargs& kargs)
251  {
252  const index_t num_tile_n1 = ck_tile::integer_divide_ceil(kargs.hdim_v, FmhaPipeline::kN1);
253 
254  const index_t i_block = blockIdx.x;
255  const index_t i_nhead = blockIdx.y;
256  const index_t i_batch = blockIdx.z;
257 
258  const auto f = [](index_t dividend, index_t divisor) {
259  index_t quotient = dividend / divisor;
260  index_t modulus = dividend - quotient * divisor;
261  return ck_tile::make_tuple(quotient, modulus);
262  };
263 
264  const auto [i_tile_m, i_tile_n] = f(i_block, num_tile_n1);
265 
266  return ck_tile::make_tuple(i_tile_m, i_tile_n, i_nhead, i_batch);
267  }
268 
269  __host__ static constexpr auto BlockSize() { return dim3(kBlockSize); }
270 
272  {
273  return ck_tile::max(FmhaPipeline::GetSmemSize(), EpiloguePipeline::GetSmemSize());
274  }
275 
276  CK_TILE_DEVICE void operator()(Kargs kargs) const
277  {
278  // allocate LDS
279  __shared__ char smem_ptr[GetSmemSize()];
280 
281  // divide problem
282  const auto [i_tile_m, i_tile_n, i_nhead, i_batch] = GetTileIndex(kargs);
283 
284  const index_t i_m0 = __builtin_amdgcn_readfirstlane(i_tile_m * FmhaPipeline::kM0);
285  const index_t i_n1 = __builtin_amdgcn_readfirstlane(i_tile_n * FmhaPipeline::kN1);
286 
287  long_index_t batch_offset_lse_acc = 0;
288  long_index_t batch_offset_o_acc = 0;
289  long_index_t batch_offset_lse = 0;
290  long_index_t batch_offset_o = 0;
291 
292  if constexpr(kIsGroupMode)
293  {
294  // get starting offset for each batch
295  const long_index_t query_start = kargs.seqstart_q_ptr[i_batch];
296 
297  batch_offset_lse_acc = query_start;
298  batch_offset_o_acc = query_start * kargs.row_stride_o_acc;
299 
300  if constexpr(kStoreLSE)
301  {
302  batch_offset_lse = query_start;
303  }
304 
305  batch_offset_o = query_start * kargs.row_stride_o;
306 
307  // get real # queries & # keys under group mode
308  const auto adjusted_seqstart_q_ptr = kargs.seqstart_q_ptr + i_batch;
309  kargs.seqlen_q = adjusted_seqstart_q_ptr[1] - adjusted_seqstart_q_ptr[0];
310 
311  // # of required blocks is different in each groups, terminate unnecessary blocks
312  // earlier
313  if(kargs.seqlen_q <= i_m0)
314  {
315  return;
316  }
317  }
318  else
319  {
320  batch_offset_lse_acc = static_cast<long_index_t>(i_batch) * kargs.batch_stride_lse_acc;
321  batch_offset_o_acc = static_cast<long_index_t>(i_batch) * kargs.batch_stride_o_acc;
322 
323  if constexpr(kStoreLSE)
324  {
325  batch_offset_lse = static_cast<long_index_t>(i_batch) * kargs.batch_stride_lse;
326  }
327 
328  batch_offset_o = static_cast<long_index_t>(i_batch) * kargs.batch_stride_o;
329  }
330 
331  // for simplicity, batch stride we just modify the pointer
332  const LSEDataType* lse_acc_ptr =
333  reinterpret_cast<const LSEDataType*>(kargs.lse_acc_ptr) +
334  static_cast<long_index_t>(i_nhead) * kargs.nhead_stride_lse_acc + batch_offset_lse_acc;
335  const OaccDataType* o_acc_ptr =
336  reinterpret_cast<const OaccDataType*>(kargs.o_acc_ptr) +
337  static_cast<long_index_t>(i_nhead) * kargs.nhead_stride_o_acc + batch_offset_o_acc;
338  ODataType* o_ptr = reinterpret_cast<ODataType*>(kargs.o_ptr) +
339  static_cast<long_index_t>(i_nhead) * kargs.nhead_stride_o +
340  batch_offset_o;
341 
342  // LSEacc/Oacc DRAM and DRAM windows
343  const auto lse_acc_dram = [&]() {
344  const auto lse_acc_dram_naive = make_naive_tensor_view<address_space_enum::global>(
345  lse_acc_ptr,
346  make_tuple(kargs.num_splits, kargs.seqlen_q),
347  make_tuple(kargs.split_stride_lse_acc, 1),
349  number<1>{});
350 
351  return pad_tensor_view(
352  lse_acc_dram_naive,
355  }();
356 
357  auto o_acc_dram = [&]() {
358  const auto o_acc_dram_naive = make_naive_tensor_view<address_space_enum::global>(
359  o_acc_ptr,
360  make_tuple(kargs.num_splits, kargs.seqlen_q, kargs.hdim_v),
361  make_tuple(kargs.split_stride_o_acc, kargs.row_stride_o_acc, 1),
363  number<1>{});
364 
365  // read 4 * (kM0, kN1) o_acc tiles simultaneously by 4 warps
366  const auto o_acc_dram_view = pad_tensor_view(
367  o_acc_dram_naive,
368  make_tuple(
371 
372  const index_t padded_num_splits =
373  o_acc_dram_view.get_tensor_descriptor().get_lengths()[number<0>{}];
374  const index_t padded_seqlen_q =
375  o_acc_dram_view.get_tensor_descriptor().get_lengths()[number<1>{}];
376  const index_t padded_hdim_v =
377  o_acc_dram_view.get_tensor_descriptor().get_lengths()[number<2>{}];
378 
379  const index_t num_m_tiles = integer_divide_floor(padded_seqlen_q, FmhaPipeline::kM0);
380 
381  // transform tensor view by following steps, given shape: (padded_num_splits,
382  // padded_seqlen_q, padded_hdim_v)
383  // 1. unmerge to (padded_num_splits, num_m_tiles, kM0, padded_hdim_v)
384  // 2. transpose to (num_m_tiles, padded_num_splits, kM0, padded_hdim_v)
385  // 3. merge to (num_m_tiles * padded_num_splits * kM0, padded_hdim_v)
386  auto transposed = transform_tensor_view(
387  o_acc_dram_view,
388  make_tuple(make_pass_through_transform(padded_num_splits),
389  make_unmerge_transform(make_tuple(num_m_tiles, FmhaPipeline::kM0)),
390  make_pass_through_transform(padded_hdim_v)),
393 
394  return transform_tensor_view(
395  transposed,
397  make_tuple(num_m_tiles, padded_num_splits, FmhaPipeline::kM0)),
398  make_pass_through_transform(padded_hdim_v)),
401  }();
402 
403  auto lse_acc_dram_window = make_tile_window(
404  lse_acc_dram,
406  {0, i_m0});
407 
408  const index_t padded_num_splits =
409  integer_divide_ceil(kargs.num_splits, kNumWarps) * kNumWarps;
410 
411  auto o_acc_dram_window = make_tile_window(
412  o_acc_dram,
414  {i_tile_m * padded_num_splits * FmhaPipeline::kM0, i_n1});
415 
416  // LSE DRAM window
417  auto lse_dram_window = [&, i_nhead_ = i_nhead]() {
418  constexpr auto lse_dram_window_lengths = make_tuple(number<FmhaPipeline::kM0>{});
419  if constexpr(kStoreLSE)
420  {
421  LSEDataType* lse_ptr =
422  reinterpret_cast<LSEDataType*>(kargs.lse_ptr) +
423  static_cast<long_index_t>(i_nhead_) * kargs.nhead_stride_lse + batch_offset_lse;
424 
425  const auto lse_dram = [&]() {
426  const auto lse_dram_naive = make_naive_tensor_view<address_space_enum::global>(
427  lse_ptr,
428  make_tuple(kargs.seqlen_q),
429  make_tuple(1),
431  number<1>{});
432 
433  return pad_tensor_view(
434  lse_dram_naive, lse_dram_window_lengths, sequence<kPadSeqLenQ>{});
435  }();
436 
437  return make_tile_window(lse_dram, lse_dram_window_lengths, {i_m0});
438  }
439  else
440  {
441  return make_null_tile_window(lse_dram_window_lengths);
442  }
443  }();
444 
445  auto o_acc_tile = [&]() {
446  if constexpr(kDoFp8StaticQuant)
447  {
448  return FmhaPipeline{}(
449  lse_acc_dram_window,
450  o_acc_dram_window,
451  lse_dram_window,
452  identity{}, // lse_element_func
453  composes(saturates<fp8_t>{}, scales{kargs.scale_o}), // o_acc_element_func
454  kargs.num_splits,
455  smem_ptr);
456  }
457  else
458  {
459  return FmhaPipeline{}(lse_acc_dram_window,
460  o_acc_dram_window,
461  lse_dram_window,
462  kargs.num_splits,
463  smem_ptr);
464  }
465  }();
466 
467  // O DRAM and DRAM window
468  auto o_dram = [&]() {
469  const auto o_dram_naive = make_naive_tensor_view<address_space_enum::global>(
470  o_ptr,
471  make_tuple(kargs.seqlen_q, kargs.hdim_v),
472  make_tuple(kargs.row_stride_o, 1),
474  number<1>{});
475 
476  return pad_tensor_view(
477  o_dram_naive,
480  }();
481 
482  auto o_dram_window =
483  make_tile_window(o_dram,
485  {i_m0, i_n1});
486 
487  EpiloguePipeline{}(o_dram_window, o_acc_tile, nullptr);
488  }
489 };
490 
491 } // namespace ck_tile
#define CK_TILE_DEVICE
Definition: config.hpp:41
#define CK_TILE_HOST
Definition: config.hpp:40
#define CK_TILE_HOST_DEVICE
Definition: config.hpp:42
#define _TS_
#define _SS_
Definition: cluster_descriptor.hpp:13
constexpr CK_TILE_DEVICE auto make_null_tile_window(const WindowLengths &window_lengths)
Definition: null_tile_window.hpp:66
_BitInt(8) fp8_t
Definition: float8.hpp:204
constexpr CK_TILE_HOST_DEVICE auto integer_divide_ceil(X x, Y y)
Definition: math.hpp:149
_Float16 fp16_t
Definition: half.hpp:110
constexpr CK_TILE_HOST_DEVICE auto transform_tensor_view(const OldTensorView &old_tensor_view, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition: tensor_view.hpp:511
bfloat16_t bf16_t
Definition: bfloat16.hpp:113
constexpr CK_TILE_HOST_DEVICE auto make_merge_transform(const LowLengths &low_lengths)
Definition: coordinate_transform.hpp:1615
int32_t index_t
Definition: integer.hpp:9
constexpr CK_TILE_HOST_DEVICE auto pad_tensor_view(const TensorView &tensor_view, const TileLengths &tile_lengths, DoPads)
Definition: tensor_view.hpp:530
constexpr CK_TILE_HOST_DEVICE auto make_pass_through_transform(const LowLength &low_length)
Definition: coordinate_transform.hpp:1558
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition: type_traits.hpp:21
constexpr CK_TILE_HOST_DEVICE auto make_unmerge_transform(const UpLengths &up_lengths, bool_constant< Use24BitIntegerCalculation >=bool_constant< false >{})
Definition: coordinate_transform.hpp:1622
constexpr CK_TILE_HOST_DEVICE auto integer_divide_floor(X x, Y y)
Definition: math.hpp:143
int64_t long_index_t
Definition: integer.hpp:11
int32_t int32_t
Definition: integer.hpp:10
constexpr CK_TILE_DEVICE auto make_tile_window(null_tensor_view, const WindowLengths &window_lengths, const multi_index< WindowLengths::size()> &, Ts &&...)
Definition: null_tile_window.hpp:75
unsigned _BitInt(8) bf8_t
Definition: float8.hpp:206
constexpr CK_TILE_HOST_DEVICE auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:360
__host__ __device__ composes(Ts &&...) -> composes< remove_cvref_t< Ts >... >
FIXME: create macro to replace 'host device' and nothing more.
constexpr CK_TILE_HOST_DEVICE T max(T x)
Definition: math.hpp:161
typename conditional< predicate, X, Y >::type conditional_t
Definition: functional.hpp:115
Definition: fmha_fwd_splitkv_combine_kernel.hpp:113
ck_tile::index_t batch_stride_lse_acc
Definition: fmha_fwd_splitkv_combine_kernel.hpp:114
ck_tile::index_t batch_stride_o
Definition: fmha_fwd_splitkv_combine_kernel.hpp:116
ck_tile::index_t batch_stride_o_acc
Definition: fmha_fwd_splitkv_combine_kernel.hpp:115
Definition: fmha_fwd_splitkv_combine_kernel.hpp:76
ck_tile::index_t row_stride_o
Definition: fmha_fwd_splitkv_combine_kernel.hpp:87
ck_tile::index_t nhead_stride_lse_acc
Definition: fmha_fwd_splitkv_combine_kernel.hpp:89
ck_tile::index_t nhead_stride_o
Definition: fmha_fwd_splitkv_combine_kernel.hpp:91
ck_tile::index_t nhead_stride_o_acc
Definition: fmha_fwd_splitkv_combine_kernel.hpp:90
ck_tile::index_t hdim_v
Definition: fmha_fwd_splitkv_combine_kernel.hpp:83
ck_tile::index_t row_stride_o_acc
Definition: fmha_fwd_splitkv_combine_kernel.hpp:86
void * o_ptr
Definition: fmha_fwd_splitkv_combine_kernel.hpp:79
ck_tile::index_t num_splits
Definition: fmha_fwd_splitkv_combine_kernel.hpp:84
const void * lse_acc_ptr
Definition: fmha_fwd_splitkv_combine_kernel.hpp:77
ck_tile::index_t seqlen_q
Definition: fmha_fwd_splitkv_combine_kernel.hpp:82
ck_tile::index_t split_stride_lse_acc
Definition: fmha_fwd_splitkv_combine_kernel.hpp:93
ck_tile::index_t batch
Definition: fmha_fwd_splitkv_combine_kernel.hpp:81
const void * o_acc_ptr
Definition: fmha_fwd_splitkv_combine_kernel.hpp:78
ck_tile::index_t split_stride_o_acc
Definition: fmha_fwd_splitkv_combine_kernel.hpp:94
Definition: fmha_fwd_splitkv_combine_kernel.hpp:98
ck_tile::index_t batch_stride_lse
Definition: fmha_fwd_splitkv_combine_kernel.hpp:101
ck_tile::index_t nhead_stride_lse
Definition: fmha_fwd_splitkv_combine_kernel.hpp:100
void * lse_ptr
Definition: fmha_fwd_splitkv_combine_kernel.hpp:99
Definition: fmha_fwd_splitkv_combine_kernel.hpp:69
Definition: fmha_fwd_splitkv_combine_kernel.hpp:105
float scale_o
Definition: fmha_fwd_splitkv_combine_kernel.hpp:106
Definition: fmha_fwd_splitkv_combine_kernel.hpp:123
const int32_t * seqstart_q_ptr
Definition: fmha_fwd_splitkv_combine_kernel.hpp:124
Definition: fmha_fwd_splitkv_combine_kernel.hpp:32
Definition: fmha_fwd_splitkv_combine_kernel.hpp:10
static constexpr bool kStoreLSE
Definition: fmha_fwd_splitkv_combine_kernel.hpp:28
remove_cvref_t< typename FmhaPipeline::ODataType > ODataType
Definition: fmha_fwd_splitkv_combine_kernel.hpp:23
static constexpr __host__ std::enable_if_t< Cond, Kargs > MakeKargs(const void *lse_acc_ptr, const void *o_acc_ptr, void *lse_ptr, void *o_ptr, ck_tile::index_t batch, const void *seqstart_q_ptr, ck_tile::index_t hdim_v, ck_tile::index_t num_splits, float scale_o, ck_tile::index_t row_stride_o_acc, ck_tile::index_t row_stride_o, ck_tile::index_t nhead_stride_lse_acc, ck_tile::index_t nhead_stride_o_acc, ck_tile::index_t nhead_stride_lse, ck_tile::index_t nhead_stride_o, ck_tile::index_t split_stride_lse_acc, ck_tile::index_t split_stride_o_acc)
Definition: fmha_fwd_splitkv_combine_kernel.hpp:189
CK_TILE_DEVICE void operator()(Kargs kargs) const
Definition: fmha_fwd_splitkv_combine_kernel.hpp:276
static constexpr __host__ std::enable_if_t< Cond, Kargs > MakeKargs(const void *lse_acc_ptr, const void *o_acc_ptr, void *lse_ptr, void *o_ptr, ck_tile::index_t batch, ck_tile::index_t seqlen_q, ck_tile::index_t hdim_v, ck_tile::index_t num_splits, float scale_o, ck_tile::index_t row_stride_o_acc, ck_tile::index_t row_stride_o, ck_tile::index_t nhead_stride_lse_acc, ck_tile::index_t nhead_stride_o_acc, ck_tile::index_t nhead_stride_lse, ck_tile::index_t nhead_stride_o, ck_tile::index_t batch_stride_lse_acc, ck_tile::index_t batch_stride_o_acc, ck_tile::index_t batch_stride_lse, ck_tile::index_t batch_stride_o, ck_tile::index_t split_stride_lse_acc, ck_tile::index_t split_stride_o_acc)
Definition: fmha_fwd_splitkv_combine_kernel.hpp:131
static constexpr index_t kBlockPerCuInput
Definition: fmha_fwd_splitkv_combine_kernel.hpp:19
static constexpr bool kIsGroupMode
Definition: fmha_fwd_splitkv_combine_kernel.hpp:25
static constexpr CK_TILE_HOST auto GridSize(ck_tile::index_t batch_size, ck_tile::index_t nhead, ck_tile::index_t max_seqlen_q, ck_tile::index_t hdim_v)
Definition: fmha_fwd_splitkv_combine_kernel.hpp:238
static constexpr __host__ auto BlockSize()
Definition: fmha_fwd_splitkv_combine_kernel.hpp:269
static constexpr index_t kNumWarps
Definition: fmha_fwd_splitkv_combine_kernel.hpp:14
remove_cvref_t< typename FmhaPipeline::LSEDataType > LSEDataType
Definition: fmha_fwd_splitkv_combine_kernel.hpp:21
static constexpr bool kPadSeqLenQ
Definition: fmha_fwd_splitkv_combine_kernel.hpp:26
static constexpr CK_TILE_DEVICE auto GetTileIndex(const Kargs &kargs)
Definition: fmha_fwd_splitkv_combine_kernel.hpp:250
remove_cvref_t< typename FmhaPipeline::OaccDataType > OaccDataType
Definition: fmha_fwd_splitkv_combine_kernel.hpp:22
static constexpr bool kPadHeadDimV
Definition: fmha_fwd_splitkv_combine_kernel.hpp:27
remove_cvref_t< EpiloguePipeline_ > EpiloguePipeline
Definition: fmha_fwd_splitkv_combine_kernel.hpp:12
static constexpr bool kDoFp8StaticQuant
Definition: fmha_fwd_splitkv_combine_kernel.hpp:29
static __host__ std::string GetName()
Definition: fmha_fwd_splitkv_combine_kernel.hpp:40
static constexpr index_t kBlockSize
Definition: fmha_fwd_splitkv_combine_kernel.hpp:15
static constexpr CK_TILE_HOST_DEVICE ck_tile::index_t GetSmemSize()
Definition: fmha_fwd_splitkv_combine_kernel.hpp:271
static constexpr index_t kBlockPerCu
Definition: fmha_fwd_splitkv_combine_kernel.hpp:16
remove_cvref_t< FmhaPipeline_ > FmhaPipeline
Definition: fmha_fwd_splitkv_combine_kernel.hpp:11
std::conditional_t< kIsGroupMode, GroupModeKargs, BatchModeKargs > Kargs
Definition: fmha_fwd_splitkv_combine_kernel.hpp:127
Definition: integral_constant.hpp:13
Definition: functional.hpp:86
Definition: unary_element_function.hpp:56
Definition: math.hpp:28
Definition: sequence.hpp:49