/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/fmha/pipeline/tile_fmha_shape.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/fmha/pipeline/tile_fmha_shape.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/fmha/pipeline/tile_fmha_shape.hpp Source File
tile_fmha_shape.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include "ck_tile/core.hpp"
7 
8 namespace ck_tile {
9 
10 template <index_t Headdim>
11 static CK_TILE_HOST_DEVICE constexpr index_t ceil_to_qualified_tile_length()
12 {
13  if constexpr(Headdim == 48)
14  return 48;
15  else if constexpr(Headdim == 96)
16  return 128;
17  else if constexpr(Headdim == 160)
18  return 256;
19  else if constexpr(Headdim == 192)
20  return 192;
21  else if constexpr(is_power_of_two_integer(Headdim))
22  return Headdim;
23  else
24  static_assert(Headdim == 0,
25  "only Headdim of 48, 96, 160, 192 and power-of-two is supported");
26 };
27 
28 template <typename BlockTile_, // sequence<...
29  typename Gemm0BlockWarps_,
30  typename Gemm0WarpTile_,
31  typename Gemm1BlockWarps_,
32  typename Gemm1WarpTile_,
33  bool IsVLayoutRowMajor_>
35 {
41 
42  static constexpr index_t NumGemm0Warps =
44  static constexpr index_t NumGemm1Warps =
46  static_assert(NumGemm1Warps % NumGemm0Warps == 0);
47 
48  static constexpr index_t NumWarps = max(NumGemm0Warps, NumGemm1Warps);
49 
50  static constexpr index_t kM0 = BlockTile::at(number<0>{}); // tile size along q seqlen
51  static constexpr index_t kN0 = BlockTile::at(number<1>{}); // tile size along k seqlen
52  static constexpr index_t kK0 = BlockTile::at(number<2>{}); // tile size along qk gemm unroll
53  static constexpr index_t kN1 = BlockTile::at(number<3>{}); // tile size along v head_dim
54  static constexpr index_t kK1 = BlockTile::at(number<4>{}); // tile size along kv gemm unroll
55  static constexpr index_t kQKHeaddim =
56  BlockTile::at(number<5>{}); // total length of K0, used for pipeline that need load Q at
57  // once (or repeately load Q as a whole tile)
58  static_assert(kQKHeaddim % kK0 == 0, "kQKHeaddim should be divisible by kK0");
59 
60  static constexpr index_t kSubQKHeaddim = ceil_to_qualified_tile_length<kQKHeaddim>();
61 
62  // v, rowmajor : seqlen*hdim, colmajor : hdim*seqlen
63  static constexpr bool IsVLayoutRowMajor = IsVLayoutRowMajor_;
64  using VLayout = std::conditional_t<IsVLayoutRowMajor,
67 };
68 
69 template <typename BlockTile_, // sequence<...
70  typename Gemm0BlockWarps_,
71  typename Gemm0WarpTile_,
72  typename Gemm1BlockWarps_,
73  typename Gemm1WarpTile_,
74  typename Gemm2BlockWarps_,
75  typename Gemm2WarpTile_,
76  typename Gemm3BlockWarps_,
77  typename Gemm3WarpTile_,
78  typename Gemm4BlockWarps_,
79  typename Gemm4WarpTile_,
80  index_t kMaxSeqLenQ_ = 0>
82 {
94 
95  static constexpr index_t NumWarps =
97 
98  static_assert(NumWarps == reduce_on_sequence(Gemm1BlockWarps{}, multiplies{}, number<1>{}) &&
100 
101  static constexpr index_t kM0 = BlockTile::at(number<0>{}); // tile size along q seqlen
102  static constexpr index_t kN0 = BlockTile::at(number<1>{}); // tile size along k seqlen
103  static constexpr index_t kK0 =
104  BlockTile::at(number<2>{}); // tile size along gemm0(Q@K^T) unroll
105  static constexpr index_t kK1 =
106  BlockTile::at(number<3>{}); // tile size along gemm1(P^T@dO) unroll
107  static constexpr index_t kK2 =
108  BlockTile::at(number<4>{}); // tile size along gemm2(dO@V^T) unroll
109  static constexpr index_t kK3 =
110  BlockTile::at(number<5>{}); // tile size along gemm3(dS^T@Q) unroll
111  static constexpr index_t kK4 = BlockTile::at(number<6>{}); // tile size along gemm4(dS@K) unroll
112  static constexpr index_t kQKHeaddim =
113  BlockTile::at(number<7>{}); // Q & K headdim, used for pipeline that need load Q/Q^T or
114  // K/K^T at once
115  static constexpr index_t kVHeaddim = BlockTile::at(number<8>{}); // V headdim, used for pipeline
116  // that need load V at once
117 
118  static constexpr index_t kMaxSeqLenQ = kMaxSeqLenQ_;
119  static_assert(kMaxSeqLenQ == kM0 || kMaxSeqLenQ == 0,
120  "kMaxSeqLenQ should be equal to kM0 or 0, if 0, it means seq len Q is unlimited");
121 };
122 
123 } // namespace ck_tile
#define CK_TILE_HOST_DEVICE
Definition: config.hpp:42
Definition: cluster_descriptor.hpp:13
__host__ __device__ multiplies() -> multiplies< void, void >
FIXME: create macro to replace 'host device' and nothing more.
constexpr CK_TILE_HOST_DEVICE bool is_power_of_two_integer(int32_t x)
Definition: math.hpp:462
int32_t index_t
Definition: integer.hpp:9
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition: type_traits.hpp:21
constexpr CK_TILE_HOST_DEVICE index_t reduce_on_sequence(Seq, Reduce f, number< Init >)
Definition: sequence.hpp:982
constexpr CK_TILE_HOST_DEVICE T max(T x)
Definition: math.hpp:161
typename conditional< predicate, X, Y >::type conditional_t
Definition: functional.hpp:115
Definition: tile_fmha_shape.hpp:82
static constexpr index_t kQKHeaddim
Definition: tile_fmha_shape.hpp:112
remove_cvref_t< Gemm0BlockWarps_ > Gemm0BlockWarps
Definition: tile_fmha_shape.hpp:84
static constexpr index_t kK3
Definition: tile_fmha_shape.hpp:109
remove_cvref_t< Gemm1WarpTile_ > Gemm1WarpTile
Definition: tile_fmha_shape.hpp:87
static constexpr index_t kN0
Definition: tile_fmha_shape.hpp:102
static constexpr index_t kMaxSeqLenQ
Definition: tile_fmha_shape.hpp:118
remove_cvref_t< Gemm4WarpTile_ > Gemm4WarpTile
Definition: tile_fmha_shape.hpp:93
remove_cvref_t< Gemm4BlockWarps_ > Gemm4BlockWarps
Definition: tile_fmha_shape.hpp:92
static constexpr index_t kVHeaddim
Definition: tile_fmha_shape.hpp:115
remove_cvref_t< Gemm2BlockWarps_ > Gemm2BlockWarps
Definition: tile_fmha_shape.hpp:88
remove_cvref_t< Gemm3BlockWarps_ > Gemm3BlockWarps
Definition: tile_fmha_shape.hpp:90
remove_cvref_t< Gemm0WarpTile_ > Gemm0WarpTile
Definition: tile_fmha_shape.hpp:85
remove_cvref_t< Gemm2WarpTile_ > Gemm2WarpTile
Definition: tile_fmha_shape.hpp:89
static constexpr index_t kM0
Definition: tile_fmha_shape.hpp:101
remove_cvref_t< BlockTile_ > BlockTile
Definition: tile_fmha_shape.hpp:83
static constexpr index_t kK4
Definition: tile_fmha_shape.hpp:111
static constexpr index_t kK0
Definition: tile_fmha_shape.hpp:103
static constexpr index_t kK2
Definition: tile_fmha_shape.hpp:107
static constexpr index_t NumWarps
Definition: tile_fmha_shape.hpp:95
static constexpr index_t kK1
Definition: tile_fmha_shape.hpp:105
remove_cvref_t< Gemm3WarpTile_ > Gemm3WarpTile
Definition: tile_fmha_shape.hpp:91
remove_cvref_t< Gemm1BlockWarps_ > Gemm1BlockWarps
Definition: tile_fmha_shape.hpp:86
Definition: tile_fmha_shape.hpp:35
std::conditional_t< IsVLayoutRowMajor, ck_tile::tensor_layout::gemm::RowMajor, ck_tile::tensor_layout::gemm::ColumnMajor > VLayout
Definition: tile_fmha_shape.hpp:66
remove_cvref_t< Gemm1BlockWarps_ > Gemm1BlockWarps
Definition: tile_fmha_shape.hpp:39
remove_cvref_t< Gemm1WarpTile_ > Gemm1WarpTile
Definition: tile_fmha_shape.hpp:40
remove_cvref_t< Gemm0WarpTile_ > Gemm0WarpTile
Definition: tile_fmha_shape.hpp:38
remove_cvref_t< Gemm0BlockWarps_ > Gemm0BlockWarps
Definition: tile_fmha_shape.hpp:37
remove_cvref_t< BlockTile_ > BlockTile
Definition: tile_fmha_shape.hpp:36
Definition: integral_constant.hpp:13
Definition: math.hpp:98
Definition: tensor_layout.hpp:22
Definition: tensor_layout.hpp:17