include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_combine_kernel.hpp Source File

include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_combine_kernel.hpp Source File#

Composable Kernel: include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_combine_kernel.hpp Source File
fmha_fwd_splitkv_combine_kernel.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 namespace ck_tile {
7 
8 template <typename FmhaPipeline_, typename EpiloguePipeline_>
10 {
13 
14  static constexpr index_t kNumWarps = FmhaPipeline::kNumWarps;
15  static constexpr index_t kBlockSize = FmhaPipeline::kBlockSize;
16  static constexpr index_t kBlockPerCu = FmhaPipeline::kBlockPerCu;
17  static_assert(kBlockPerCu > 0);
18  static constexpr index_t kBlockPerCuInput = FmhaPipeline::Problem::kBlockPerCu;
19 
23 
24  static constexpr bool kIsGroupMode = FmhaPipeline::kIsGroupMode;
25  static constexpr bool kPadSeqLenQ = FmhaPipeline::kPadSeqLenQ;
26  static constexpr bool kPadHeadDimV = FmhaPipeline::kPadHeadDimV;
27  static constexpr bool kStoreLSE = FmhaPipeline::kStoreLSE;
28  static constexpr bool kDoFp8StaticQuant = FmhaPipeline::Problem::kDoFp8StaticQuant;
29 
30  // clang-format off
31  template <typename T> struct t2s;
32  template <> struct t2s<float> { static constexpr const char * name = "fp32"; };
33  template <> struct t2s<ck_tile::fp16_t> { static constexpr const char * name = "fp16"; };
34  template <> struct t2s<ck_tile::bf16_t> { static constexpr const char * name = "bf16"; };
35  template <> struct t2s<ck_tile::fp8_t> { static constexpr const char * name = "fp8"; };
36  template <> struct t2s<ck_tile::bf8_t> { static constexpr const char * name = "bf8"; };
37  // clang-format on
38 
39  __host__ static std::string GetName()
40  {
41  // sync with generate.py
42  // clang-format off
43 
44  #define _SS_ std::string
45  #define _TS_ std::to_string
46  auto pn = [&] () {
47  std::string n;
48  if (kPadSeqLenQ) n += "s";
49  if (kPadHeadDimV) n += "dv";
50  return n.empty() ? n : std::string("p") + n; }();
51  return
52  _SS_("fmha_fwd_splitkv_combine_d") + _TS_(FmhaPipeline::kHeadDimV) + "_" + _SS_(t2s<ODataType>::name) +
53  "_" + (kIsGroupMode ? "group" : "batch") + "_"
54  "b" + _TS_(FmhaPipeline::kN1) + "_" +
55  (kBlockPerCuInput == -1 ? "" : ("o" + _TS_(kBlockPerCu) + "_")) +
56  _SS_(FmhaPipeline::name) +
57  (pn.empty() ? "" : "_" + pn) +
58  (kStoreLSE ? "_lse" : "" ) +
59  (kDoFp8StaticQuant ? "_squant" : "" );
60  #undef _SS_
61  #undef _TS_
62  // clang-format on
63  }
64 
65  template <ck_tile::index_t I> // to avoid duplicated base class prblem, introduce an template
66  // arg
67  struct EmptyKargs
68  {
69  };
70 
71  // kargs use aggregate initializer, so no constructor will provided
72  // use inheritance to minimize karg size
73  // user need to use MakeKargs() function to create kargs.
74  struct CommonKargs
75  {
76  const void* lse_acc_ptr;
77  const void* o_acc_ptr;
78  void* o_ptr;
79 
84 
87 
91 
94  };
95 
97  {
98  void* lse_ptr = nullptr;
101  };
102 
104  {
105  float scale_o;
106  };
107 
109  : CommonKargs,
110  std::conditional_t<kStoreLSE, CommonLSEKargs, EmptyKargs<0>>,
111  std::conditional_t<kDoFp8StaticQuant, Fp8StaticQuantKargs, EmptyKargs<1>>
112  {
116  };
117 
119  : CommonKargs,
120  std::conditional_t<kStoreLSE, CommonLSEKargs, EmptyKargs<0>>,
121  std::conditional_t<kDoFp8StaticQuant, Fp8StaticQuantKargs, EmptyKargs<3>>
122  {
123  const int32_t* seqstart_q_ptr;
124  };
125 
126  using Kargs = std::conditional_t<kIsGroupMode, GroupModeKargs, BatchModeKargs>;
127 
128  template <bool Cond = !kIsGroupMode>
129  __host__ static constexpr std::enable_if_t<Cond, Kargs>
130  MakeKargs(const void* lse_acc_ptr,
131  const void* o_acc_ptr,
132  void* lse_ptr,
133  void* o_ptr,
134  ck_tile::index_t batch,
135  ck_tile::index_t seqlen_q,
136  ck_tile::index_t hdim_v,
137  ck_tile::index_t num_splits,
138  float scale_o,
139  ck_tile::index_t row_stride_o_acc,
140  ck_tile::index_t row_stride_o,
141  ck_tile::index_t nhead_stride_lse_acc,
142  ck_tile::index_t nhead_stride_o_acc,
143  ck_tile::index_t nhead_stride_lse,
144  ck_tile::index_t nhead_stride_o,
145  ck_tile::index_t batch_stride_lse_acc,
146  ck_tile::index_t batch_stride_o_acc,
147  ck_tile::index_t batch_stride_lse,
148  ck_tile::index_t batch_stride_o,
149  ck_tile::index_t split_stride_lse_acc,
150  ck_tile::index_t split_stride_o_acc)
151  {
152  Kargs kargs{{lse_acc_ptr,
153  o_acc_ptr,
154  o_ptr,
155  batch,
156  seqlen_q,
157  hdim_v,
158  num_splits,
159  row_stride_o_acc,
160  row_stride_o,
161  nhead_stride_lse_acc,
162  nhead_stride_o_acc,
163  nhead_stride_o,
164  split_stride_lse_acc,
165  split_stride_o_acc}, // args for common karg
166  {}, // placeholder for lse
167  {}, // placeholder for fp8_static_quant args
168  batch_stride_lse_acc,
169  batch_stride_o_acc,
170  batch_stride_o};
171 
172  if constexpr(kStoreLSE)
173  {
174  kargs.lse_ptr = lse_ptr;
175  kargs.nhead_stride_lse = nhead_stride_lse;
176  kargs.batch_stride_lse = batch_stride_lse;
177  }
178  if constexpr(kDoFp8StaticQuant)
179  {
180  kargs.scale_o = scale_o;
181  }
182 
183  return kargs;
184  }
185 
186  template <bool Cond = kIsGroupMode>
187  __host__ static constexpr std::enable_if_t<Cond, Kargs>
188  MakeKargs(const void* lse_acc_ptr,
189  const void* o_acc_ptr,
190  void* lse_ptr,
191  void* o_ptr,
192  ck_tile::index_t batch,
193  const void* seqstart_q_ptr,
194  ck_tile::index_t hdim_v,
195  ck_tile::index_t num_splits,
196  float scale_o,
197  ck_tile::index_t row_stride_o_acc,
198  ck_tile::index_t row_stride_o,
199  ck_tile::index_t nhead_stride_lse_acc,
200  ck_tile::index_t nhead_stride_o_acc,
201  ck_tile::index_t nhead_stride_lse,
202  ck_tile::index_t nhead_stride_o,
203  ck_tile::index_t split_stride_lse_acc,
204  ck_tile::index_t split_stride_o_acc)
205  {
206  Kargs kargs{{lse_acc_ptr,
207  o_acc_ptr,
208  o_ptr,
209  batch,
210  -1, // seqlen will be updated by another pointer
211  hdim_v,
212  num_splits,
213  row_stride_o_acc,
214  row_stride_o,
215  nhead_stride_lse_acc,
216  nhead_stride_o_acc,
217  nhead_stride_o,
218  split_stride_lse_acc,
219  split_stride_o_acc}, // args for common karg
220  {}, // placeholder for lse
221  {}, // placeholder for fp8_static_quant args
222  reinterpret_cast<const int32_t*>(seqstart_q_ptr)};
223 
224  if constexpr(kStoreLSE)
225  {
226  kargs.lse_ptr = lse_ptr;
227  kargs.nhead_stride_lse = nhead_stride_lse;
228  }
229  if constexpr(kDoFp8StaticQuant)
230  {
231  kargs.scale_o = scale_o;
232  }
233 
234  return kargs;
235  }
236 
237  CK_TILE_HOST static constexpr auto GridSize(ck_tile::index_t batch_size,
238  ck_tile::index_t nhead,
239  ck_tile::index_t max_seqlen_q,
240  ck_tile::index_t hdim_v)
241  {
242  // TODO: this may need tuning
243  return dim3(ck_tile::integer_divide_ceil(max_seqlen_q, FmhaPipeline::kM0) *
244  ck_tile::integer_divide_ceil(hdim_v, FmhaPipeline::kN1),
245  nhead,
246  batch_size);
247  }
248 
249  CK_TILE_DEVICE static constexpr auto GetTileIndex(const Kargs& kargs)
250  {
251  const index_t num_tile_n1 = ck_tile::integer_divide_ceil(kargs.hdim_v, FmhaPipeline::kN1);
252 
253  const index_t i_block = blockIdx.x;
254  const index_t i_nhead = blockIdx.y;
255  const index_t i_batch = blockIdx.z;
256 
257  const auto f = [](index_t dividend, index_t divisor) {
258  index_t quotient = dividend / divisor;
259  index_t modulus = dividend - quotient * divisor;
260  return ck_tile::make_tuple(quotient, modulus);
261  };
262 
263  const auto [i_tile_m, i_tile_n] = f(i_block, num_tile_n1);
264 
265  return ck_tile::make_tuple(i_tile_m, i_tile_n, i_nhead, i_batch);
266  }
267 
268  __host__ static constexpr auto BlockSize() { return dim3(kBlockSize); }
269 
271  {
272  return ck_tile::max(FmhaPipeline::GetSmemSize(), EpiloguePipeline::GetSmemSize());
273  }
274 
275  CK_TILE_DEVICE void operator()(Kargs kargs) const
276  {
277  // allocate LDS
278  __shared__ char smem_ptr[GetSmemSize()];
279 
280  // divide problem
281  const auto [i_tile_m, i_tile_n, i_nhead, i_batch] = GetTileIndex(kargs);
282 
283  const index_t i_m0 = __builtin_amdgcn_readfirstlane(i_tile_m * FmhaPipeline::kM0);
284  const index_t i_n1 = __builtin_amdgcn_readfirstlane(i_tile_n * FmhaPipeline::kN1);
285 
286  long_index_t batch_offset_lse_acc = 0;
287  long_index_t batch_offset_o_acc = 0;
288  long_index_t batch_offset_lse = 0;
289  long_index_t batch_offset_o = 0;
290 
291  if constexpr(kIsGroupMode)
292  {
293  // get starting offset for each batch
294  const long_index_t query_start = kargs.seqstart_q_ptr[i_batch];
295 
296  batch_offset_lse_acc = query_start;
297  batch_offset_o_acc = query_start * kargs.row_stride_o_acc;
298 
299  if constexpr(kStoreLSE)
300  {
301  batch_offset_lse = query_start;
302  }
303 
304  batch_offset_o = query_start * kargs.row_stride_o;
305 
306  // get real # queries & # keys under group mode
307  const auto adjusted_seqstart_q_ptr = kargs.seqstart_q_ptr + i_batch;
308  kargs.seqlen_q = adjusted_seqstart_q_ptr[1] - adjusted_seqstart_q_ptr[0];
309 
310  // # of required blocks is different in each groups, terminate unnecessary blocks
311  // earlier
312  if(kargs.seqlen_q <= i_m0)
313  {
314  return;
315  }
316  }
317  else
318  {
319  batch_offset_lse_acc = static_cast<long_index_t>(i_batch) * kargs.batch_stride_lse_acc;
320  batch_offset_o_acc = static_cast<long_index_t>(i_batch) * kargs.batch_stride_o_acc;
321 
322  if constexpr(kStoreLSE)
323  {
324  batch_offset_lse = static_cast<long_index_t>(i_batch) * kargs.batch_stride_lse;
325  }
326 
327  batch_offset_o = static_cast<long_index_t>(i_batch) * kargs.batch_stride_o;
328  }
329 
330  // for simplicity, batch stride we just modify the pointer
331  const LSEDataType* lse_acc_ptr =
332  reinterpret_cast<const LSEDataType*>(kargs.lse_acc_ptr) +
333  static_cast<long_index_t>(i_nhead) * kargs.nhead_stride_lse_acc + batch_offset_lse_acc;
334  const OaccDataType* o_acc_ptr =
335  reinterpret_cast<const OaccDataType*>(kargs.o_acc_ptr) +
336  static_cast<long_index_t>(i_nhead) * kargs.nhead_stride_o_acc + batch_offset_o_acc;
337  ODataType* o_ptr = reinterpret_cast<ODataType*>(kargs.o_ptr) +
338  static_cast<long_index_t>(i_nhead) * kargs.nhead_stride_o +
339  batch_offset_o;
340 
341  // LSEacc/Oacc DRAM and DRAM windows
342  const auto lse_acc_dram = [&]() {
343  const auto lse_acc_dram_naive = make_naive_tensor_view<address_space_enum::global>(
344  lse_acc_ptr,
345  make_tuple(kargs.num_splits, kargs.seqlen_q),
346  make_tuple(kargs.split_stride_lse_acc, 1),
348  number<1>{});
349 
350  return pad_tensor_view(
351  lse_acc_dram_naive,
354  }();
355 
356  auto o_acc_dram = [&]() {
357  const auto o_acc_dram_naive = make_naive_tensor_view<address_space_enum::global>(
358  o_acc_ptr,
359  make_tuple(kargs.num_splits, kargs.seqlen_q, kargs.hdim_v),
360  make_tuple(kargs.split_stride_o_acc, kargs.row_stride_o_acc, 1),
362  number<1>{});
363 
364  // read 4 * (kM0, kN1) o_acc tiles simultaneously by 4 warps
365  const auto o_acc_dram_view = pad_tensor_view(
366  o_acc_dram_naive,
367  make_tuple(
370 
371  const index_t padded_num_splits =
372  o_acc_dram_view.get_tensor_descriptor().get_lengths()[number<0>{}];
373  const index_t padded_seqlen_q =
374  o_acc_dram_view.get_tensor_descriptor().get_lengths()[number<1>{}];
375  const index_t padded_hdim_v =
376  o_acc_dram_view.get_tensor_descriptor().get_lengths()[number<2>{}];
377 
378  const index_t num_m_tiles = integer_divide_floor(padded_seqlen_q, FmhaPipeline::kM0);
379 
380  // transform tensor view by following steps, given shape: (padded_num_splits,
381  // padded_seqlen_q, padded_hdim_v)
382  // 1. unmerge to (padded_num_splits, num_m_tiles, kM0, padded_hdim_v)
383  // 2. transpose to (num_m_tiles, padded_num_splits, kM0, padded_hdim_v)
384  // 3. merge to (num_m_tiles * padded_num_splits * kM0, padded_hdim_v)
385  auto transposed = transform_tensor_view(
386  o_acc_dram_view,
387  make_tuple(make_pass_through_transform(padded_num_splits),
388  make_unmerge_transform(make_tuple(num_m_tiles, FmhaPipeline::kM0)),
389  make_pass_through_transform(padded_hdim_v)),
392 
393  return transform_tensor_view(
394  transposed,
396  make_tuple(num_m_tiles, padded_num_splits, FmhaPipeline::kM0)),
397  make_pass_through_transform(padded_hdim_v)),
400  }();
401 
402  auto lse_acc_dram_window = make_tile_window(
403  lse_acc_dram,
405  {0, i_m0});
406 
407  const index_t padded_num_splits =
408  integer_divide_ceil(kargs.num_splits, kNumWarps) * kNumWarps;
409 
410  auto o_acc_dram_window = make_tile_window(
411  o_acc_dram,
413  {i_tile_m * padded_num_splits * FmhaPipeline::kM0, i_n1});
414 
415  // LSE DRAM window
416  auto lse_dram_window = [&, i_nhead_ = i_nhead]() {
417  constexpr auto lse_dram_window_lengths = make_tuple(number<FmhaPipeline::kM0>{});
418  if constexpr(kStoreLSE)
419  {
420  LSEDataType* lse_ptr =
421  reinterpret_cast<LSEDataType*>(kargs.lse_ptr) +
422  static_cast<long_index_t>(i_nhead_) * kargs.nhead_stride_lse + batch_offset_lse;
423 
424  const auto lse_dram = [&]() {
425  const auto lse_dram_naive = make_naive_tensor_view<address_space_enum::global>(
426  lse_ptr,
427  make_tuple(kargs.seqlen_q),
428  make_tuple(1),
430  number<1>{});
431 
432  return pad_tensor_view(
433  lse_dram_naive, lse_dram_window_lengths, sequence<kPadSeqLenQ>{});
434  }();
435 
436  return make_tile_window(lse_dram, lse_dram_window_lengths, {i_m0});
437  }
438  else
439  {
440  return make_null_tile_window(lse_dram_window_lengths);
441  }
442  }();
443 
444  auto o_acc_tile = [&]() {
445  if constexpr(kDoFp8StaticQuant)
446  {
447  return FmhaPipeline{}(
448  lse_acc_dram_window,
449  o_acc_dram_window,
450  lse_dram_window,
451  identity{}, // lse_element_func
452  composes(saturates<fp8_t>{}, scales{kargs.scale_o}), // o_acc_element_func
453  kargs.num_splits,
454  smem_ptr);
455  }
456  else
457  {
458  return FmhaPipeline{}(lse_acc_dram_window,
459  o_acc_dram_window,
460  lse_dram_window,
461  kargs.num_splits,
462  smem_ptr);
463  }
464  }();
465 
466  // O DRAM and DRAM window
467  auto o_dram = [&]() {
468  const auto o_dram_naive = make_naive_tensor_view<address_space_enum::global>(
469  o_ptr,
470  make_tuple(kargs.seqlen_q, kargs.hdim_v),
471  make_tuple(kargs.row_stride_o, 1),
473  number<1>{});
474 
475  return pad_tensor_view(
476  o_dram_naive,
479  }();
480 
481  auto o_dram_window =
482  make_tile_window(o_dram,
484  {i_m0, i_n1});
485 
486  EpiloguePipeline{}(o_dram_window, o_acc_tile);
487  }
488 };
489 
490 } // namespace ck_tile
#define CK_TILE_DEVICE
Definition: config.hpp:40
#define CK_TILE_HOST
Definition: config.hpp:39
#define CK_TILE_HOST_DEVICE
Definition: config.hpp:41
#define _TS_
#define _SS_
Definition: cluster_descriptor.hpp:13
constexpr CK_TILE_DEVICE auto make_null_tile_window(const WindowLengths &window_lengths)
Definition: null_tile_window.hpp:63
_BitInt(8) fp8_t
Definition: float8.hpp:204
constexpr CK_TILE_HOST_DEVICE auto integer_divide_ceil(X x, Y y)
Definition: math.hpp:149
_Float16 fp16_t
Definition: half.hpp:110
constexpr CK_TILE_HOST_DEVICE auto transform_tensor_view(const OldTensorView &old_tensor_view, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition: tensor_view.hpp:461
bfloat16_t bf16_t
Definition: bfloat16.hpp:106
constexpr CK_TILE_HOST_DEVICE auto make_merge_transform(const LowLengths &low_lengths)
Definition: coordinate_transform.hpp:1672
int32_t index_t
Definition: integer.hpp:9
constexpr CK_TILE_HOST_DEVICE auto pad_tensor_view(const TensorView &tensor_view, const TileLengths &tile_lengths, DoPads)
Definition: tensor_view.hpp:480
constexpr CK_TILE_HOST_DEVICE auto make_pass_through_transform(const LowLength &low_length)
Definition: coordinate_transform.hpp:1615
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition: type_traits.hpp:20
constexpr CK_TILE_HOST_DEVICE auto make_unmerge_transform(const UpLengths &up_lengths, bool_constant< Use24BitIntegerCalculation >=bool_constant< false >{})
Definition: coordinate_transform.hpp:1679
constexpr CK_TILE_HOST_DEVICE auto integer_divide_floor(X x, Y y)
Definition: math.hpp:143
int64_t long_index_t
Definition: integer.hpp:10
constexpr CK_TILE_DEVICE auto make_tile_window(null_tensor_view, const WindowLengths &window_lengths, const multi_index< WindowLengths::size()> &, Ts &&...)
Definition: null_tile_window.hpp:72
unsigned _BitInt(8) bf8_t
Definition: float8.hpp:206
constexpr CK_TILE_HOST_DEVICE auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:337
__host__ __device__ composes(Ts &&...) -> composes< remove_cvref_t< Ts >... >
FIXME: create macro to replace 'host device' and nothing more.
constexpr CK_TILE_HOST_DEVICE T max(T x)
Definition: math.hpp:161
typename conditional< predicate, X, Y >::type conditional_t
Definition: functional.hpp:115
Definition: fmha_fwd_splitkv_combine_kernel.hpp:112
ck_tile::index_t batch_stride_lse_acc
Definition: fmha_fwd_splitkv_combine_kernel.hpp:113
ck_tile::index_t batch_stride_o
Definition: fmha_fwd_splitkv_combine_kernel.hpp:115
ck_tile::index_t batch_stride_o_acc
Definition: fmha_fwd_splitkv_combine_kernel.hpp:114
Definition: fmha_fwd_splitkv_combine_kernel.hpp:75
ck_tile::index_t row_stride_o
Definition: fmha_fwd_splitkv_combine_kernel.hpp:86
ck_tile::index_t nhead_stride_lse_acc
Definition: fmha_fwd_splitkv_combine_kernel.hpp:88
ck_tile::index_t nhead_stride_o
Definition: fmha_fwd_splitkv_combine_kernel.hpp:90
ck_tile::index_t nhead_stride_o_acc
Definition: fmha_fwd_splitkv_combine_kernel.hpp:89
ck_tile::index_t hdim_v
Definition: fmha_fwd_splitkv_combine_kernel.hpp:82
ck_tile::index_t row_stride_o_acc
Definition: fmha_fwd_splitkv_combine_kernel.hpp:85
void * o_ptr
Definition: fmha_fwd_splitkv_combine_kernel.hpp:78
ck_tile::index_t num_splits
Definition: fmha_fwd_splitkv_combine_kernel.hpp:83
const void * lse_acc_ptr
Definition: fmha_fwd_splitkv_combine_kernel.hpp:76
ck_tile::index_t seqlen_q
Definition: fmha_fwd_splitkv_combine_kernel.hpp:81
ck_tile::index_t split_stride_lse_acc
Definition: fmha_fwd_splitkv_combine_kernel.hpp:92
ck_tile::index_t batch
Definition: fmha_fwd_splitkv_combine_kernel.hpp:80
const void * o_acc_ptr
Definition: fmha_fwd_splitkv_combine_kernel.hpp:77
ck_tile::index_t split_stride_o_acc
Definition: fmha_fwd_splitkv_combine_kernel.hpp:93
Definition: fmha_fwd_splitkv_combine_kernel.hpp:97
ck_tile::index_t batch_stride_lse
Definition: fmha_fwd_splitkv_combine_kernel.hpp:100
ck_tile::index_t nhead_stride_lse
Definition: fmha_fwd_splitkv_combine_kernel.hpp:99
void * lse_ptr
Definition: fmha_fwd_splitkv_combine_kernel.hpp:98
Definition: fmha_fwd_splitkv_combine_kernel.hpp:68
Definition: fmha_fwd_splitkv_combine_kernel.hpp:104
float scale_o
Definition: fmha_fwd_splitkv_combine_kernel.hpp:105
Definition: fmha_fwd_splitkv_combine_kernel.hpp:122
const int32_t * seqstart_q_ptr
Definition: fmha_fwd_splitkv_combine_kernel.hpp:123
Definition: fmha_fwd_splitkv_combine_kernel.hpp:31
Definition: fmha_fwd_splitkv_combine_kernel.hpp:10
static constexpr bool kStoreLSE
Definition: fmha_fwd_splitkv_combine_kernel.hpp:27
remove_cvref_t< typename FmhaPipeline::ODataType > ODataType
Definition: fmha_fwd_splitkv_combine_kernel.hpp:22
static constexpr __host__ std::enable_if_t< Cond, Kargs > MakeKargs(const void *lse_acc_ptr, const void *o_acc_ptr, void *lse_ptr, void *o_ptr, ck_tile::index_t batch, const void *seqstart_q_ptr, ck_tile::index_t hdim_v, ck_tile::index_t num_splits, float scale_o, ck_tile::index_t row_stride_o_acc, ck_tile::index_t row_stride_o, ck_tile::index_t nhead_stride_lse_acc, ck_tile::index_t nhead_stride_o_acc, ck_tile::index_t nhead_stride_lse, ck_tile::index_t nhead_stride_o, ck_tile::index_t split_stride_lse_acc, ck_tile::index_t split_stride_o_acc)
Definition: fmha_fwd_splitkv_combine_kernel.hpp:188
CK_TILE_DEVICE void operator()(Kargs kargs) const
Definition: fmha_fwd_splitkv_combine_kernel.hpp:275
static constexpr __host__ std::enable_if_t< Cond, Kargs > MakeKargs(const void *lse_acc_ptr, const void *o_acc_ptr, void *lse_ptr, void *o_ptr, ck_tile::index_t batch, ck_tile::index_t seqlen_q, ck_tile::index_t hdim_v, ck_tile::index_t num_splits, float scale_o, ck_tile::index_t row_stride_o_acc, ck_tile::index_t row_stride_o, ck_tile::index_t nhead_stride_lse_acc, ck_tile::index_t nhead_stride_o_acc, ck_tile::index_t nhead_stride_lse, ck_tile::index_t nhead_stride_o, ck_tile::index_t batch_stride_lse_acc, ck_tile::index_t batch_stride_o_acc, ck_tile::index_t batch_stride_lse, ck_tile::index_t batch_stride_o, ck_tile::index_t split_stride_lse_acc, ck_tile::index_t split_stride_o_acc)
Definition: fmha_fwd_splitkv_combine_kernel.hpp:130
static constexpr index_t kBlockPerCuInput
Definition: fmha_fwd_splitkv_combine_kernel.hpp:18
static constexpr bool kIsGroupMode
Definition: fmha_fwd_splitkv_combine_kernel.hpp:24
static constexpr CK_TILE_HOST auto GridSize(ck_tile::index_t batch_size, ck_tile::index_t nhead, ck_tile::index_t max_seqlen_q, ck_tile::index_t hdim_v)
Definition: fmha_fwd_splitkv_combine_kernel.hpp:237
static constexpr __host__ auto BlockSize()
Definition: fmha_fwd_splitkv_combine_kernel.hpp:268
static constexpr index_t kNumWarps
Definition: fmha_fwd_splitkv_combine_kernel.hpp:14
remove_cvref_t< typename FmhaPipeline::LSEDataType > LSEDataType
Definition: fmha_fwd_splitkv_combine_kernel.hpp:20
static constexpr bool kPadSeqLenQ
Definition: fmha_fwd_splitkv_combine_kernel.hpp:25
static constexpr CK_TILE_DEVICE auto GetTileIndex(const Kargs &kargs)
Definition: fmha_fwd_splitkv_combine_kernel.hpp:249
remove_cvref_t< typename FmhaPipeline::OaccDataType > OaccDataType
Definition: fmha_fwd_splitkv_combine_kernel.hpp:21
static constexpr bool kPadHeadDimV
Definition: fmha_fwd_splitkv_combine_kernel.hpp:26
remove_cvref_t< EpiloguePipeline_ > EpiloguePipeline
Definition: fmha_fwd_splitkv_combine_kernel.hpp:12
static constexpr bool kDoFp8StaticQuant
Definition: fmha_fwd_splitkv_combine_kernel.hpp:28
static __host__ std::string GetName()
Definition: fmha_fwd_splitkv_combine_kernel.hpp:39
static constexpr index_t kBlockSize
Definition: fmha_fwd_splitkv_combine_kernel.hpp:15
static constexpr CK_TILE_HOST_DEVICE ck_tile::index_t GetSmemSize()
Definition: fmha_fwd_splitkv_combine_kernel.hpp:270
static constexpr index_t kBlockPerCu
Definition: fmha_fwd_splitkv_combine_kernel.hpp:16
remove_cvref_t< FmhaPipeline_ > FmhaPipeline
Definition: fmha_fwd_splitkv_combine_kernel.hpp:11
std::conditional_t< kIsGroupMode, GroupModeKargs, BatchModeKargs > Kargs
Definition: fmha_fwd_splitkv_combine_kernel.hpp:126
Definition: integral_constant.hpp:13
Definition: functional.hpp:62
Definition: unary_element_function.hpp:56
Definition: math.hpp:28
Definition: sequence.hpp:52