/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/host/reference/reference_fused_moe.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/host/reference/reference_fused_moe.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/host/reference/reference_fused_moe.hpp Source File
reference_fused_moe.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include "ck_tile/core.hpp"
8 
9 namespace ck_tile {
10 // [indexing implementation-1]
11 // using M_a as constexpr block_size to partition all tokens into different slices
12 // each slice map to one expert, and one expert can have multiple slices
13 // e.g. num_experts = 6, topk=3, M_a = 4, input_tokens = 5
14 // before sort, topk_ids is : [[0, 3, 5], [2, 3, 5], [1, 3, 5], [1, 2, 3], [1, 3, 5]]
15 // tok-0 tok-1 tok-2 tok-3 tok-4
16 // topk_weight is : [[a, b, c], [d, e, f], [g, h, i], [j, k, l], [m, n, o]] (some float
17 // number)
18 //
19 // token_id_per_expert is : [[0], [2, 3, 4], [1, 3], [0, 1, 2, 3, 4], [], [0, 1, 2, 5]]
20 // (only for reference) exp-0 exp-1 exp-2 exp-3 exp-4 exp-5
21 // weight_id_per_expert is: [[a], [g, j, m], [d, k], [b, e, h, l, n], [], [c, f, i, o]]
22 //
23 // max_num_tokens_padded : topk * input_tokens + num_experts * (M_a - 1)
24 // max_num_tokens_padded : topk * input_tokens + num_experts * M_a - topk (updated)
25 // * this could be larger than actual, since actual tokens are on GPU
26 //
27 // sorted_token_ids_ptr : [0, 6, 6, 6, 2, 3, 4, 6, 1, 3, 6, 6, 0, 1, 2, 3, 4, 6, 6, 6, 6, 6, 6, 6,
28 // 0, 1, 2, 5]
29 // |- exp-0 -|- exp-1 -|- exp-2 -|- exp-3 -|- exp-4
30 // -|- exp-5 -|
31 // sorted_weight_ptr : [a, *, *, *, g, j, m, *, d, k, *, *, b, e, h, l, n, *, *, *, *, *, *, *,
32 // c, f, i, o]
33 //
34 // * length is max_num_tokens_padded, actual size is num_tokens_post_padded_ptr
35 //
36 // sorted_expert_ids_ptr : [0, 1, 2, 3, 3, 4, 5]
37 // * length is (max_num_tokens_padded + block_size - 1) / block_size
39 // num_tokens_post_padded_ptr : [28]
40 // num_sorted_tiles_ptr : [7]
41 
42 template <typename AccDataType, // you only need to explcitly set this one
43  typename Activation, // ck_tile::element_wise::Gelu
44  typename ADataType,
45  typename GDataType,
46  typename DDataType,
47  typename ODataType,
48  typename AScaleDataType,
49  typename GScaleDataType,
50  typename DScaleDataType,
51  typename YSmoothScaleDataType,
52  typename TopkWeightDataType,
53  typename IndexDataType>
55  const ck_tile::HostTensor<ADataType>& a_host, // [tokens, hidden_size]
56  const ck_tile::HostTensor<GDataType>& g_host, // [experts, interme_size_0, hidden_size]
57  const ck_tile::HostTensor<DDataType>& d_host, // [experts, hidden_size, interme_size_1]
58  const ck_tile::HostTensor<AScaleDataType>& sa_host, // [tokens, 1],
59  const ck_tile::HostTensor<GScaleDataType>& sg_host, // [experts, 1, interme_size_0]
60  const ck_tile::HostTensor<DScaleDataType>& sd_host, // [experts, 1, hidden_size],
61  const ck_tile::HostTensor<YSmoothScaleDataType>& sy_host, // [experts, 1, interme_size_0]
62  ck_tile::HostTensor<ODataType>& o_host, // [tokens, hidden_size]
63  const ck_tile::HostTensor<IndexDataType>& sorted_token_ids_host, // [max_num_tokens_padded]
64  const ck_tile::HostTensor<TopkWeightDataType>& sorted_weight_host, // [max_num_tokens_padded]
66  sorted_expert_ids_host, // [(max_num_tokens_padded + block_size - 1) / block_size]
67  const ck_tile::HostTensor<IndexDataType>& num_sorted_tiles_host, // [1]
68 
70  token_ids_host, // [tokens, topk] --> ugly!!! remove in the future
71 
72  ck_tile::index_t block_m,
73  ck_tile::index_t tokens,
74  ck_tile::index_t experts,
75  ck_tile::index_t hidden_size,
76  ck_tile::index_t intermediate_size, // this size is for gate/up/down
77  ck_tile::index_t topk,
78  ck_tile::index_t gate_only)
79 {
80  assert(sorted_token_ids_host.get_num_of_dimension() == 1);
81  assert(sorted_weight_host.get_num_of_dimension() == 1);
82  assert(sorted_expert_ids_host.get_num_of_dimension() == 1);
83  assert(num_sorted_tiles_host.get_element_size() == 1);
84  ck_tile::index_t num_sorted_tiles = num_sorted_tiles_host.mData[0] / block_m;
85  ck_tile::index_t intermediate_size_0 = intermediate_size * (gate_only ? 1 : 2);
86  ck_tile::index_t intermediate_size_1 = intermediate_size;
87 
88  ck_tile::HostTensor<AccDataType> out_topk_tokens({tokens, topk, hidden_size});
89 
90  int max_num_tokens_padded = topk * tokens + experts * block_m - topk;
91  // assert();
92  auto f = [&](auto i_flatten) {
93  ck_tile::index_t i_tile = i_flatten / block_m;
94  if(i_tile >= num_sorted_tiles)
95  return;
96  ck_tile::index_t i_expert = sorted_expert_ids_host.mData[i_tile];
97 
98 #if CK_TILE_REFERENCE_MOE_SORTING_MOCK_ID
99  ck_tile::index_t i_token = sorted_token_ids_host.mData[i_flatten];
100  ck_tile::index_t i_topk = i_token >> 24;
101  i_token &= 0xffffff;
102  if(i_token >= tokens)
103  return;
104  (void)token_ids_host;
105 #else
106  // TODO: better remove this in the future, or modify the token_id value
107  auto get_topk_id = [&](ck_tile::index_t token_id_, ck_tile::index_t expert_id_) {
108  for(ck_tile::index_t i_ = 0; i_ < topk; i_++)
109  {
110  if(token_ids_host(token_id_, i_) == expert_id_)
111  return i_;
112  }
113  throw std::runtime_error("not correct token/expert pair\n");
114  return -1; // TODO: not correct!!
115  };
116  ck_tile::index_t i_token = sorted_token_ids_host.mData[i_flatten];
117  if(i_token >= tokens)
118  return;
119  ck_tile::index_t i_topk = get_topk_id(i_token, i_expert); // TODO: ugly
120 #endif
121  auto weight = sorted_weight_host.mData[i_flatten];
122 
123  ck_tile::HostTensor<AccDataType> acc_0({1, intermediate_size_0});
124  // first gemm
125  for(ck_tile::index_t i_n = 0; i_n < intermediate_size_0; i_n++)
126  {
127  AccDataType acc = static_cast<AccDataType>(0);
128  for(ck_tile::index_t i_k = 0; i_k < hidden_size; i_k++)
129  {
130  acc += type_convert<AccDataType>(a_host(i_token, i_k)) *
131  type_convert<AccDataType>(g_host(i_expert, i_n, i_k));
132  }
133  acc_0(0, i_n) = acc;
134  // printf("ie:%2d, it:%3d, in:%d, %f\n", i_expert, i_token, i_n, acc);
135  }
136 
137  ck_tile::HostTensor<AccDataType> y({1, intermediate_size_1});
138  if(gate_only)
139  {
140  if(intermediate_size_1 != intermediate_size_0)
141  throw std::runtime_error(
142  "intermediate_size not correct, 0:" + std::to_string(intermediate_size_0) +
143  ", 1:" + std::to_string(intermediate_size_1));
144  for(ck_tile::index_t i_n = 0; i_n < intermediate_size_1; i_n++)
145  {
146  Activation{}(y(0, i_n), acc_0(0, i_n));
147  // printf("ie:%2d, it:%3d, in:%d, %f\n", i_expert, i_token, i_n, y(0, i_n));
148  }
149  }
150  else
151  {
152  if(intermediate_size_1 * 2 != intermediate_size_0)
153  throw std::runtime_error(
154  "intermediate_size not correct, 0:" + std::to_string(intermediate_size_0) +
155  ", 1:" + std::to_string(intermediate_size_1));
156  for(ck_tile::index_t i_n = 0; i_n < intermediate_size_1; i_n++)
157  {
158  AccDataType tmp;
159  Activation{}(tmp, acc_0(0, i_n));
160  y(0, i_n) = tmp * acc_0(0, i_n + intermediate_size_1); // TODO: elementwise mul
161  }
162  }
163 
164  // second gemm, loop along gemm-n
165  ck_tile::HostTensor<AccDataType> acc_1({1, hidden_size});
166  for(ck_tile::index_t i_n = 0; i_n < hidden_size; i_n++)
167  {
168  AccDataType acc = static_cast<AccDataType>(0);
169  for(ck_tile::index_t i_k = 0; i_k < intermediate_size_1; i_k++)
170  {
171  acc += y(0, i_k) * type_convert<AccDataType>(d_host(i_expert, i_n, i_k));
172  }
173  acc_1(0, i_n) = acc * weight; // multiple weight here
174  }
175 
176  for(ck_tile::index_t i_n = 0; i_n < hidden_size; i_n++)
177  {
178  out_topk_tokens(i_token, i_topk, i_n) = acc_1(0, i_n);
179  }
180  };
181 
182  // make_ParallelTensorFunctor(f, max_num_tokens_padded)(std::thread::hardware_concurrency());
183  make_ParallelTensorFunctor(f, max_num_tokens_padded)(1);
184 
185  // reduce
186  auto r = [&](auto i_token) {
187  for(ck_tile::index_t i_n = 0; i_n < hidden_size; i_n++)
188  {
189  AccDataType acc = type_convert<AccDataType>(0);
190  for(ck_tile::index_t i_topk = 0; i_topk < topk; i_topk++)
191  {
192  acc += out_topk_tokens(i_token, i_topk, i_n);
193  }
194  o_host(i_token, i_n) = type_convert<ODataType>(acc);
195  }
196  };
197  make_ParallelTensorFunctor(r, tokens)(std::thread::hardware_concurrency());
198 
199  (void)num_sorted_tiles_host;
200  (void)sa_host;
201  (void)sg_host;
202  (void)sd_host;
203  (void)sy_host;
204 }
205 } // namespace ck_tile
Definition: cluster_descriptor.hpp:13
CK_TILE_HOST auto make_ParallelTensorFunctor(F f, Xs... xs)
Definition: host_tensor.hpp:329
int32_t index_t
Definition: integer.hpp:9
void reference_fused_moe(const ck_tile::HostTensor< ADataType > &a_host, const ck_tile::HostTensor< GDataType > &g_host, const ck_tile::HostTensor< DDataType > &d_host, const ck_tile::HostTensor< AScaleDataType > &sa_host, const ck_tile::HostTensor< GScaleDataType > &sg_host, const ck_tile::HostTensor< DScaleDataType > &sd_host, const ck_tile::HostTensor< YSmoothScaleDataType > &sy_host, ck_tile::HostTensor< ODataType > &o_host, const ck_tile::HostTensor< IndexDataType > &sorted_token_ids_host, const ck_tile::HostTensor< TopkWeightDataType > &sorted_weight_host, const ck_tile::HostTensor< IndexDataType > &sorted_expert_ids_host, const ck_tile::HostTensor< IndexDataType > &num_sorted_tiles_host, const ck_tile::HostTensor< IndexDataType > &token_ids_host, ck_tile::index_t block_m, ck_tile::index_t tokens, ck_tile::index_t experts, ck_tile::index_t hidden_size, ck_tile::index_t intermediate_size, ck_tile::index_t topk, ck_tile::index_t gate_only)
Definition: reference_fused_moe.hpp:54
Activation
Definition: gridwise_moe_gemm.hpp:31
Definition: host_tensor.hpp:336
std::size_t get_num_of_dimension() const
Definition: host_tensor.hpp:396
std::size_t get_element_size() const
Definition: host_tensor.hpp:398
Data mData
Definition: host_tensor.hpp:801