include/ck_tile/ops/fmha/kernel/fmha_fwd_appendkv_kernel.hpp Source File

include/ck_tile/ops/fmha/kernel/fmha_fwd_appendkv_kernel.hpp Source File#

Composable Kernel: include/ck_tile/ops/fmha/kernel/fmha_fwd_appendkv_kernel.hpp Source File
fmha_fwd_appendkv_kernel.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include "ck_tile/core.hpp"
7 #include "ck_tile/ops/common.hpp"
8 #include <string>
9 #include <type_traits>
10 
11 namespace ck_tile {
12 
13 template <typename FmhaPipeline_>
15 {
17  static constexpr ck_tile::index_t kBlockSize = FmhaPipeline::kBlockSize;
18  static constexpr ck_tile::index_t kBlockPerCu = FmhaPipeline::kBlockPerCu;
19  static_assert(kBlockPerCu > 0);
20  static constexpr ck_tile::index_t kBlockPerCuInput = FmhaPipeline::Problem::kBlockPerCu;
21 
25 
27  static constexpr bool kApplyRoPE = FmhaPipeline::RotaryEnum != RotaryEmbeddingEnum::NONE;
28  static constexpr bool kIsPagedKV = FmhaPipeline::kIsPagedKV;
29 
30  static constexpr bool kPadSeqLenQ = FmhaPipeline::kPadSeqLenQ;
31  static constexpr bool kPadSeqLenK = FmhaPipeline::kPadSeqLenK;
32  static constexpr bool kPadHeadDimQ = FmhaPipeline::kPadHeadDimQ;
33  static constexpr bool kPadHeadDimV = FmhaPipeline::kPadHeadDimV;
34 
35  // clang-format off
36  template <typename T> struct t2s;
37  template <> struct t2s<float> { static constexpr const char * name = "fp32"; };
38  template <> struct t2s<ck_tile::fp16_t> { static constexpr const char * name = "fp16"; };
39  template <> struct t2s<ck_tile::bf16_t> { static constexpr const char * name = "bf16"; };
40  template <> struct t2s<ck_tile::fp8_t> { static constexpr const char * name = "fp8"; };
41  template <> struct t2s<ck_tile::bf8_t> { static constexpr const char * name = "bf8"; };
42  // clang-format on
43 
44  __host__ static std::string GetName()
45  {
46  // sync with generate.py
47  // clang-format off
48 
49  #define _SS_ std::string
50  #define _TS_ std::to_string
51  auto pn = [&] () {
52  std::string n;
53  if (kPadSeqLenQ) n += "s";
54  if (kPadSeqLenK) n += "sk";
55  if (kPadHeadDimQ) n += "d";
56  if (kPadHeadDimV) n += "dv";
57  return n.empty() ? n : std::string("p") + n; }();
58  return
59  _SS_("fmha_fwd_appendkv_d") + _TS_(FmhaPipeline::kK0) + "_" + _SS_(t2s<QDataType>::name) + "_"
60  "b" + _TS_(FmhaPipeline::kM0) + "x" + _TS_(FmhaPipeline::kN0) + "x" + _TS_(FmhaPipeline::kK0) + "x" +
61  _TS_(FmhaPipeline::kN1) + "_" + (kBlockPerCuInput == -1 ? "" : ("o" + _TS_(kBlockPerCu) + "_")) +
62  "v" + (std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor> ? "r" : "c") + (pn.empty() ? "" : "_" + pn)
64  + (kIsPagedKV ? "_pagedkv" : "" );
65  #undef _SS_
66  #undef _TS_
67  // clang-format on
68  }
69 
70  template <ck_tile::index_t I> // to avoid duplicated base class prblem, introduce an template
71  // arg
72  struct EmptyKargs
73  {
74  };
75 
76  // kargs use aggregate initializer, so no constructor will provided
77  // use inheritance to minimize karg size
78  // user need to use MakeKargs() function to create kargs.
79  struct BasicKargs
80  {
81  void* q_ptr;
82  void* k_ptr;
83  const void* knew_ptr;
84  void* v_ptr;
85  const void* vnew_ptr;
86 
87  const int32_t* seqlen_k_ptr;
88 
94 
96  // for MQA/GQA, nhead could be different. This parameter is nhead_q / nhead_k
97  // if this param is larger than 1, indicate MQA/GQA case
99 
105 
111 
117  };
118 
119  struct RoPEKargs
120  {
121  const void* rotary_cos_ptr;
122  const void* rotary_sin_ptr;
124  bool has_mask;
125  };
126 
128  {
129  const int32_t* block_table_ptr;
132  };
133 
135  {
136  const int32_t* cache_batch_idx;
137  };
138 
139  struct Kargs : BasicKargs,
140  std::conditional_t<kApplyRoPE, RoPEKargs, EmptyKargs<0>>,
141  std::conditional_t<kIsPagedKV, PageBlockTableKargs, CacheBatchIdxKargs>
142  {
143  };
144 
145  __host__ static constexpr Kargs MakeKargs(void* q_ptr,
146  void* k_ptr,
147  const void* knew_ptr,
148  void* v_ptr,
149  const void* vnew_ptr,
150  ck_tile::index_t seqlen_q,
151  const void* seqlen_k_ptr,
152  ck_tile::index_t seqlen_knew,
153  ck_tile::index_t hdim_q,
154  ck_tile::index_t hdim_v,
155  ck_tile::index_t num_head_q,
156  ck_tile::index_t nhead_ratio_qk,
157  const void* rotary_cos_ptr,
158  const void* rotary_sin_ptr,
159  ck_tile::index_t rotary_dim,
160  bool has_mask,
161  const void* block_table_ptr,
162  ck_tile::index_t batch_stride_block_table,
163  ck_tile::index_t page_block_size,
164  const void* cache_batch_idx,
165  ck_tile::index_t stride_q,
166  ck_tile::index_t stride_k,
167  ck_tile::index_t stride_knew,
168  ck_tile::index_t stride_v,
169  ck_tile::index_t stride_vnew,
170  ck_tile::index_t nhead_stride_q,
171  ck_tile::index_t nhead_stride_k,
172  ck_tile::index_t nhead_stride_knew,
173  ck_tile::index_t nhead_stride_v,
174  ck_tile::index_t nhead_stride_vnew,
175  ck_tile::index_t batch_stride_q,
176  ck_tile::index_t batch_stride_k,
177  ck_tile::index_t batch_stride_knew,
178  ck_tile::index_t batch_stride_v,
179  ck_tile::index_t batch_stride_vnew)
180  {
181  Kargs kargs{
182  {q_ptr,
183  k_ptr,
184  knew_ptr,
185  v_ptr,
186  vnew_ptr,
187  reinterpret_cast<const int32_t*>(seqlen_k_ptr),
188  seqlen_q,
189  -1, // seqlen_k will be updated by content of seqlen_k_ptr
190  seqlen_knew,
191  hdim_q,
192  hdim_v,
193  num_head_q,
194  nhead_ratio_qk,
195  stride_q,
196  stride_k,
197  stride_knew,
198  stride_v,
199  stride_vnew,
200  nhead_stride_q,
201  nhead_stride_k,
202  nhead_stride_knew,
203  nhead_stride_v,
204  nhead_stride_vnew,
205  batch_stride_q,
206  batch_stride_k,
207  batch_stride_knew,
208  batch_stride_v,
209  batch_stride_vnew}, // args for common karg
210  {}, // placeholder for rope
211  {} // placeholder for paged-block table or cache_batch_idx
212  };
213 
214  if constexpr(kApplyRoPE)
215  {
216  kargs.rotary_cos_ptr = rotary_cos_ptr;
217  kargs.rotary_sin_ptr = rotary_sin_ptr;
218  kargs.rotary_dim = rotary_dim;
219  kargs.has_mask = has_mask;
220  }
221 
222  if constexpr(kIsPagedKV)
223  {
224  kargs.block_table_ptr = reinterpret_cast<const int32_t*>(block_table_ptr);
225  kargs.batch_stride_block_table = batch_stride_block_table;
226  kargs.page_block_size = page_block_size;
227  }
228  else
229  {
230  kargs.cache_batch_idx = reinterpret_cast<const int32_t*>(cache_batch_idx);
231  }
232 
233  return kargs;
234  }
235 
236  CK_TILE_HOST static constexpr auto GridSize(ck_tile::index_t batch_size,
237  ck_tile::index_t nhead,
238  ck_tile::index_t seqlen_q,
239  ck_tile::index_t seqlen_knew)
240  {
241  // TODO: this may need tuning
242  return dim3(std::max(ck_tile::integer_divide_ceil(seqlen_q, FmhaPipeline::kM0),
243  ck_tile::integer_divide_ceil(seqlen_knew, FmhaPipeline::kN0)),
244  nhead,
245  batch_size);
246  }
247 
248  CK_TILE_DEVICE static constexpr auto GetTileIndex(const Kargs& /* kargs */)
249  {
250  const index_t i_tile = blockIdx.x;
251  const index_t i_nhead = blockIdx.y;
252  const index_t i_batch = blockIdx.z;
253 
254  return ck_tile::make_tuple(i_tile, i_nhead, i_batch);
255  }
256 
257  __host__ static constexpr auto BlockSize() { return dim3(kBlockSize); }
258 
259  CK_TILE_DEVICE void operator()(Kargs kargs) const
260  {
261  // divide problem
262  const auto [i_tile, i_nhead, i_batch] = GetTileIndex(kargs);
263 
264  const index_t i_m0 = __builtin_amdgcn_readfirstlane(i_tile * FmhaPipeline::kM0);
265  const index_t i_n0 = __builtin_amdgcn_readfirstlane(i_tile * FmhaPipeline::kN0);
266 
267  const index_t i_cache_batch = [&, i_batch_ = i_batch] {
268  if constexpr(kIsPagedKV)
269  {
270  return i_batch_;
271  }
272  else
273  {
274  return (kargs.cache_batch_idx != nullptr ? kargs.cache_batch_idx[i_batch_]
275  : i_batch_);
276  }
277  }();
278 
279  const long_index_t batch_offset_q =
280  static_cast<long_index_t>(i_batch) * kargs.batch_stride_q;
281  const long_index_t batch_offset_k =
282  static_cast<long_index_t>(i_cache_batch) * kargs.batch_stride_k;
283  const long_index_t batch_offset_knew =
284  static_cast<long_index_t>(i_batch) * kargs.batch_stride_knew;
285  const long_index_t batch_offset_v =
286  static_cast<long_index_t>(i_cache_batch) * kargs.batch_stride_v;
287  const long_index_t batch_offset_vnew =
288  static_cast<long_index_t>(i_batch) * kargs.batch_stride_vnew;
289 
290  kargs.seqlen_k = kargs.seqlen_k_ptr[i_batch];
291 
292  // for simplicity, batch stride we just modify the pointer
293  QDataType* q_ptr = reinterpret_cast<QDataType*>(kargs.q_ptr) +
294  static_cast<long_index_t>(i_nhead) * kargs.nhead_stride_q +
295  batch_offset_q;
296  KDataType* k_ptr =
297  reinterpret_cast<KDataType*>(kargs.k_ptr) +
298  static_cast<long_index_t>(i_nhead / kargs.nhead_ratio_qk) * kargs.nhead_stride_k +
299  batch_offset_k;
300  const KDataType* knew_ptr =
301  reinterpret_cast<const KDataType*>(kargs.knew_ptr) +
302  static_cast<long_index_t>(i_nhead / kargs.nhead_ratio_qk) * kargs.nhead_stride_knew +
303  batch_offset_knew;
304  VDataType* v_ptr =
305  reinterpret_cast<VDataType*>(kargs.v_ptr) +
306  static_cast<long_index_t>(i_nhead / kargs.nhead_ratio_qk) * kargs.nhead_stride_v +
307  batch_offset_v;
308  const VDataType* vnew_ptr =
309  reinterpret_cast<const VDataType*>(kargs.vnew_ptr) +
310  static_cast<long_index_t>(i_nhead / kargs.nhead_ratio_qk) * kargs.nhead_stride_vnew +
311  batch_offset_vnew;
312 
313  // Q/K/V DRAM and DRAM window
314  const auto q_dram = [&]() {
315  const auto q_dram_naive = make_naive_tensor_view<address_space_enum::global>(
316  q_ptr,
317  make_tuple(kargs.seqlen_q, kargs.hdim_q),
318  make_tuple(kargs.stride_q, 1),
320  number<1>{});
321 
322  return pad_tensor_view(
323  q_dram_naive,
326  }();
327 
328  const auto make_k_dram = [&](KDataType* data, index_t height) {
329  const auto k_dram_naive = make_naive_tensor_view<address_space_enum::global>(
330  data, // will update this pointer if using paged-kvcache
331  make_tuple(height, kargs.hdim_q),
332  make_tuple(kargs.stride_k, 1),
334  number<1>{});
335 
336  return pad_tensor_view(
337  k_dram_naive,
340  };
341  const auto k_dram = [&]() {
342  if constexpr(kIsPagedKV)
343  {
344  return make_k_dram(nullptr, kargs.page_block_size);
345  }
346  else
347  {
348  return make_k_dram(k_ptr, kargs.seqlen_k + kargs.seqlen_knew);
349  }
350  }();
351 
352  const auto knew_dram = [&]() {
353  const auto knew_dram_naive = make_naive_tensor_view<address_space_enum::global>(
354  knew_ptr,
355  make_tuple(kargs.seqlen_knew, kargs.hdim_q),
356  make_tuple(kargs.stride_knew, 1),
358  number<1>{});
359 
360  return pad_tensor_view(
361  knew_dram_naive,
364  }();
365 
366  const auto make_v_dram = [&](VDataType* data, index_t length) {
367  if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
368  {
369  const auto v_dram_naive = make_naive_tensor_view<address_space_enum::global>(
370  data, // will update this pointer if using paged-kvcache
371  make_tuple(length, kargs.hdim_v),
372  make_tuple(kargs.stride_v, 1),
374  number<1>{});
375 
376  const auto v_dram_transposed =
377  transform_tensor_view(v_dram_naive,
382 
383  return pad_tensor_view(
384  v_dram_transposed,
387  }
388  else
389  {
390  const auto v_dram_naive = make_naive_tensor_view<address_space_enum::global>(
391  data, // will update this pointer if using paged-kvcache
392  make_tuple(kargs.hdim_v, length),
393  make_tuple(kargs.stride_v, 1),
395  number<1>{});
396 
397  return pad_tensor_view(
398  v_dram_naive,
401  }
402  };
403  const auto v_dram = [&]() {
404  if constexpr(kIsPagedKV)
405  {
406  return make_v_dram(nullptr, kargs.page_block_size);
407  }
408  else
409  {
410  return make_v_dram(v_ptr, kargs.seqlen_k + kargs.seqlen_knew);
411  }
412  }();
413 
414  const auto vnew_dram = [&]() {
415  if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
416  {
417  const auto vnew_dram_naive = make_naive_tensor_view<address_space_enum::global>(
418  vnew_ptr,
419  make_tuple(kargs.seqlen_knew, kargs.hdim_v),
420  make_tuple(kargs.stride_vnew, 1),
422  number<1>{});
423 
424  const auto vnew_dram_transposed = transform_tensor_view(
425  vnew_dram_naive,
430 
431  return pad_tensor_view(
432  vnew_dram_transposed,
435  }
436  else
437  {
438  const auto vnew_dram_naive = make_naive_tensor_view<address_space_enum::global>(
439  vnew_ptr,
440  make_tuple(kargs.hdim_v, kargs.seqlen_knew),
441  make_tuple(kargs.stride_vnew, 1),
443  number<1>{});
444 
445  return pad_tensor_view(
446  vnew_dram_naive,
449  }
450  }();
451 
452  constexpr auto q_rotary_cos_sin_dram_window_lengths =
453  make_tuple(number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kK0 / 2>{});
454  const auto q_rotary_cos_dram_window = [&]() {
455  if constexpr(kApplyRoPE)
456  {
457  const auto rotary_cos_dram_native =
458  make_naive_tensor_view<address_space_enum::global>(
459  reinterpret_cast<const QDataType*>(kargs.rotary_cos_ptr) +
460  kargs.seqlen_k * (kargs.rotary_dim / 2),
461  make_tuple(kargs.seqlen_q, kargs.rotary_dim / 2),
462  make_tuple(kargs.has_mask * (kargs.rotary_dim / 2), 1),
463  number<8>{},
464  number<1>{});
465 
466  const auto rotary_cos_dram = [&]() {
467  return pad_tensor_view(rotary_cos_dram_native,
468  q_rotary_cos_sin_dram_window_lengths,
470  }();
471 
472  return make_tile_window(
473  rotary_cos_dram, q_rotary_cos_sin_dram_window_lengths, {i_m0, 0});
474  }
475  else
476  {
477  return make_null_tile_window(q_rotary_cos_sin_dram_window_lengths);
478  }
479  }();
480  const auto q_rotary_sin_dram_window = [&]() {
481  if constexpr(kApplyRoPE)
482  {
483  const auto rotary_sin_dram_native =
484  make_naive_tensor_view<address_space_enum::global>(
485  reinterpret_cast<const QDataType*>(kargs.rotary_sin_ptr) +
486  kargs.seqlen_k * (kargs.rotary_dim / 2),
487  make_tuple(kargs.seqlen_q, kargs.rotary_dim / 2),
488  make_tuple(kargs.has_mask * (kargs.rotary_dim / 2), 1),
489  number<8>{},
490  number<1>{});
491 
492  const auto rotary_sin_dram = [&]() {
493  return pad_tensor_view(rotary_sin_dram_native,
494  q_rotary_cos_sin_dram_window_lengths,
496  }();
497 
498  return make_tile_window(
499  rotary_sin_dram, q_rotary_cos_sin_dram_window_lengths, {i_m0, 0});
500  }
501  else
502  {
503  return make_null_tile_window(q_rotary_cos_sin_dram_window_lengths);
504  }
505  }();
506 
507  constexpr auto knew_rotary_cos_sin_dram_window_lengths =
508  make_tuple(number<FmhaPipeline::kN0>{}, number<FmhaPipeline::kK0 / 2>{});
509  const auto knew_rotary_cos_dram_window = [&]() {
510  if constexpr(kApplyRoPE)
511  {
512  const auto rotary_cos_dram_native =
513  make_naive_tensor_view<address_space_enum::global>(
514  reinterpret_cast<const KDataType*>(kargs.rotary_cos_ptr) +
515  kargs.seqlen_k * (kargs.rotary_dim / 2),
516  make_tuple(kargs.seqlen_knew, kargs.rotary_dim / 2),
517  make_tuple(kargs.rotary_dim / 2, 1),
518  number<8>{},
519  number<1>{});
520 
521  const auto rotary_cos_dram = [&]() {
522  return pad_tensor_view(rotary_cos_dram_native,
523  knew_rotary_cos_sin_dram_window_lengths,
525  }();
526 
527  return make_tile_window(
528  rotary_cos_dram, knew_rotary_cos_sin_dram_window_lengths, {i_n0, 0});
529  }
530  else
531  {
532  return make_null_tile_window(knew_rotary_cos_sin_dram_window_lengths);
533  }
534  }();
535  const auto knew_rotary_sin_dram_window = [&]() {
536  if constexpr(kApplyRoPE)
537  {
538  const auto rotary_sin_dram_native =
539  make_naive_tensor_view<address_space_enum::global>(
540  reinterpret_cast<const KDataType*>(kargs.rotary_sin_ptr) +
541  kargs.seqlen_k * (kargs.rotary_dim / 2),
542  make_tuple(kargs.seqlen_knew, kargs.rotary_dim / 2),
543  make_tuple(kargs.rotary_dim / 2, 1),
544  number<8>{},
545  number<1>{});
546 
547  const auto rotary_sin_dram = [&]() {
548  return pad_tensor_view(rotary_sin_dram_native,
549  knew_rotary_cos_sin_dram_window_lengths,
551  }();
552 
553  return make_tile_window(
554  rotary_sin_dram, knew_rotary_cos_sin_dram_window_lengths, {i_n0, 0});
555  }
556  else
557  {
558  return make_null_tile_window(knew_rotary_cos_sin_dram_window_lengths);
559  }
560  }();
561 
562  auto k_page_block_navigator = [&, i_batch_ = i_batch, i_nhead_ = i_nhead]() {
563  if constexpr(kIsPagedKV)
564  {
565  const auto* block_indices =
566  reinterpret_cast<const int32_t*>(kargs.block_table_ptr) +
567  i_batch_ * kargs.batch_stride_block_table;
568  const index_t num_blocks =
569  integer_divide_ceil(kargs.seqlen_k + kargs.seqlen_knew, kargs.page_block_size);
570 
571  const long_index_t fixed_offset =
572  static_cast<long_index_t>(i_nhead_ / kargs.nhead_ratio_qk) *
573  kargs.nhead_stride_k;
574 
575  return make_page_block_navigator<KDataType, 0>(
576  kargs.k_ptr,
577  kargs.batch_stride_k,
578  fixed_offset,
579  block_indices,
580  num_blocks,
581  kargs.page_block_size,
582  k_dram,
583  make_k_dram(nullptr,
584  (kargs.seqlen_k + kargs.seqlen_knew) -
585  (num_blocks - 1) * kargs.page_block_size));
586  }
587  else
588  {
589  return make_page_block_navigator(k_dram);
590  }
591  }();
592 
593  auto v_page_block_navigator = [&, i_batch_ = i_batch, i_nhead_ = i_nhead]() {
594  if constexpr(kIsPagedKV)
595  {
596  const auto* block_indices =
597  reinterpret_cast<const int32_t*>(kargs.block_table_ptr) +
598  i_batch_ * kargs.batch_stride_block_table;
599  const index_t num_blocks =
600  integer_divide_ceil(kargs.seqlen_k + kargs.seqlen_knew, kargs.page_block_size);
601 
602  const long_index_t fixed_offset =
603  static_cast<long_index_t>(i_nhead_ / kargs.nhead_ratio_qk) *
604  kargs.nhead_stride_v;
605 
606  return make_page_block_navigator<VDataType, 1>(
607  kargs.v_ptr,
608  kargs.batch_stride_v,
609  fixed_offset,
610  block_indices,
611  num_blocks,
612  kargs.page_block_size,
613  v_dram,
614  make_v_dram(nullptr,
615  (kargs.seqlen_k + kargs.seqlen_knew) -
616  (num_blocks - 1) * kargs.page_block_size));
617  }
618  else
619  {
620  return make_page_block_navigator(v_dram);
621  }
622  }();
623 
624  auto q_dram_window =
625  make_tile_window(q_dram,
627  {i_m0, 0});
628 
629  const bool skip_append_kv = kargs.seqlen_knew <= i_n0;
630  // window origin = (0, 0) if no work to do for current block
631  auto [i_page_block_k, k_dram_window] = k_page_block_navigator.make_tile_window(
633  {!skip_append_kv * (kargs.seqlen_k + i_n0), 0});
634 
635  auto knew_dram_window =
636  make_tile_window(knew_dram,
638  {i_n0, 0});
639 
640  // window origin = (0, 0) if no work to do for current block
641  auto [i_page_block_v, v_dram_window] = v_page_block_navigator.make_tile_window(
643  {0, !skip_append_kv * (kargs.seqlen_k + i_n0)});
644 
645  auto vnew_dram_window =
646  make_tile_window(vnew_dram,
648  {0, i_n0});
649 
650  if constexpr(kApplyRoPE)
651  {
652  FmhaPipeline{}(q_dram_window,
653  k_dram_window,
654  i_page_block_k,
655  k_page_block_navigator,
656  knew_dram_window,
657  v_dram_window,
658  i_page_block_v,
659  v_page_block_navigator,
660  vnew_dram_window,
661  q_rotary_cos_dram_window,
662  q_rotary_sin_dram_window,
663  knew_rotary_cos_dram_window,
664  knew_rotary_sin_dram_window,
665  kargs.rotary_dim,
666  kargs.seqlen_q <= i_m0,
667  skip_append_kv);
668  }
669  else
670  {
671  FmhaPipeline{}(q_dram_window,
672  k_dram_window,
673  i_page_block_k,
674  k_page_block_navigator,
675  knew_dram_window,
676  v_dram_window,
677  i_page_block_v,
678  v_page_block_navigator,
679  vnew_dram_window,
680  q_rotary_cos_dram_window,
681  q_rotary_sin_dram_window,
682  knew_rotary_cos_dram_window,
683  knew_rotary_sin_dram_window,
684  0, // rotary_dim not used
685  kargs.seqlen_q <= i_m0,
686  skip_append_kv);
687  }
688  }
689 };
690 
691 } // namespace ck_tile
#define CK_TILE_DEVICE
Definition: config.hpp:40
#define CK_TILE_HOST
Definition: config.hpp:39
#define _TS_
#define _SS_
__host__ constexpr __device__ T max(T x)
Definition: math.hpp:84
Definition: cluster_descriptor.hpp:13
constexpr CK_TILE_DEVICE auto make_null_tile_window(const WindowLengths &window_lengths)
Definition: null_tile_window.hpp:63
_BitInt(8) fp8_t
Definition: float8.hpp:204
constexpr CK_TILE_HOST_DEVICE auto integer_divide_ceil(X x, Y y)
Definition: math.hpp:149
_Float16 fp16_t
Definition: half.hpp:110
CK_TILE_HOST_DEVICE auto make_page_block_navigator(const TensorView &tensor_view)
Definition: page_block_navigator.hpp:262
constexpr CK_TILE_HOST_DEVICE auto transform_tensor_view(const OldTensorView &old_tensor_view, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition: tensor_view.hpp:461
bfloat16_t bf16_t
Definition: bfloat16.hpp:106
int32_t index_t
Definition: integer.hpp:9
constexpr CK_TILE_HOST_DEVICE auto pad_tensor_view(const TensorView &tensor_view, const TileLengths &tile_lengths, DoPads)
Definition: tensor_view.hpp:480
constexpr CK_TILE_HOST_DEVICE auto make_pass_through_transform(const LowLength &low_length)
Definition: coordinate_transform.hpp:1615
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition: type_traits.hpp:20
int64_t long_index_t
Definition: integer.hpp:10
constexpr CK_TILE_DEVICE auto make_tile_window(null_tensor_view, const WindowLengths &window_lengths, const multi_index< WindowLengths::size()> &, Ts &&...)
Definition: null_tile_window.hpp:72
unsigned _BitInt(8) bf8_t
Definition: float8.hpp:206
constexpr CK_TILE_HOST_DEVICE auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:337
typename conditional< predicate, X, Y >::type conditional_t
Definition: functional.hpp:115
Definition: fmha_fwd_appendkv_kernel.hpp:80
ck_tile::index_t stride_q
Definition: fmha_fwd_appendkv_kernel.hpp:100
const int32_t * seqlen_k_ptr
Definition: fmha_fwd_appendkv_kernel.hpp:87
const void * knew_ptr
Definition: fmha_fwd_appendkv_kernel.hpp:83
ck_tile::index_t stride_k
Definition: fmha_fwd_appendkv_kernel.hpp:101
ck_tile::index_t batch_stride_knew
Definition: fmha_fwd_appendkv_kernel.hpp:114
ck_tile::index_t batch_stride_v
Definition: fmha_fwd_appendkv_kernel.hpp:115
ck_tile::index_t nhead_stride_knew
Definition: fmha_fwd_appendkv_kernel.hpp:108
ck_tile::index_t nhead_stride_vnew
Definition: fmha_fwd_appendkv_kernel.hpp:110
ck_tile::index_t batch_stride_k
Definition: fmha_fwd_appendkv_kernel.hpp:113
ck_tile::index_t stride_knew
Definition: fmha_fwd_appendkv_kernel.hpp:102
ck_tile::index_t nhead_stride_q
Definition: fmha_fwd_appendkv_kernel.hpp:106
ck_tile::index_t hdim_q
Definition: fmha_fwd_appendkv_kernel.hpp:92
ck_tile::index_t stride_v
Definition: fmha_fwd_appendkv_kernel.hpp:103
ck_tile::index_t batch_stride_q
Definition: fmha_fwd_appendkv_kernel.hpp:112
ck_tile::index_t nhead_stride_k
Definition: fmha_fwd_appendkv_kernel.hpp:107
ck_tile::index_t hdim_v
Definition: fmha_fwd_appendkv_kernel.hpp:93
ck_tile::index_t batch_stride_vnew
Definition: fmha_fwd_appendkv_kernel.hpp:116
void * q_ptr
Definition: fmha_fwd_appendkv_kernel.hpp:81
ck_tile::index_t stride_vnew
Definition: fmha_fwd_appendkv_kernel.hpp:104
void * v_ptr
Definition: fmha_fwd_appendkv_kernel.hpp:84
const void * vnew_ptr
Definition: fmha_fwd_appendkv_kernel.hpp:85
ck_tile::index_t nhead_stride_v
Definition: fmha_fwd_appendkv_kernel.hpp:109
void * k_ptr
Definition: fmha_fwd_appendkv_kernel.hpp:82
ck_tile::index_t seqlen_k
Definition: fmha_fwd_appendkv_kernel.hpp:90
ck_tile::index_t seqlen_q
Definition: fmha_fwd_appendkv_kernel.hpp:89
ck_tile::index_t seqlen_knew
Definition: fmha_fwd_appendkv_kernel.hpp:91
ck_tile::index_t num_head_q
Definition: fmha_fwd_appendkv_kernel.hpp:95
ck_tile::index_t nhead_ratio_qk
Definition: fmha_fwd_appendkv_kernel.hpp:98
Definition: fmha_fwd_appendkv_kernel.hpp:135
const int32_t * cache_batch_idx
Definition: fmha_fwd_appendkv_kernel.hpp:136
Definition: fmha_fwd_appendkv_kernel.hpp:73
Definition: fmha_fwd_appendkv_kernel.hpp:142
Definition: fmha_fwd_appendkv_kernel.hpp:128
ck_tile::index_t batch_stride_block_table
Definition: fmha_fwd_appendkv_kernel.hpp:130
const int32_t * block_table_ptr
Definition: fmha_fwd_appendkv_kernel.hpp:129
ck_tile::index_t page_block_size
Definition: fmha_fwd_appendkv_kernel.hpp:131
Definition: fmha_fwd_appendkv_kernel.hpp:120
ck_tile::index_t rotary_dim
Definition: fmha_fwd_appendkv_kernel.hpp:123
const void * rotary_sin_ptr
Definition: fmha_fwd_appendkv_kernel.hpp:122
bool has_mask
Definition: fmha_fwd_appendkv_kernel.hpp:124
const void * rotary_cos_ptr
Definition: fmha_fwd_appendkv_kernel.hpp:121
Definition: fmha_fwd_appendkv_kernel.hpp:36
Definition: fmha_fwd_appendkv_kernel.hpp:15
static constexpr bool kPadHeadDimV
Definition: fmha_fwd_appendkv_kernel.hpp:33
static constexpr ck_tile::index_t kBlockPerCuInput
Definition: fmha_fwd_appendkv_kernel.hpp:20
static constexpr __host__ auto BlockSize()
Definition: fmha_fwd_appendkv_kernel.hpp:257
ck_tile::remove_cvref_t< FmhaPipeline_ > FmhaPipeline
Definition: fmha_fwd_appendkv_kernel.hpp:16
ck_tile::remove_cvref_t< typename FmhaPipeline::QDataType > QDataType
Definition: fmha_fwd_appendkv_kernel.hpp:22
ck_tile::remove_cvref_t< typename FmhaPipeline::KDataType > KDataType
Definition: fmha_fwd_appendkv_kernel.hpp:23
static constexpr CK_TILE_HOST auto GridSize(ck_tile::index_t batch_size, ck_tile::index_t nhead, ck_tile::index_t seqlen_q, ck_tile::index_t seqlen_knew)
Definition: fmha_fwd_appendkv_kernel.hpp:236
static constexpr bool kPadSeqLenQ
Definition: fmha_fwd_appendkv_kernel.hpp:30
CK_TILE_DEVICE void operator()(Kargs kargs) const
Definition: fmha_fwd_appendkv_kernel.hpp:259
ck_tile::remove_cvref_t< typename FmhaPipeline::VLayout > VLayout
Definition: fmha_fwd_appendkv_kernel.hpp:26
static constexpr ck_tile::index_t kBlockPerCu
Definition: fmha_fwd_appendkv_kernel.hpp:18
static constexpr __host__ Kargs MakeKargs(void *q_ptr, void *k_ptr, const void *knew_ptr, void *v_ptr, const void *vnew_ptr, ck_tile::index_t seqlen_q, const void *seqlen_k_ptr, ck_tile::index_t seqlen_knew, ck_tile::index_t hdim_q, ck_tile::index_t hdim_v, ck_tile::index_t num_head_q, ck_tile::index_t nhead_ratio_qk, const void *rotary_cos_ptr, const void *rotary_sin_ptr, ck_tile::index_t rotary_dim, bool has_mask, const void *block_table_ptr, ck_tile::index_t batch_stride_block_table, ck_tile::index_t page_block_size, const void *cache_batch_idx, ck_tile::index_t stride_q, ck_tile::index_t stride_k, ck_tile::index_t stride_knew, ck_tile::index_t stride_v, ck_tile::index_t stride_vnew, ck_tile::index_t nhead_stride_q, ck_tile::index_t nhead_stride_k, ck_tile::index_t nhead_stride_knew, ck_tile::index_t nhead_stride_v, ck_tile::index_t nhead_stride_vnew, ck_tile::index_t batch_stride_q, ck_tile::index_t batch_stride_k, ck_tile::index_t batch_stride_knew, ck_tile::index_t batch_stride_v, ck_tile::index_t batch_stride_vnew)
Definition: fmha_fwd_appendkv_kernel.hpp:145
ck_tile::remove_cvref_t< typename FmhaPipeline::VDataType > VDataType
Definition: fmha_fwd_appendkv_kernel.hpp:24
static constexpr bool kPadHeadDimQ
Definition: fmha_fwd_appendkv_kernel.hpp:32
static constexpr bool kPadSeqLenK
Definition: fmha_fwd_appendkv_kernel.hpp:31
static constexpr CK_TILE_DEVICE auto GetTileIndex(const Kargs &)
Definition: fmha_fwd_appendkv_kernel.hpp:248
static constexpr bool kIsPagedKV
Definition: fmha_fwd_appendkv_kernel.hpp:28
static constexpr bool kApplyRoPE
Definition: fmha_fwd_appendkv_kernel.hpp:27
static __host__ std::string GetName()
Definition: fmha_fwd_appendkv_kernel.hpp:44
static constexpr ck_tile::index_t kBlockSize
Definition: fmha_fwd_appendkv_kernel.hpp:17
Definition: block_rotary_embedding.hpp:19
Definition: integral_constant.hpp:13
Definition: sequence.hpp:52