/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/fused_moe/kernel/fused_moegemm_shape.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/fused_moe/kernel/fused_moegemm_shape.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/fused_moe/kernel/fused_moegemm_shape.hpp Source File
fused_moegemm_shape.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include "ck_tile/core.hpp"
7 
8 namespace ck_tile {
9 
10 /*
11 tensors:
12 1. act (A): input feature map
13 2. gate (G): B matrix for first gemm, output will do activation(Silu)
14 3. up (U): B matrix for first gemm
15 4. down (D): B matrix for second gemm
16  N1
17  / \
18  +----------+ |
19  | Down | |
20  x----------x |
21  hidden hidden K1 | | |
22  N0 N0 x----------x |
23  | +------x-----x------+------x-----x------+ | | |
24  dim | | Gate | | | Up | | | | | |
25  contiguous | | | | | | | | | | |
26  | | | | | | | | | | |
27  v +------x-----x------+------x-----x------+ +----------+ V
28  K0 | | | | | contiguous
29  / \ v v v v |
30  +---------+ +------x-----x------+------x-----x------+ |
31 M0 | A | | | | | | | | |
32  +---------+ +------x-----x------+------x-----x------+ |
33  ----------> | | |
34  contiguous | V V
35  | x-----x +----------+
36  +------------> M1 | Y | ---------> | Out(O) |
37  ACT x-----x +----------+
38  K1 = N0 dim
39 
40 * Note: Act could be Gelu/Silu/...
41 * Note: some model does not have Up
42 */
43 template <typename BlockTile_0_,
44  typename WarpPerBlock_0_,
45  typename WarpTile_0_,
46  typename BlockTile_1_,
47  typename WarpPerBlock_1_,
48  typename WarpTile_1_>
50 {
57 
58  static constexpr index_t NumWarps =
59  reduce_on_sequence(WarpPerBlock_0{}, multiplies{}, number<1>{});
60 
61  // TODO: we don't support half warps aound to 1 warp here
62  static_assert(NumWarps == reduce_on_sequence(WarpPerBlock_1{}, multiplies{}, number<1>{}));
63 
64  static constexpr index_t Block_M0 = BlockTile_0::at(number<0>{});
65  static constexpr index_t Block_N0 = BlockTile_0::at(number<1>{});
66  static constexpr index_t Block_K0 = BlockTile_0::at(number<2>{});
67  static constexpr index_t WarpPerBlock_M0 = WarpPerBlock_0::at(number<0>{});
68  static constexpr index_t WarpPerBlock_N0 = WarpPerBlock_0::at(number<1>{});
69  static constexpr index_t WarpPerBlock_K0 = WarpPerBlock_0::at(number<2>{});
70  static constexpr index_t Warp_M0 = WarpTile_0::at(number<0>{});
71  static constexpr index_t Warp_N0 = WarpTile_0::at(number<1>{});
72  static constexpr index_t Warp_K0 = WarpTile_0::at(number<2>{});
73 
77  static_assert(Block_M0 % ThreadPerBlock_M0 == 0);
78  static_assert(Block_N0 % ThreadPerBlock_N0 == 0);
79  static_assert(Block_K0 % ThreadPerBlock_K0 == 0);
83 
84  static constexpr index_t Block_M1 = BlockTile_1::at(number<0>{});
85  static constexpr index_t Block_N1 = BlockTile_1::at(number<1>{});
86  static constexpr index_t Block_K1 = BlockTile_1::at(number<2>{});
87  static constexpr index_t WarpPerBlock_M1 = WarpPerBlock_1::at(number<0>{});
88  static constexpr index_t WarpPerBlock_N1 = WarpPerBlock_1::at(number<1>{});
89  static constexpr index_t WarpPerBlock_K1 = WarpPerBlock_1::at(number<2>{});
90  static constexpr index_t Warp_M1 = WarpTile_1::at(number<0>{});
91  static constexpr index_t Warp_N1 = WarpTile_1::at(number<1>{});
92  static constexpr index_t Warp_K1 = WarpTile_1::at(number<2>{});
93 
97  static_assert(Block_M1 % ThreadPerBlock_M1 == 0);
98  static_assert(Block_N1 % ThreadPerBlock_N1 == 0);
99  static_assert(Block_K1 % ThreadPerBlock_K1 == 0);
103 
104  static constexpr index_t BlockSize = get_warp_size() * NumWarps;
105 
106  // some assert
107  static_assert(Block_M0 == Block_M1);
108  static_assert(Block_N0 == Block_K1 || (Block_N0 / 2) == Block_K1); // Gate Only or Gate+Up
109 
110  // pre-shuffle tile size compute (assume only for B matrix)
111  // we flatten the each wave tile to a 1d linear tensor(at model loading time)
112  // e.g. originally we have Block_N*Block_K tile size, after pre-shuffle
113  // we can have Block_Nr*Block_Kr*Block_W, where Block_W is Warp_N*Warp_K,
114  // and Block_Nr=Block_N/Warp_N, Block_Kr=Block_K/Warp_K
115  static constexpr index_t Block_W0 = Warp_N0 * Warp_K0;
116  static constexpr index_t Block_Nr0 = Block_N0 / Warp_N0;
117  static constexpr index_t Block_Kr0 = Block_K0 / Warp_K0;
118  static constexpr index_t Block_W1 = Warp_N1 * Warp_K1;
119  static constexpr index_t Block_Nr1 = Block_N1 / Warp_N1;
120  static constexpr index_t Block_Kr1 = Block_K1 / Warp_K1;
121 
122  static_assert(Block_W0 == Block_W1);
123  // static_assert(Block_Nr0 == Block_Kr1);
124 };
125 } // namespace ck_tile
Definition: cluster_descriptor.hpp:13
__host__ __device__ multiplies() -> multiplies< void, void >
FIXME: create macro to replace 'host device' and nothing more.
int32_t index_t
Definition: integer.hpp:9
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition: type_traits.hpp:21
constexpr CK_TILE_HOST_DEVICE index_t reduce_on_sequence(Seq, Reduce f, number< Init >)
Definition: sequence.hpp:979
__host__ constexpr __device__ index_t get_warp_size()
Definition: get_id.hpp:42
Definition: fused_moegemm_shape.hpp:50
static constexpr index_t ThreadPerBlock_N0
Definition: fused_moegemm_shape.hpp:75
static constexpr index_t Repeat_K0
Definition: fused_moegemm_shape.hpp:82
static constexpr index_t ThreadPerBlock_M1
Definition: fused_moegemm_shape.hpp:94
static constexpr index_t BlockSize
Definition: fused_moegemm_shape.hpp:104
static constexpr index_t Repeat_K1
Definition: fused_moegemm_shape.hpp:102
remove_cvref_t< WarpTile_0_ > WarpTile_0
Definition: fused_moegemm_shape.hpp:53
static constexpr index_t Block_N0
Definition: fused_moegemm_shape.hpp:65
static constexpr index_t Warp_N0
Definition: fused_moegemm_shape.hpp:71
static constexpr index_t Block_K0
Definition: fused_moegemm_shape.hpp:66
remove_cvref_t< WarpPerBlock_1_ > WarpPerBlock_1
Definition: fused_moegemm_shape.hpp:55
static constexpr index_t Warp_N1
Definition: fused_moegemm_shape.hpp:91
static constexpr index_t Block_W0
Definition: fused_moegemm_shape.hpp:115
static constexpr index_t WarpPerBlock_N0
Definition: fused_moegemm_shape.hpp:68
static constexpr index_t WarpPerBlock_M0
Definition: fused_moegemm_shape.hpp:67
static constexpr index_t Repeat_N0
Definition: fused_moegemm_shape.hpp:81
static constexpr index_t Block_Nr1
Definition: fused_moegemm_shape.hpp:119
static constexpr index_t Block_Kr1
Definition: fused_moegemm_shape.hpp:120
static constexpr index_t Repeat_M1
Definition: fused_moegemm_shape.hpp:100
static constexpr index_t NumWarps
Definition: fused_moegemm_shape.hpp:58
static constexpr index_t Repeat_M0
Definition: fused_moegemm_shape.hpp:80
remove_cvref_t< WarpTile_1_ > WarpTile_1
Definition: fused_moegemm_shape.hpp:56
static constexpr index_t Block_M1
Definition: fused_moegemm_shape.hpp:84
static constexpr index_t ThreadPerBlock_M0
Definition: fused_moegemm_shape.hpp:74
static constexpr index_t Block_K1
Definition: fused_moegemm_shape.hpp:86
static constexpr index_t Block_M0
Definition: fused_moegemm_shape.hpp:64
remove_cvref_t< BlockTile_0_ > BlockTile_0
Definition: fused_moegemm_shape.hpp:51
static constexpr index_t WarpPerBlock_K1
Definition: fused_moegemm_shape.hpp:89
remove_cvref_t< WarpPerBlock_0_ > WarpPerBlock_0
Definition: fused_moegemm_shape.hpp:52
static constexpr index_t Block_Nr0
Definition: fused_moegemm_shape.hpp:116
static constexpr index_t Warp_M0
Definition: fused_moegemm_shape.hpp:70
static constexpr index_t Block_Kr0
Definition: fused_moegemm_shape.hpp:117
remove_cvref_t< BlockTile_1_ > BlockTile_1
Definition: fused_moegemm_shape.hpp:54
static constexpr index_t WarpPerBlock_M1
Definition: fused_moegemm_shape.hpp:87
static constexpr index_t Warp_M1
Definition: fused_moegemm_shape.hpp:90
static constexpr index_t ThreadPerBlock_K1
Definition: fused_moegemm_shape.hpp:96
static constexpr index_t Block_N1
Definition: fused_moegemm_shape.hpp:85
static constexpr index_t Repeat_N1
Definition: fused_moegemm_shape.hpp:101
static constexpr index_t WarpPerBlock_K0
Definition: fused_moegemm_shape.hpp:69
static constexpr index_t Warp_K0
Definition: fused_moegemm_shape.hpp:72
static constexpr index_t ThreadPerBlock_K0
Definition: fused_moegemm_shape.hpp:76
static constexpr index_t WarpPerBlock_N1
Definition: fused_moegemm_shape.hpp:88
static constexpr index_t ThreadPerBlock_N1
Definition: fused_moegemm_shape.hpp:95
static constexpr index_t Warp_K1
Definition: fused_moegemm_shape.hpp:92
static constexpr index_t Block_W1
Definition: fused_moegemm_shape.hpp:118
Definition: integral_constant.hpp:13
Definition: math.hpp:98