/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/fmha/kernel/fmha_fwd_appendkv_tile_partitioner.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/fmha/kernel/fmha_fwd_appendkv_tile_partitioner.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/fmha/kernel/fmha_fwd_appendkv_tile_partitioner.hpp Source File
fmha_fwd_appendkv_tile_partitioner.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include "ck_tile/core.hpp"
7 
8 namespace ck_tile {
9 
10 template <index_t kM0_, index_t kN0_, index_t kK0_, index_t kN1_>
12 {
13  static constexpr ck_tile::index_t kM0 = kM0_;
14  static constexpr ck_tile::index_t kN0 = kN0_;
15  static constexpr ck_tile::index_t kK0 = kK0_;
16  static constexpr ck_tile::index_t kN1 = kN1_;
17 
18  static_assert(kK0 == kN1);
19 
20  CK_TILE_HOST static constexpr auto GridSize(ck_tile::index_t batch_size,
21  ck_tile::index_t nhead,
22  ck_tile::index_t seqlen_q,
23  ck_tile::index_t seqlen_knew)
24  {
25  // TODO: this may need tuning
26  return dim3(std::max(ck_tile::integer_divide_ceil(seqlen_q, kM0),
27  ck_tile::integer_divide_ceil(seqlen_knew, kN0)),
28  nhead,
29  batch_size);
30  }
31 
33  {
34  const index_t i_tile = blockIdx.x;
35  const index_t i_nhead = blockIdx.y;
36  const index_t i_batch = blockIdx.z;
37 
38  return ck_tile::make_tuple(i_tile, i_nhead, i_batch);
39  }
40 };
41 
42 } // namespace ck_tile
#define CK_TILE_DEVICE
Definition: config.hpp:41
#define CK_TILE_HOST
Definition: config.hpp:40
__host__ constexpr __device__ T max(T x)
Definition: math.hpp:84
Definition: cluster_descriptor.hpp:13
constexpr CK_TILE_HOST_DEVICE auto integer_divide_ceil(X x, Y y)
Definition: math.hpp:149
int32_t index_t
Definition: integer.hpp:9
constexpr CK_TILE_HOST_DEVICE auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:360
Definition: fmha_fwd_appendkv_tile_partitioner.hpp:12
static constexpr ck_tile::index_t kN1
Definition: fmha_fwd_appendkv_tile_partitioner.hpp:16
CK_TILE_DEVICE auto operator()()
Definition: fmha_fwd_appendkv_tile_partitioner.hpp:32
static constexpr ck_tile::index_t kK0
Definition: fmha_fwd_appendkv_tile_partitioner.hpp:15
static constexpr CK_TILE_HOST auto GridSize(ck_tile::index_t batch_size, ck_tile::index_t nhead, ck_tile::index_t seqlen_q, ck_tile::index_t seqlen_knew)
Definition: fmha_fwd_appendkv_tile_partitioner.hpp:20
static constexpr ck_tile::index_t kM0
Definition: fmha_fwd_appendkv_tile_partitioner.hpp:13
static constexpr ck_tile::index_t kN0
Definition: fmha_fwd_appendkv_tile_partitioner.hpp:14