/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/fmha/pipeline/tile_fmha_shape.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/fmha/pipeline/tile_fmha_shape.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/fmha/pipeline/tile_fmha_shape.hpp Source File
tile_fmha_shape.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include "ck_tile/core.hpp"
7 
8 namespace ck_tile {
9 
10 static CK_TILE_HOST_DEVICE constexpr index_t ceil_to_qualified_tile_length(index_t len)
11 {
12  if(len == 96)
13  return 128;
14  if(len == 160)
15  return 256;
16  if(len == 192)
17  return 192;
18 
19  // only length of 96, 160 and power-of-two is supported
20  if(!(len & (len - 1)))
21  return len;
22 
23  return 0;
24 };
25 
26 template <typename BlockTile_, // sequence<...
27  typename Gemm0BlockWarps_,
28  typename Gemm0WarpTile_,
29  typename Gemm1BlockWarps_,
30  typename Gemm1WarpTile_,
31  bool IsVLayoutRowMajor_>
33 {
39 
40  static constexpr index_t NumGemm0Warps =
42  static constexpr index_t NumGemm1Warps =
44  static_assert(NumGemm1Warps % NumGemm0Warps == 0);
45 
47 
48  static constexpr index_t kM0 = BlockTile::at(number<0>{}); // tile size along q seqlen
49  static constexpr index_t kN0 = BlockTile::at(number<1>{}); // tile size along k seqlen
50  static constexpr index_t kK0 = BlockTile::at(number<2>{}); // tile size along qk gemm unroll
51  static constexpr index_t kN1 = BlockTile::at(number<3>{}); // tile size along v head_dim
52  static constexpr index_t kK1 = BlockTile::at(number<4>{}); // tile size along kv gemm unroll
53  static constexpr index_t kQKHeaddim =
54  BlockTile::at(number<5>{}); // total length of K0, used for pipeline that need load Q at
55  // once (or repeately load Q as a whole tile)
56  static_assert(kQKHeaddim % kK0 == 0, "kQKHeaddim should be divisible by kK0");
57 
58  static constexpr index_t kSubQKHeaddim = ceil_to_qualified_tile_length(kQKHeaddim);
59 
60  // v, rowmajor : seqlen*hdim, colmajor : hdim*seqlen
61  static constexpr bool IsVLayoutRowMajor = IsVLayoutRowMajor_;
65 };
66 
67 template <typename BlockTile_, // sequence<...
68  typename Gemm0BlockWarps_,
69  typename Gemm0WarpTile_,
70  typename Gemm1BlockWarps_,
71  typename Gemm1WarpTile_,
72  typename Gemm2BlockWarps_,
73  typename Gemm2WarpTile_,
74  typename Gemm3BlockWarps_,
75  typename Gemm3WarpTile_,
76  typename Gemm4BlockWarps_,
77  typename Gemm4WarpTile_,
78  index_t kMaxSeqLenQ_ = 0>
80 {
92 
93  static constexpr index_t NumWarps =
95 
96  static_assert(NumWarps == reduce_on_sequence(Gemm1BlockWarps{}, multiplies{}, number<1>{}) &&
98 
99  static constexpr index_t kM0 = BlockTile::at(number<0>{}); // tile size along q seqlen
100  static constexpr index_t kN0 = BlockTile::at(number<1>{}); // tile size along k seqlen
101  static constexpr index_t kK0 =
102  BlockTile::at(number<2>{}); // tile size along gemm0(Q@K^T) unroll
103  static constexpr index_t kK1 =
104  BlockTile::at(number<3>{}); // tile size along gemm1(P^T@dO) unroll
105  static constexpr index_t kK2 =
106  BlockTile::at(number<4>{}); // tile size along gemm2(dO@V^T) unroll
107  static constexpr index_t kK3 =
108  BlockTile::at(number<5>{}); // tile size along gemm3(dS^T@Q) unroll
109  static constexpr index_t kK4 = BlockTile::at(number<6>{}); // tile size along gemm4(dS@K) unroll
110  static constexpr index_t kQKHeaddim =
111  BlockTile::at(number<7>{}); // Q & K headdim, used for pipeline that need load Q/Q^T or
112  // K/K^T at once
113  static constexpr index_t kVHeaddim = BlockTile::at(number<8>{}); // V headdim, used for pipeline
114  // that need load V at once
115 
116  static constexpr index_t kMaxSeqLenQ = kMaxSeqLenQ_;
117  static_assert(kMaxSeqLenQ == kM0 || kMaxSeqLenQ == 0,
118  "kMaxSeqLenQ should be equal to kM0 or 0, if 0, it means seq len Q is unlimited");
119 };
120 
121 } // namespace ck_tile
#define CK_TILE_HOST_DEVICE
Definition: config.hpp:42
Definition: cluster_descriptor.hpp:13
__host__ __device__ multiplies() -> multiplies< void, void >
FIXME: create macro to replace 'host device' and nothing more.
int32_t index_t
Definition: integer.hpp:9
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition: type_traits.hpp:21
constexpr CK_TILE_HOST_DEVICE index_t reduce_on_sequence(Seq, Reduce f, number< Init >)
Definition: sequence.hpp:979
constexpr CK_TILE_HOST_DEVICE T max(T x)
Definition: math.hpp:161
typename conditional< predicate, X, Y >::type conditional_t
Definition: functional.hpp:115
Definition: tile_fmha_shape.hpp:80
static constexpr index_t kQKHeaddim
Definition: tile_fmha_shape.hpp:110
remove_cvref_t< Gemm0BlockWarps_ > Gemm0BlockWarps
Definition: tile_fmha_shape.hpp:82
static constexpr index_t kK3
Definition: tile_fmha_shape.hpp:107
remove_cvref_t< Gemm1WarpTile_ > Gemm1WarpTile
Definition: tile_fmha_shape.hpp:85
static constexpr index_t kN0
Definition: tile_fmha_shape.hpp:100
static constexpr index_t kMaxSeqLenQ
Definition: tile_fmha_shape.hpp:116
remove_cvref_t< Gemm4WarpTile_ > Gemm4WarpTile
Definition: tile_fmha_shape.hpp:91
remove_cvref_t< Gemm4BlockWarps_ > Gemm4BlockWarps
Definition: tile_fmha_shape.hpp:90
static constexpr index_t kVHeaddim
Definition: tile_fmha_shape.hpp:113
remove_cvref_t< Gemm2BlockWarps_ > Gemm2BlockWarps
Definition: tile_fmha_shape.hpp:86
remove_cvref_t< Gemm3BlockWarps_ > Gemm3BlockWarps
Definition: tile_fmha_shape.hpp:88
remove_cvref_t< Gemm0WarpTile_ > Gemm0WarpTile
Definition: tile_fmha_shape.hpp:83
remove_cvref_t< Gemm2WarpTile_ > Gemm2WarpTile
Definition: tile_fmha_shape.hpp:87
static constexpr index_t kM0
Definition: tile_fmha_shape.hpp:99
remove_cvref_t< BlockTile_ > BlockTile
Definition: tile_fmha_shape.hpp:81
static constexpr index_t kK4
Definition: tile_fmha_shape.hpp:109
static constexpr index_t kK0
Definition: tile_fmha_shape.hpp:101
static constexpr index_t kK2
Definition: tile_fmha_shape.hpp:105
static constexpr index_t NumWarps
Definition: tile_fmha_shape.hpp:93
static constexpr index_t kK1
Definition: tile_fmha_shape.hpp:103
remove_cvref_t< Gemm3WarpTile_ > Gemm3WarpTile
Definition: tile_fmha_shape.hpp:89
remove_cvref_t< Gemm1BlockWarps_ > Gemm1BlockWarps
Definition: tile_fmha_shape.hpp:84
Definition: tile_fmha_shape.hpp:33
static constexpr bool IsVLayoutRowMajor
Definition: tile_fmha_shape.hpp:61
std::conditional_t< IsVLayoutRowMajor, ck_tile::tensor_layout::gemm::RowMajor, ck_tile::tensor_layout::gemm::ColumnMajor > VLayout
Definition: tile_fmha_shape.hpp:64
static constexpr index_t kQKHeaddim
Definition: tile_fmha_shape.hpp:53
static constexpr index_t kK0
Definition: tile_fmha_shape.hpp:50
remove_cvref_t< Gemm1BlockWarps_ > Gemm1BlockWarps
Definition: tile_fmha_shape.hpp:37
static constexpr index_t NumGemm0Warps
Definition: tile_fmha_shape.hpp:40
static constexpr index_t NumWarps
Definition: tile_fmha_shape.hpp:46
static constexpr index_t kK1
Definition: tile_fmha_shape.hpp:52
remove_cvref_t< Gemm1WarpTile_ > Gemm1WarpTile
Definition: tile_fmha_shape.hpp:38
remove_cvref_t< Gemm0WarpTile_ > Gemm0WarpTile
Definition: tile_fmha_shape.hpp:36
static constexpr index_t kSubQKHeaddim
Definition: tile_fmha_shape.hpp:58
static constexpr index_t kN0
Definition: tile_fmha_shape.hpp:49
remove_cvref_t< Gemm0BlockWarps_ > Gemm0BlockWarps
Definition: tile_fmha_shape.hpp:35
static constexpr index_t kN1
Definition: tile_fmha_shape.hpp:51
static constexpr index_t kM0
Definition: tile_fmha_shape.hpp:48
remove_cvref_t< BlockTile_ > BlockTile
Definition: tile_fmha_shape.hpp:34
static constexpr index_t NumGemm1Warps
Definition: tile_fmha_shape.hpp:42
Definition: integral_constant.hpp:13
Definition: math.hpp:98
Definition: tensor_layout.hpp:22
Definition: tensor_layout.hpp:17