/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/fmha/pipeline/tile_fmha_shape.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/fmha/pipeline/tile_fmha_shape.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/fmha/pipeline/tile_fmha_shape.hpp Source File
tile_fmha_shape.hpp
Go to the documentation of this file.
1 // Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
2 // SPDX-License-Identifier: MIT
3 
4 #pragma once
5 
6 #include "ck_tile/core.hpp"
7 
8 namespace ck_tile {
9 
10 template <index_t Headdim>
11 static CK_TILE_HOST_DEVICE constexpr index_t ceil_to_qualified_tile_length()
12 {
13  if constexpr(Headdim == 48)
14  return 48;
15  else if constexpr(Headdim == 80)
16  return 96;
17  else if constexpr(Headdim == 96)
18  return 128;
19  else if constexpr(Headdim == 160)
20  return 256;
21  else if constexpr(Headdim == 192)
22  return 192;
23  else if constexpr(is_power_of_two_integer(Headdim))
24  return Headdim;
25  else
26  static_assert(Headdim == 0,
27  "only Headdim of 48, 96, 160, 192 and power-of-two is supported");
28 };
29 
30 template <typename BlockTile_, // sequence<...
31  typename Gemm0BlockWarps_,
32  typename Gemm0WarpTile_,
33  typename Gemm1BlockWarps_,
34  typename Gemm1WarpTile_,
35  bool IsVLayoutRowMajor_>
37 {
43 
44  static constexpr index_t NumGemm0Warps =
46  static constexpr index_t NumGemm1Warps =
48  static_assert(NumGemm1Warps % NumGemm0Warps == 0);
49 
50  static constexpr index_t NumWarps = max(NumGemm0Warps, NumGemm1Warps);
51 
52  static constexpr index_t kM0 = BlockTile::at(number<0>{}); // tile size along q seqlen
53  static constexpr index_t kN0 = BlockTile::at(number<1>{}); // tile size along k seqlen
54  static constexpr index_t kK0 = BlockTile::at(number<2>{}); // tile size along qk gemm unroll
55  static constexpr index_t kN1 = BlockTile::at(number<3>{}); // tile size along v head_dim
56  static constexpr index_t kK1 = BlockTile::at(number<4>{}); // tile size along kv gemm unroll
57  static constexpr index_t kQKHeaddim =
58  BlockTile::at(number<5>{}); // total length of K0, used for pipeline that need load Q at
59  // once (or repeately load Q as a whole tile)
60  static_assert(kQKHeaddim % kK0 == 0, "kQKHeaddim should be divisible by kK0");
61 
62  static constexpr index_t kSubQKHeaddim = ceil_to_qualified_tile_length<kQKHeaddim>();
63 
64  // v, rowmajor : seqlen*hdim, colmajor : hdim*seqlen
65  static constexpr bool IsVLayoutRowMajor = IsVLayoutRowMajor_;
66  using VLayout = std::conditional_t<IsVLayoutRowMajor,
69 };
70 
71 template <typename BlockTile_, // sequence<...
72  typename Gemm0BlockWarps_,
73  typename Gemm0WarpTile_,
74  typename Gemm1BlockWarps_,
75  typename Gemm1WarpTile_,
76  typename Gemm2BlockWarps_,
77  typename Gemm2WarpTile_,
78  typename Gemm3BlockWarps_,
79  typename Gemm3WarpTile_,
80  typename Gemm4BlockWarps_,
81  typename Gemm4WarpTile_,
82  index_t kMaxSeqLenQ_ = 0>
84 {
96 
97  static constexpr index_t NumWarps =
99 
100  static_assert(NumWarps == reduce_on_sequence(Gemm1BlockWarps{}, multiplies{}, number<1>{}) &&
101  NumWarps == reduce_on_sequence(Gemm4BlockWarps{}, multiplies{}, number<1>{}));
102 
103  static constexpr index_t kM0 = BlockTile::at(number<0>{}); // tile size along q seqlen
104  static constexpr index_t kN0 = BlockTile::at(number<1>{}); // tile size along k seqlen
105  static constexpr index_t kK0 =
106  BlockTile::at(number<2>{}); // tile size along gemm0(Q@K^T) unroll
107  static constexpr index_t kK1 =
108  BlockTile::at(number<3>{}); // tile size along gemm1(P^T@dO) unroll
109  static constexpr index_t kK2 =
110  BlockTile::at(number<4>{}); // tile size along gemm2(dO@V^T) unroll
111  static constexpr index_t kK3 =
112  BlockTile::at(number<5>{}); // tile size along gemm3(dS^T@Q) unroll
113  static constexpr index_t kK4 = BlockTile::at(number<6>{}); // tile size along gemm4(dS@K) unroll
114  static constexpr index_t kQKHeaddim =
115  BlockTile::at(number<7>{}); // Q & K headdim, used for pipeline that need load Q/Q^T or
116  // K/K^T at once
117  static constexpr index_t kVHeaddim = BlockTile::at(number<8>{}); // V headdim, used for pipeline
118  // that need load V at once
119 
120  static constexpr index_t kMaxSeqLenQ = kMaxSeqLenQ_;
121  static_assert(kMaxSeqLenQ == kM0 || kMaxSeqLenQ == 0,
122  "kMaxSeqLenQ should be equal to kM0 or 0, if 0, it means seq len Q is unlimited");
123 };
124 
125 } // namespace ck_tile
#define CK_TILE_HOST_DEVICE
Definition: config.hpp:46
Definition: cluster_descriptor.hpp:13
constexpr CK_TILE_HOST_DEVICE bool is_power_of_two_integer(int32_t x)
Definition: math.hpp:455
int32_t index_t
Definition: integer.hpp:9
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition: type_traits.hpp:21
constexpr CK_TILE_HOST_DEVICE index_t reduce_on_sequence(Seq, Reduce f, number< Init >)
Definition: sequence.hpp:993
CK_TILE_HOST_DEVICE_EXTERN multiplies() -> multiplies< void, void >
constexpr CK_TILE_HOST_DEVICE T max(T x)
Definition: math.hpp:157
typename conditional< predicate, X, Y >::type conditional_t
Definition: functional.hpp:115
Definition: tile_fmha_shape.hpp:84
static constexpr index_t kQKHeaddim
Definition: tile_fmha_shape.hpp:114
remove_cvref_t< Gemm0BlockWarps_ > Gemm0BlockWarps
Definition: tile_fmha_shape.hpp:86
static constexpr index_t kK3
Definition: tile_fmha_shape.hpp:111
remove_cvref_t< Gemm1WarpTile_ > Gemm1WarpTile
Definition: tile_fmha_shape.hpp:89
static constexpr index_t kN0
Definition: tile_fmha_shape.hpp:104
static constexpr index_t kMaxSeqLenQ
Definition: tile_fmha_shape.hpp:120
remove_cvref_t< Gemm4WarpTile_ > Gemm4WarpTile
Definition: tile_fmha_shape.hpp:95
remove_cvref_t< Gemm4BlockWarps_ > Gemm4BlockWarps
Definition: tile_fmha_shape.hpp:94
static constexpr index_t kVHeaddim
Definition: tile_fmha_shape.hpp:117
remove_cvref_t< Gemm2BlockWarps_ > Gemm2BlockWarps
Definition: tile_fmha_shape.hpp:90
remove_cvref_t< Gemm3BlockWarps_ > Gemm3BlockWarps
Definition: tile_fmha_shape.hpp:92
remove_cvref_t< Gemm0WarpTile_ > Gemm0WarpTile
Definition: tile_fmha_shape.hpp:87
remove_cvref_t< Gemm2WarpTile_ > Gemm2WarpTile
Definition: tile_fmha_shape.hpp:91
static constexpr index_t kM0
Definition: tile_fmha_shape.hpp:103
remove_cvref_t< BlockTile_ > BlockTile
Definition: tile_fmha_shape.hpp:85
static constexpr index_t kK4
Definition: tile_fmha_shape.hpp:113
static constexpr index_t kK0
Definition: tile_fmha_shape.hpp:105
static constexpr index_t kK2
Definition: tile_fmha_shape.hpp:109
static constexpr index_t NumWarps
Definition: tile_fmha_shape.hpp:97
static constexpr index_t kK1
Definition: tile_fmha_shape.hpp:107
remove_cvref_t< Gemm3WarpTile_ > Gemm3WarpTile
Definition: tile_fmha_shape.hpp:93
remove_cvref_t< Gemm1BlockWarps_ > Gemm1BlockWarps
Definition: tile_fmha_shape.hpp:88
Definition: tile_fmha_shape.hpp:37
std::conditional_t< IsVLayoutRowMajor, ck_tile::tensor_layout::gemm::RowMajor, ck_tile::tensor_layout::gemm::ColumnMajor > VLayout
Definition: tile_fmha_shape.hpp:68
remove_cvref_t< Gemm1BlockWarps_ > Gemm1BlockWarps
Definition: tile_fmha_shape.hpp:41
remove_cvref_t< Gemm1WarpTile_ > Gemm1WarpTile
Definition: tile_fmha_shape.hpp:42
remove_cvref_t< Gemm0WarpTile_ > Gemm0WarpTile
Definition: tile_fmha_shape.hpp:40
remove_cvref_t< Gemm0BlockWarps_ > Gemm0BlockWarps
Definition: tile_fmha_shape.hpp:39
remove_cvref_t< BlockTile_ > BlockTile
Definition: tile_fmha_shape.hpp:38
Definition: integral_constant.hpp:13
Definition: math.hpp:95
Definition: tensor_layout.hpp:22
Definition: tensor_layout.hpp:17