/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/fmha/kernel/fmha_fwd_appendkv_kernel.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/fmha/kernel/fmha_fwd_appendkv_kernel.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/fmha/kernel/fmha_fwd_appendkv_kernel.hpp Source File
fmha_fwd_appendkv_kernel.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include "ck_tile/core.hpp"
7 #include "ck_tile/ops/common.hpp"
8 #include <string>
9 #include <type_traits>
10 
11 namespace ck_tile {
12 
13 template <typename FmhaPipeline_>
15 {
17  static constexpr ck_tile::index_t kBlockSize = FmhaPipeline::kBlockSize;
18  static constexpr ck_tile::index_t kBlockPerCu = FmhaPipeline::kBlockPerCu;
19 
20  static_assert(kBlockPerCu > 0);
21  static constexpr ck_tile::index_t kBlockPerCuInput = FmhaPipeline::Problem::kBlockPerCu;
22 
26 
28  static constexpr bool kApplyRoPE = FmhaPipeline::RotaryEnum != RotaryEmbeddingEnum::NONE;
29  static constexpr bool kIsPagedKV = FmhaPipeline::kIsPagedKV;
30 
31  static constexpr bool kPadSeqLenQ = FmhaPipeline::kPadSeqLenQ;
32  static constexpr bool kPadSeqLenK = FmhaPipeline::kPadSeqLenK;
33  static constexpr bool kPadHeadDimQ = FmhaPipeline::kPadHeadDimQ;
34  static constexpr bool kPadHeadDimV = FmhaPipeline::kPadHeadDimV;
35 
36  // clang-format off
37  template <typename T> struct t2s;
38  template <> struct t2s<float> { static constexpr const char * name = "fp32"; };
39  template <> struct t2s<ck_tile::fp16_t> { static constexpr const char * name = "fp16"; };
40  template <> struct t2s<ck_tile::bf16_t> { static constexpr const char * name = "bf16"; };
41  template <> struct t2s<ck_tile::fp8_t> { static constexpr const char * name = "fp8"; };
42  template <> struct t2s<ck_tile::bf8_t> { static constexpr const char * name = "bf8"; };
43  // clang-format on
44 
45  __host__ static std::string GetName()
46  {
47  // sync with generate.py
48  // clang-format off
49 
50  #define _SS_ std::string
51  #define _TS_ std::to_string
52  auto pn = [&] () {
53  std::string n;
54  if (kPadSeqLenQ) n += "s";
55  if (kPadSeqLenK) n += "sk";
56  if (kPadHeadDimQ) n += "d";
57  if (kPadHeadDimV) n += "dv";
58  return n.empty() ? n : std::string("p") + n; }();
59  return
60  _SS_("fmha_fwd_appendkv_d") + _TS_(FmhaPipeline::kK0) + "_" + _SS_(t2s<QDataType>::name) + "_"
61  "b" + _TS_(FmhaPipeline::kM0) + "x" + _TS_(FmhaPipeline::kN0) + "x" + _TS_(FmhaPipeline::kK0) + "x" +
62  _TS_(FmhaPipeline::kN1) + "_" + (kBlockPerCuInput == -1 ? "" : ("o" + _TS_(kBlockPerCu) + "_")) +
63  "v" + (std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor> ? "r" : "c") + (pn.empty() ? "" : "_" + pn)
65  + (kIsPagedKV ? "_pagedkv" : "" );
66  #undef _SS_
67  #undef _TS_
68  // clang-format on
69  }
70 
71  template <ck_tile::index_t I> // to avoid duplicated base class prblem, introduce an template
72  // arg
73  struct EmptyKargs
74  {
75  };
76 
77  // kargs use aggregate initializer, so no constructor will provided
78  // use inheritance to minimize karg size
79  // user need to use MakeKargs() function to create kargs.
80  struct BasicKargs
81  {
82  void* q_ptr;
83  void* k_ptr;
84  const void* knew_ptr;
85  void* v_ptr;
86  const void* vnew_ptr;
87 
89 
95 
97  // for MQA/GQA, nhead could be different. This parameter is nhead_q / nhead_k
98  // if this param is larger than 1, indicate MQA/GQA case
100 
106 
112 
118  };
119 
120  struct RoPEKargs
121  {
122  const void* rotary_cos_ptr;
123  const void* rotary_sin_ptr;
125  bool has_mask;
126  };
127 
129  {
133  };
134 
136  {
138  };
139 
140  struct Kargs : BasicKargs,
141  std::conditional_t<kApplyRoPE, RoPEKargs, EmptyKargs<0>>,
142  std::conditional_t<kIsPagedKV, PageBlockTableKargs, CacheBatchIdxKargs>
143  {
144  };
145 
146  __host__ static constexpr Kargs MakeKargs(void* q_ptr,
147  void* k_ptr,
148  const void* knew_ptr,
149  void* v_ptr,
150  const void* vnew_ptr,
151  ck_tile::index_t seqlen_q,
152  const void* seqlen_k_ptr,
153  ck_tile::index_t seqlen_knew,
154  ck_tile::index_t hdim_q,
155  ck_tile::index_t hdim_v,
156  ck_tile::index_t num_head_q,
157  ck_tile::index_t nhead_ratio_qk,
158  const void* rotary_cos_ptr,
159  const void* rotary_sin_ptr,
160  ck_tile::index_t rotary_dim,
161  bool has_mask,
162  const void* block_table_ptr,
163  ck_tile::index_t batch_stride_block_table,
164  ck_tile::index_t page_block_size,
165  const void* cache_batch_idx,
166  ck_tile::index_t stride_q,
167  ck_tile::index_t stride_k,
168  ck_tile::index_t stride_knew,
169  ck_tile::index_t stride_v,
170  ck_tile::index_t stride_vnew,
171  ck_tile::index_t nhead_stride_q,
172  ck_tile::index_t nhead_stride_k,
173  ck_tile::index_t nhead_stride_knew,
174  ck_tile::index_t nhead_stride_v,
175  ck_tile::index_t nhead_stride_vnew,
176  ck_tile::index_t batch_stride_q,
177  ck_tile::index_t batch_stride_k,
178  ck_tile::index_t batch_stride_knew,
179  ck_tile::index_t batch_stride_v,
180  ck_tile::index_t batch_stride_vnew)
181  {
182  Kargs kargs{
183  {q_ptr,
184  k_ptr,
185  knew_ptr,
186  v_ptr,
187  vnew_ptr,
188  reinterpret_cast<const int32_t*>(seqlen_k_ptr),
189  seqlen_q,
190  -1, // seqlen_k will be updated by content of seqlen_k_ptr
191  seqlen_knew,
192  hdim_q,
193  hdim_v,
194  num_head_q,
195  nhead_ratio_qk,
196  stride_q,
197  stride_k,
198  stride_knew,
199  stride_v,
200  stride_vnew,
201  nhead_stride_q,
202  nhead_stride_k,
203  nhead_stride_knew,
204  nhead_stride_v,
205  nhead_stride_vnew,
206  batch_stride_q,
207  batch_stride_k,
208  batch_stride_knew,
209  batch_stride_v,
210  batch_stride_vnew}, // args for common karg
211  {}, // placeholder for rope
212  {} // placeholder for paged-block table or cache_batch_idx
213  };
214 
215  if constexpr(kApplyRoPE)
216  {
217  kargs.rotary_cos_ptr = rotary_cos_ptr;
218  kargs.rotary_sin_ptr = rotary_sin_ptr;
219  kargs.rotary_dim = rotary_dim;
220  kargs.has_mask = has_mask;
221  }
222 
223  if constexpr(kIsPagedKV)
224  {
225  kargs.block_table_ptr = reinterpret_cast<const int32_t*>(block_table_ptr);
226  kargs.batch_stride_block_table = batch_stride_block_table;
227  kargs.page_block_size = page_block_size;
228  }
229  else
230  {
231  kargs.cache_batch_idx = reinterpret_cast<const int32_t*>(cache_batch_idx);
232  }
233 
234  return kargs;
235  }
236 
237  CK_TILE_HOST static constexpr auto GridSize(ck_tile::index_t batch_size,
238  ck_tile::index_t nhead,
239  ck_tile::index_t seqlen_q,
240  ck_tile::index_t seqlen_knew)
241  {
242  // TODO: this may need tuning
243  return dim3(std::max(ck_tile::integer_divide_ceil(seqlen_q, FmhaPipeline::kM0),
244  ck_tile::integer_divide_ceil(seqlen_knew, FmhaPipeline::kN0)),
245  nhead,
246  batch_size);
247  }
248 
249  CK_TILE_DEVICE static constexpr auto GetTileIndex(const Kargs& /* kargs */)
250  {
251  const index_t i_tile = blockIdx.x;
252  const index_t i_nhead = blockIdx.y;
253  const index_t i_batch = blockIdx.z;
254 
255  return ck_tile::make_tuple(i_tile, i_nhead, i_batch);
256  }
257 
258  __host__ static constexpr auto BlockSize() { return dim3(kBlockSize); }
259 
260  CK_TILE_DEVICE void operator()(Kargs kargs) const
261  {
262  // divide problem
263  const auto [i_tile, i_nhead, i_batch] = GetTileIndex(kargs);
264 
265  const index_t i_m0 = __builtin_amdgcn_readfirstlane(i_tile * FmhaPipeline::kM0);
266  const index_t i_n0 = __builtin_amdgcn_readfirstlane(i_tile * FmhaPipeline::kN0);
267 
268  const index_t i_cache_batch = [&, i_batch_ = i_batch] {
269  if constexpr(kIsPagedKV)
270  {
271  return i_batch_;
272  }
273  else
274  {
275  return (kargs.cache_batch_idx != nullptr ? kargs.cache_batch_idx[i_batch_]
276  : i_batch_);
277  }
278  }();
279 
280  const long_index_t batch_offset_q =
281  static_cast<long_index_t>(i_batch) * kargs.batch_stride_q;
282  const long_index_t batch_offset_k =
283  static_cast<long_index_t>(i_cache_batch) * kargs.batch_stride_k;
284  const long_index_t batch_offset_knew =
285  static_cast<long_index_t>(i_batch) * kargs.batch_stride_knew;
286  const long_index_t batch_offset_v =
287  static_cast<long_index_t>(i_cache_batch) * kargs.batch_stride_v;
288  const long_index_t batch_offset_vnew =
289  static_cast<long_index_t>(i_batch) * kargs.batch_stride_vnew;
290 
291  kargs.seqlen_k = kargs.seqlen_k_ptr[i_batch];
292 
293  // for simplicity, batch stride we just modify the pointer
294  QDataType* q_ptr = reinterpret_cast<QDataType*>(kargs.q_ptr) +
295  static_cast<long_index_t>(i_nhead) * kargs.nhead_stride_q +
296  batch_offset_q;
297  KDataType* k_ptr =
298  reinterpret_cast<KDataType*>(kargs.k_ptr) +
299  static_cast<long_index_t>(i_nhead / kargs.nhead_ratio_qk) * kargs.nhead_stride_k +
300  batch_offset_k;
301  const KDataType* knew_ptr =
302  reinterpret_cast<const KDataType*>(kargs.knew_ptr) +
303  static_cast<long_index_t>(i_nhead / kargs.nhead_ratio_qk) * kargs.nhead_stride_knew +
304  batch_offset_knew;
305  VDataType* v_ptr =
306  reinterpret_cast<VDataType*>(kargs.v_ptr) +
307  static_cast<long_index_t>(i_nhead / kargs.nhead_ratio_qk) * kargs.nhead_stride_v +
308  batch_offset_v;
309  const VDataType* vnew_ptr =
310  reinterpret_cast<const VDataType*>(kargs.vnew_ptr) +
311  static_cast<long_index_t>(i_nhead / kargs.nhead_ratio_qk) * kargs.nhead_stride_vnew +
312  batch_offset_vnew;
313 
314  // Q/K/V DRAM and DRAM window
315  const auto q_dram = [&]() {
316  const auto q_dram_naive = make_naive_tensor_view<address_space_enum::global>(
317  q_ptr,
318  make_tuple(kargs.seqlen_q, kargs.hdim_q),
319  make_tuple(kargs.stride_q, 1),
321  number<1>{});
322 
323  return pad_tensor_view(
324  q_dram_naive,
327  }();
328 
329  const auto make_k_dram = [&](KDataType* data, index_t height) {
330  const auto k_dram_naive = make_naive_tensor_view<address_space_enum::global>(
331  data, // will update this pointer if using paged-kvcache
332  make_tuple(height, kargs.hdim_q),
333  make_tuple(kargs.stride_k, 1),
335  number<1>{});
336 
337  return pad_tensor_view(
338  k_dram_naive,
341  };
342  const auto k_dram = [&]() {
343  if constexpr(kIsPagedKV)
344  {
345  return make_k_dram(nullptr, kargs.page_block_size);
346  }
347  else
348  {
349  return make_k_dram(k_ptr, kargs.seqlen_k + kargs.seqlen_knew);
350  }
351  }();
352 
353  const auto knew_dram = [&]() {
354  const auto knew_dram_naive = make_naive_tensor_view<address_space_enum::global>(
355  knew_ptr,
356  make_tuple(kargs.seqlen_knew, kargs.hdim_q),
357  make_tuple(kargs.stride_knew, 1),
359  number<1>{});
360 
361  return pad_tensor_view(
362  knew_dram_naive,
365  }();
366 
367  const auto make_v_dram = [&](VDataType* data, index_t length) {
368  if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
369  {
370  const auto v_dram_naive = make_naive_tensor_view<address_space_enum::global>(
371  data, // will update this pointer if using paged-kvcache
372  make_tuple(length, kargs.hdim_v),
373  make_tuple(kargs.stride_v, 1),
375  number<1>{});
376 
377  const auto v_dram_transposed =
378  transform_tensor_view(v_dram_naive,
383 
384  return pad_tensor_view(
385  v_dram_transposed,
388  }
389  else
390  {
391  const auto v_dram_naive = make_naive_tensor_view<address_space_enum::global>(
392  data, // will update this pointer if using paged-kvcache
393  make_tuple(kargs.hdim_v, length),
394  make_tuple(kargs.stride_v, 1),
396  number<1>{});
397 
398  return pad_tensor_view(
399  v_dram_naive,
402  }
403  };
404  const auto v_dram = [&]() {
405  if constexpr(kIsPagedKV)
406  {
407  return make_v_dram(nullptr, kargs.page_block_size);
408  }
409  else
410  {
411  return make_v_dram(v_ptr, kargs.seqlen_k + kargs.seqlen_knew);
412  }
413  }();
414 
415  const auto vnew_dram = [&]() {
416  if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
417  {
418  const auto vnew_dram_naive = make_naive_tensor_view<address_space_enum::global>(
419  vnew_ptr,
420  make_tuple(kargs.seqlen_knew, kargs.hdim_v),
421  make_tuple(kargs.stride_vnew, 1),
423  number<1>{});
424 
425  const auto vnew_dram_transposed = transform_tensor_view(
426  vnew_dram_naive,
431 
432  return pad_tensor_view(
433  vnew_dram_transposed,
436  }
437  else
438  {
439  const auto vnew_dram_naive = make_naive_tensor_view<address_space_enum::global>(
440  vnew_ptr,
441  make_tuple(kargs.hdim_v, kargs.seqlen_knew),
442  make_tuple(kargs.stride_vnew, 1),
444  number<1>{});
445 
446  return pad_tensor_view(
447  vnew_dram_naive,
450  }
451  }();
452 
453  constexpr auto q_rotary_cos_sin_dram_window_lengths =
454  make_tuple(number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kK0 / 2>{});
455  const auto q_rotary_cos_dram_window = [&]() {
456  if constexpr(kApplyRoPE)
457  {
458  const auto rotary_cos_dram_native =
459  make_naive_tensor_view<address_space_enum::global>(
460  reinterpret_cast<const QDataType*>(kargs.rotary_cos_ptr) +
461  kargs.seqlen_k * (kargs.rotary_dim / 2),
462  make_tuple(kargs.seqlen_q, kargs.rotary_dim / 2),
463  make_tuple(kargs.has_mask * (kargs.rotary_dim / 2), 1),
464  number<8>{},
465  number<1>{});
466 
467  const auto rotary_cos_dram = [&]() {
468  return pad_tensor_view(rotary_cos_dram_native,
469  q_rotary_cos_sin_dram_window_lengths,
471  }();
472 
473  return make_tile_window(
474  rotary_cos_dram, q_rotary_cos_sin_dram_window_lengths, {i_m0, 0});
475  }
476  else
477  {
478  return make_null_tile_window(q_rotary_cos_sin_dram_window_lengths);
479  }
480  }();
481  const auto q_rotary_sin_dram_window = [&]() {
482  if constexpr(kApplyRoPE)
483  {
484  const auto rotary_sin_dram_native =
485  make_naive_tensor_view<address_space_enum::global>(
486  reinterpret_cast<const QDataType*>(kargs.rotary_sin_ptr) +
487  kargs.seqlen_k * (kargs.rotary_dim / 2),
488  make_tuple(kargs.seqlen_q, kargs.rotary_dim / 2),
489  make_tuple(kargs.has_mask * (kargs.rotary_dim / 2), 1),
490  number<8>{},
491  number<1>{});
492 
493  const auto rotary_sin_dram = [&]() {
494  return pad_tensor_view(rotary_sin_dram_native,
495  q_rotary_cos_sin_dram_window_lengths,
497  }();
498 
499  return make_tile_window(
500  rotary_sin_dram, q_rotary_cos_sin_dram_window_lengths, {i_m0, 0});
501  }
502  else
503  {
504  return make_null_tile_window(q_rotary_cos_sin_dram_window_lengths);
505  }
506  }();
507 
508  constexpr auto knew_rotary_cos_sin_dram_window_lengths =
509  make_tuple(number<FmhaPipeline::kN0>{}, number<FmhaPipeline::kK0 / 2>{});
510  const auto knew_rotary_cos_dram_window = [&]() {
511  if constexpr(kApplyRoPE)
512  {
513  const auto rotary_cos_dram_native =
514  make_naive_tensor_view<address_space_enum::global>(
515  reinterpret_cast<const KDataType*>(kargs.rotary_cos_ptr) +
516  kargs.seqlen_k * (kargs.rotary_dim / 2),
517  make_tuple(kargs.seqlen_knew, kargs.rotary_dim / 2),
518  make_tuple(kargs.rotary_dim / 2, 1),
519  number<8>{},
520  number<1>{});
521 
522  const auto rotary_cos_dram = [&]() {
523  return pad_tensor_view(rotary_cos_dram_native,
524  knew_rotary_cos_sin_dram_window_lengths,
526  }();
527 
528  return make_tile_window(
529  rotary_cos_dram, knew_rotary_cos_sin_dram_window_lengths, {i_n0, 0});
530  }
531  else
532  {
533  return make_null_tile_window(knew_rotary_cos_sin_dram_window_lengths);
534  }
535  }();
536  const auto knew_rotary_sin_dram_window = [&]() {
537  if constexpr(kApplyRoPE)
538  {
539  const auto rotary_sin_dram_native =
540  make_naive_tensor_view<address_space_enum::global>(
541  reinterpret_cast<const KDataType*>(kargs.rotary_sin_ptr) +
542  kargs.seqlen_k * (kargs.rotary_dim / 2),
543  make_tuple(kargs.seqlen_knew, kargs.rotary_dim / 2),
544  make_tuple(kargs.rotary_dim / 2, 1),
545  number<8>{},
546  number<1>{});
547 
548  const auto rotary_sin_dram = [&]() {
549  return pad_tensor_view(rotary_sin_dram_native,
550  knew_rotary_cos_sin_dram_window_lengths,
552  }();
553 
554  return make_tile_window(
555  rotary_sin_dram, knew_rotary_cos_sin_dram_window_lengths, {i_n0, 0});
556  }
557  else
558  {
559  return make_null_tile_window(knew_rotary_cos_sin_dram_window_lengths);
560  }
561  }();
562 
563  auto k_page_block_navigator = [&, i_batch_ = i_batch, i_nhead_ = i_nhead]() {
564  if constexpr(kIsPagedKV)
565  {
566  const auto* block_indices =
567  reinterpret_cast<const int32_t*>(kargs.block_table_ptr) +
568  i_batch_ * kargs.batch_stride_block_table;
569  const index_t num_blocks =
570  integer_divide_ceil(kargs.seqlen_k + kargs.seqlen_knew, kargs.page_block_size);
571 
572  const long_index_t fixed_offset =
573  static_cast<long_index_t>(i_nhead_ / kargs.nhead_ratio_qk) *
574  kargs.nhead_stride_k;
575 
576  return make_page_block_navigator<KDataType, 0>(
577  kargs.k_ptr,
578  kargs.batch_stride_k,
579  fixed_offset,
580  block_indices,
581  num_blocks,
582  kargs.page_block_size,
583  k_dram,
584  make_k_dram(nullptr,
585  (kargs.seqlen_k + kargs.seqlen_knew) -
586  (num_blocks - 1) * kargs.page_block_size));
587  }
588  else
589  {
590  return make_page_block_navigator(k_dram);
591  }
592  }();
593 
594  auto v_page_block_navigator = [&, i_batch_ = i_batch, i_nhead_ = i_nhead]() {
595  if constexpr(kIsPagedKV)
596  {
597  const auto* block_indices =
598  reinterpret_cast<const int32_t*>(kargs.block_table_ptr) +
599  i_batch_ * kargs.batch_stride_block_table;
600  const index_t num_blocks =
601  integer_divide_ceil(kargs.seqlen_k + kargs.seqlen_knew, kargs.page_block_size);
602 
603  const long_index_t fixed_offset =
604  static_cast<long_index_t>(i_nhead_ / kargs.nhead_ratio_qk) *
605  kargs.nhead_stride_v;
606 
607  return make_page_block_navigator<VDataType, 1>(
608  kargs.v_ptr,
609  kargs.batch_stride_v,
610  fixed_offset,
611  block_indices,
612  num_blocks,
613  kargs.page_block_size,
614  v_dram,
615  make_v_dram(nullptr,
616  (kargs.seqlen_k + kargs.seqlen_knew) -
617  (num_blocks - 1) * kargs.page_block_size));
618  }
619  else
620  {
621  return make_page_block_navigator(v_dram);
622  }
623  }();
624 
625  auto q_dram_window =
626  make_tile_window(q_dram,
628  {i_m0, 0});
629 
630  const bool skip_append_kv = kargs.seqlen_knew <= i_n0;
631  // window origin = (0, 0) if no work to do for current block
632  auto [i_page_block_k, k_dram_window] = k_page_block_navigator.make_tile_window(
634  {!skip_append_kv * (kargs.seqlen_k + i_n0), 0});
635 
636  auto knew_dram_window =
637  make_tile_window(knew_dram,
639  {i_n0, 0});
640 
641  // window origin = (0, 0) if no work to do for current block
642  auto [i_page_block_v, v_dram_window] = v_page_block_navigator.make_tile_window(
644  {0, !skip_append_kv * (kargs.seqlen_k + i_n0)});
645 
646  auto vnew_dram_window =
647  make_tile_window(vnew_dram,
649  {0, i_n0});
650 
651  // If kApplyRoPe is false, we set the rotary_dim to 0
652  auto rotary_dim = [&]() {
653  if constexpr(kApplyRoPE)
654  return kargs.rotary_dim;
655  else
656  return 0;
657  }();
658  FmhaPipeline{}(q_dram_window,
659  k_dram_window,
660  i_page_block_k,
661  k_page_block_navigator,
662  knew_dram_window,
663  v_dram_window,
664  i_page_block_v,
665  v_page_block_navigator,
666  vnew_dram_window,
667  q_rotary_cos_dram_window,
668  q_rotary_sin_dram_window,
669  knew_rotary_cos_dram_window,
670  knew_rotary_sin_dram_window,
671  rotary_dim,
672  kargs.seqlen_q <= i_m0,
673  skip_append_kv);
674  }
675 };
676 
677 } // namespace ck_tile
#define CK_TILE_DEVICE
Definition: config.hpp:41
#define CK_TILE_HOST
Definition: config.hpp:40
#define _TS_
#define _SS_
__host__ constexpr __device__ T max(T x)
Definition: math.hpp:84
Definition: cluster_descriptor.hpp:13
constexpr CK_TILE_DEVICE auto make_null_tile_window(const WindowLengths &window_lengths)
Definition: null_tile_window.hpp:66
_BitInt(8) fp8_t
Definition: float8.hpp:204
constexpr CK_TILE_HOST_DEVICE auto integer_divide_ceil(X x, Y y)
Definition: math.hpp:149
_Float16 fp16_t
Definition: half.hpp:110
CK_TILE_HOST_DEVICE auto make_page_block_navigator(const TensorView &tensor_view)
Definition: page_block_navigator.hpp:333
constexpr CK_TILE_HOST_DEVICE auto transform_tensor_view(const OldTensorView &old_tensor_view, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition: tensor_view.hpp:511
bfloat16_t bf16_t
Definition: bfloat16.hpp:113
int32_t index_t
Definition: integer.hpp:9
constexpr CK_TILE_HOST_DEVICE auto pad_tensor_view(const TensorView &tensor_view, const TileLengths &tile_lengths, DoPads)
Definition: tensor_view.hpp:530
constexpr CK_TILE_HOST_DEVICE auto make_pass_through_transform(const LowLength &low_length)
Definition: coordinate_transform.hpp:1558
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition: type_traits.hpp:21
int64_t long_index_t
Definition: integer.hpp:11
int32_t int32_t
Definition: integer.hpp:10
constexpr CK_TILE_DEVICE auto make_tile_window(null_tensor_view, const WindowLengths &window_lengths, const multi_index< WindowLengths::size()> &, Ts &&...)
Definition: null_tile_window.hpp:75
unsigned _BitInt(8) bf8_t
Definition: float8.hpp:206
constexpr CK_TILE_HOST_DEVICE auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:360
typename conditional< predicate, X, Y >::type conditional_t
Definition: functional.hpp:115
Definition: fmha_fwd_appendkv_kernel.hpp:81
ck_tile::index_t stride_q
Definition: fmha_fwd_appendkv_kernel.hpp:101
const int32_t * seqlen_k_ptr
Definition: fmha_fwd_appendkv_kernel.hpp:88
const void * knew_ptr
Definition: fmha_fwd_appendkv_kernel.hpp:84
ck_tile::index_t stride_k
Definition: fmha_fwd_appendkv_kernel.hpp:102
ck_tile::index_t batch_stride_knew
Definition: fmha_fwd_appendkv_kernel.hpp:115
ck_tile::index_t batch_stride_v
Definition: fmha_fwd_appendkv_kernel.hpp:116
ck_tile::index_t nhead_stride_knew
Definition: fmha_fwd_appendkv_kernel.hpp:109
ck_tile::index_t nhead_stride_vnew
Definition: fmha_fwd_appendkv_kernel.hpp:111
ck_tile::index_t batch_stride_k
Definition: fmha_fwd_appendkv_kernel.hpp:114
ck_tile::index_t stride_knew
Definition: fmha_fwd_appendkv_kernel.hpp:103
ck_tile::index_t nhead_stride_q
Definition: fmha_fwd_appendkv_kernel.hpp:107
ck_tile::index_t hdim_q
Definition: fmha_fwd_appendkv_kernel.hpp:93
ck_tile::index_t stride_v
Definition: fmha_fwd_appendkv_kernel.hpp:104
ck_tile::index_t batch_stride_q
Definition: fmha_fwd_appendkv_kernel.hpp:113
ck_tile::index_t nhead_stride_k
Definition: fmha_fwd_appendkv_kernel.hpp:108
ck_tile::index_t hdim_v
Definition: fmha_fwd_appendkv_kernel.hpp:94
ck_tile::index_t batch_stride_vnew
Definition: fmha_fwd_appendkv_kernel.hpp:117
void * q_ptr
Definition: fmha_fwd_appendkv_kernel.hpp:82
ck_tile::index_t stride_vnew
Definition: fmha_fwd_appendkv_kernel.hpp:105
void * v_ptr
Definition: fmha_fwd_appendkv_kernel.hpp:85
const void * vnew_ptr
Definition: fmha_fwd_appendkv_kernel.hpp:86
ck_tile::index_t nhead_stride_v
Definition: fmha_fwd_appendkv_kernel.hpp:110
void * k_ptr
Definition: fmha_fwd_appendkv_kernel.hpp:83
ck_tile::index_t seqlen_k
Definition: fmha_fwd_appendkv_kernel.hpp:91
ck_tile::index_t seqlen_q
Definition: fmha_fwd_appendkv_kernel.hpp:90
ck_tile::index_t seqlen_knew
Definition: fmha_fwd_appendkv_kernel.hpp:92
ck_tile::index_t num_head_q
Definition: fmha_fwd_appendkv_kernel.hpp:96
ck_tile::index_t nhead_ratio_qk
Definition: fmha_fwd_appendkv_kernel.hpp:99
Definition: fmha_fwd_appendkv_kernel.hpp:136
const int32_t * cache_batch_idx
Definition: fmha_fwd_appendkv_kernel.hpp:137
Definition: fmha_fwd_appendkv_kernel.hpp:74
Definition: fmha_fwd_appendkv_kernel.hpp:143
Definition: fmha_fwd_appendkv_kernel.hpp:129
ck_tile::index_t batch_stride_block_table
Definition: fmha_fwd_appendkv_kernel.hpp:131
const int32_t * block_table_ptr
Definition: fmha_fwd_appendkv_kernel.hpp:130
ck_tile::index_t page_block_size
Definition: fmha_fwd_appendkv_kernel.hpp:132
Definition: fmha_fwd_appendkv_kernel.hpp:121
ck_tile::index_t rotary_dim
Definition: fmha_fwd_appendkv_kernel.hpp:124
const void * rotary_sin_ptr
Definition: fmha_fwd_appendkv_kernel.hpp:123
bool has_mask
Definition: fmha_fwd_appendkv_kernel.hpp:125
const void * rotary_cos_ptr
Definition: fmha_fwd_appendkv_kernel.hpp:122
Definition: fmha_fwd_appendkv_kernel.hpp:37
Definition: fmha_fwd_appendkv_kernel.hpp:15
static constexpr bool kPadHeadDimV
Definition: fmha_fwd_appendkv_kernel.hpp:34
static constexpr ck_tile::index_t kBlockPerCuInput
Definition: fmha_fwd_appendkv_kernel.hpp:21
static constexpr __host__ auto BlockSize()
Definition: fmha_fwd_appendkv_kernel.hpp:258
ck_tile::remove_cvref_t< FmhaPipeline_ > FmhaPipeline
Definition: fmha_fwd_appendkv_kernel.hpp:16
ck_tile::remove_cvref_t< typename FmhaPipeline::QDataType > QDataType
Definition: fmha_fwd_appendkv_kernel.hpp:23
ck_tile::remove_cvref_t< typename FmhaPipeline::KDataType > KDataType
Definition: fmha_fwd_appendkv_kernel.hpp:24
static constexpr CK_TILE_HOST auto GridSize(ck_tile::index_t batch_size, ck_tile::index_t nhead, ck_tile::index_t seqlen_q, ck_tile::index_t seqlen_knew)
Definition: fmha_fwd_appendkv_kernel.hpp:237
static constexpr bool kPadSeqLenQ
Definition: fmha_fwd_appendkv_kernel.hpp:31
CK_TILE_DEVICE void operator()(Kargs kargs) const
Definition: fmha_fwd_appendkv_kernel.hpp:260
ck_tile::remove_cvref_t< typename FmhaPipeline::VLayout > VLayout
Definition: fmha_fwd_appendkv_kernel.hpp:27
static constexpr ck_tile::index_t kBlockPerCu
Definition: fmha_fwd_appendkv_kernel.hpp:18
static constexpr __host__ Kargs MakeKargs(void *q_ptr, void *k_ptr, const void *knew_ptr, void *v_ptr, const void *vnew_ptr, ck_tile::index_t seqlen_q, const void *seqlen_k_ptr, ck_tile::index_t seqlen_knew, ck_tile::index_t hdim_q, ck_tile::index_t hdim_v, ck_tile::index_t num_head_q, ck_tile::index_t nhead_ratio_qk, const void *rotary_cos_ptr, const void *rotary_sin_ptr, ck_tile::index_t rotary_dim, bool has_mask, const void *block_table_ptr, ck_tile::index_t batch_stride_block_table, ck_tile::index_t page_block_size, const void *cache_batch_idx, ck_tile::index_t stride_q, ck_tile::index_t stride_k, ck_tile::index_t stride_knew, ck_tile::index_t stride_v, ck_tile::index_t stride_vnew, ck_tile::index_t nhead_stride_q, ck_tile::index_t nhead_stride_k, ck_tile::index_t nhead_stride_knew, ck_tile::index_t nhead_stride_v, ck_tile::index_t nhead_stride_vnew, ck_tile::index_t batch_stride_q, ck_tile::index_t batch_stride_k, ck_tile::index_t batch_stride_knew, ck_tile::index_t batch_stride_v, ck_tile::index_t batch_stride_vnew)
Definition: fmha_fwd_appendkv_kernel.hpp:146
ck_tile::remove_cvref_t< typename FmhaPipeline::VDataType > VDataType
Definition: fmha_fwd_appendkv_kernel.hpp:25
static constexpr bool kPadHeadDimQ
Definition: fmha_fwd_appendkv_kernel.hpp:33
static constexpr bool kPadSeqLenK
Definition: fmha_fwd_appendkv_kernel.hpp:32
static constexpr CK_TILE_DEVICE auto GetTileIndex(const Kargs &)
Definition: fmha_fwd_appendkv_kernel.hpp:249
static constexpr bool kIsPagedKV
Definition: fmha_fwd_appendkv_kernel.hpp:29
static constexpr bool kApplyRoPE
Definition: fmha_fwd_appendkv_kernel.hpp:28
static __host__ std::string GetName()
Definition: fmha_fwd_appendkv_kernel.hpp:45
static constexpr ck_tile::index_t kBlockSize
Definition: fmha_fwd_appendkv_kernel.hpp:17
Definition: block_rotary_embedding.hpp:19
Definition: integral_constant.hpp:13
Definition: sequence.hpp:49