/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/host/reference/reference_moe_gemm.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/host/reference/reference_moe_gemm.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/host/reference/reference_moe_gemm.hpp Source File
reference_moe_gemm.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include <cstdlib>
7 #include <thread>
8 
9 #include "ck_tile/core.hpp"
11 
12 namespace ck_tile {
13 
14 template <typename ADataType,
15  typename BDataType,
16  typename AccDataType,
17  typename CDataType,
18  typename LayoutA,
19  typename LayoutB,
20  typename LayoutC,
21  int MoeGemmKind = 0, // 0: gemm1_gate_only, 1: gemm1_gate_up, 2: gemm2
22  typename ActivationOp = identity>
23 __global__ void moe_gemm_kernel(const ck_tile::index_t* p_sorted_token_ids_,
24  const ck_tile::index_t* p_sorted_expert_ids_,
25  const ck_tile::index_t* p_max_token_id_,
26  const ADataType* A,
27  const BDataType* B,
28  CDataType* C,
29  const AccDataType* expert_weight_ptr,
30  ck_tile::index_t Num_tokens,
31  ck_tile::index_t TokensPerBlock,
32  ck_tile::index_t TopK,
36  ck_tile::index_t strideA,
37  ck_tile::index_t strideB,
38  ck_tile::index_t strideC,
39  index_t scale_granularity_m,
40  index_t scale_granularity_n,
41  index_t scale_granularity_k,
42  float* scale_A_ptr,
43  float* scale_B_ptr,
44  float* expert_bias_ptr)
45 {
46  int idx = blockIdx.x * blockDim.x + threadIdx.x;
47  int problem_N = MoeGemmKind == 1 ? N / 2 : N;
48  int row = idx / problem_N; // Compute row index
49  int col = idx % problem_N; // Compute column index
50 
51  index_t gather_token_id = 0;
52  index_t scatter_token_id = 0;
53  index_t expert_id = 0;
54 
55  if(row < p_max_token_id_[0])
56  {
57  expert_id = p_sorted_expert_ids_[row / TokensPerBlock];
58  gather_token_id = p_sorted_token_ids_[row] & 0xff'ffff;
59  scatter_token_id = p_sorted_token_ids_[row] & 0xff'ffff;
60  if(gather_token_id >= Num_tokens)
61  {
62  return;
63  }
64  if(MoeGemmKind == 2)
65  {
66  gather_token_id = gather_token_id * TopK + (p_sorted_token_ids_[row] >> 24);
67  }
68  else
69  {
70  scatter_token_id = scatter_token_id * TopK + (p_sorted_token_ids_[row] >> 24);
71  }
72  }
73  else
74  {
75  return;
76  }
77 
78  if(row < M)
79  {
80  AccDataType acc = 0.0;
81  AccDataType acc_up = 0.0;
82 
83  AccDataType acc_temp = 0.0;
84  AccDataType acc_up_temp = 0.0;
85 
86  float scale_A = 0;
87  float scale_B = 0;
88  float scale_B_up = 0;
89 
90  index_t scale_A_stride = (M + scale_granularity_m - 1) / scale_granularity_m;
91  index_t scale_B_stride = (N + scale_granularity_n - 1) / scale_granularity_n;
92  index_t scale_B_expert_stride = scale_B_stride * K / scale_granularity_k;
93 
94  for(int k = 0; k < K; ++k)
95  {
96  if(k % scale_granularity_k == 0)
97  {
98  // update acc
99  acc += acc_temp * scale_A * scale_B;
100  acc_up += acc_up_temp * scale_A * scale_B_up;
101  // reset acc temp
102  acc_temp = 0.0;
103  acc_up_temp = 0.0;
104  // update scale factors
105  scale_A = scale_A_ptr[(gather_token_id / scale_granularity_m) +
106  (k / scale_granularity_k) * scale_A_stride];
107  scale_B =
108  scale_B_ptr[expert_id * scale_B_expert_stride + col / scale_granularity_n +
109  (k / scale_granularity_k) * scale_B_stride];
110  if constexpr(MoeGemmKind == 1)
111  scale_B_up = scale_B_ptr[expert_id * scale_B_expert_stride +
112  (col + problem_N) / scale_granularity_n +
113  (k / scale_granularity_k) * scale_B_stride];
114  }
115 
118  // Adjust indexing based on matrix layout
119  int a_index = (std::is_same_v<LayoutA, tensor_layout::gemm::RowMajor>)
120  ? gather_token_id * strideA + k
121  : k * strideA + gather_token_id;
122 
123  long b_index =
124  long(expert_id) * N * K +
125  ((std::is_same_v<LayoutB, tensor_layout::gemm::ColumnMajor>) ? col * strideB + k
126  : k * strideB + col);
127  long b_index_up;
128  if constexpr(MoeGemmKind == 1)
129  b_index_up = long(expert_id) * N * K +
130  ((std::is_same_v<LayoutB, tensor_layout::gemm::ColumnMajor>)
131  ? (col + problem_N) * strideB + k
132  : k * strideB + col + problem_N);
133 
134  AccDataType v_a;
135  AccDataType v_b;
136  AccDataType v_b_up;
137  if constexpr(std::is_same_v<ADataType, pk_int4_t>)
138  {
139  const fp32x2_t fp32_val = pk_int4_t_to_fp32x2_t(A[a_index / packed_size_a]);
140  if(k % 2 == 1)
141  v_a = fp32_val.hi;
142  else
143  v_a = fp32_val.lo;
144  }
145  else if constexpr(std::is_same_v<ADataType, pk_fp4_t>)
146  {
147  const fp32x2_t fp32_val = pk_fp4_to_fp32x2(A[a_index / packed_size_a]);
148  if(k % 2 == 1)
149  v_a = fp32_val.hi;
150  else
151  v_a = fp32_val.lo;
152  }
153  else
154  {
155  v_a = ck_tile::type_convert<AccDataType>(A[a_index]);
156  }
157  if constexpr(std::is_same_v<BDataType, pk_int4_t>)
158  {
159  const fp32x2_t fp32_val = pk_int4_t_to_fp32x2_t(B[b_index / packed_size_b]);
160  if(k % 2 == 1)
161  v_b = fp32_val.hi;
162  else
163  v_b = fp32_val.lo;
164  if constexpr(MoeGemmKind == 1)
165  {
166  const fp32x2_t fp32_val_up =
167  pk_int4_t_to_fp32x2_t(B[b_index_up / packed_size_b]);
168  if(k % 2 == 1)
169  v_b_up = fp32_val_up.hi;
170  else
171  v_b_up = fp32_val_up.lo;
172  }
173  }
174  else if constexpr(std::is_same_v<BDataType, pk_fp4_t>)
175  {
176  const fp32x2_t fp32_val = pk_fp4_to_fp32x2(B[b_index / packed_size_b], 1.0f);
177  if(k % 2 == 1)
178  v_b = fp32_val.hi;
179  else
180  v_b = fp32_val.lo;
181  if constexpr(MoeGemmKind == 1)
182  {
183  const fp32x2_t fp32_val_up =
184  pk_fp4_to_fp32x2(B[b_index_up / packed_size_b], 1.0f);
185  if(k % 2 == 1)
186  v_b_up = fp32_val_up.hi;
187  else
188  v_b_up = fp32_val_up.lo;
189  }
190  }
191  else
192  {
193  v_b = ck_tile::type_convert<AccDataType>(B[b_index]);
194  if constexpr(MoeGemmKind == 1)
195  v_b_up = ck_tile::type_convert<AccDataType>(B[b_index_up]);
196  }
197  acc_temp += v_a * v_b;
198  if constexpr(MoeGemmKind == 1)
199  acc_up_temp += v_a * v_b_up;
200  }
201 
202  acc += acc_temp * scale_A * scale_B;
203  acc_up += acc_up_temp * scale_A * scale_B_up;
204 
205  float bias = 0.f, bias_up = 0.f;
206  if(expert_bias_ptr != nullptr)
207  {
208  bias = expert_bias_ptr[expert_id * N + col];
209  if constexpr(MoeGemmKind == 1)
210  bias_up = expert_bias_ptr[expert_id * N + col + problem_N];
211  }
212 
213  int c_index = (std::is_same_v<LayoutC, tensor_layout::gemm::RowMajor>)
214  ? scatter_token_id * strideC + col
215  : col * strideC + scatter_token_id;
216  if constexpr(MoeGemmKind < 2)
217  {
218  C[c_index] = ck_tile::type_convert<CDataType>(
219  ActivationOp{}(acc + bias, MoeGemmKind == 1 ? acc_up + bias_up : 1));
220  }
221  else
222  {
223  // moe gemm2 don't use activation.
224  CDataType res = ck_tile::type_convert<CDataType>((acc + bias) * expert_weight_ptr[row]);
225  using ResV2Type = std::conditional_t<std::is_same_v<CDataType, ck_tile::half_t>,
228  ResV2Type add_v{0, 0};
229  if(c_index % 2)
230  {
231  // result is the second value of fp16 pair.
232  add_v.y = res;
233  }
234  else
235  {
236  // result is the first value of fp16 pair.
237  add_v.x = res;
238  }
239  // mask last bit to make sure atomicAdd pointer is aligned of DWORD.
240  atomic_add<ResV2Type>(reinterpret_cast<ResV2Type*>(C + (c_index & 0xffff'fffe)), add_v);
241  }
242  }
243 }
244 
245 template <typename ADataType,
246  typename BDataType,
247  typename AccDataType,
248  typename CDataType,
249  typename LayoutA,
250  typename LayoutB,
251  typename LayoutC,
252  int MoeGemmKind = 0, // 0: gemm1_gate_only, 1: gemm1_gate_up, 2: gemm2
253  typename ActivationOp = identity>
254 void reference_moe_gemm_gpu(const index_t* p_sorted_token_ids_,
255  const index_t* p_sorted_expert_ids_,
256  const index_t* p_max_token_id_,
257  const ADataType* a_ptr,
258  const BDataType* b_ptr,
259  CDataType* c_ptr,
260  const AccDataType* expert_weight_ptr,
261  index_t Num_tokens,
262  index_t TokensPerBlock,
263  index_t TopK,
264  index_t M,
265  index_t N,
266  index_t K,
267  index_t stride_a,
268  index_t stride_b,
269  index_t stride_c,
270  index_t scale_granularity_m,
271  index_t scale_granularity_n,
272  index_t scale_granularity_k,
273  float* scale_A_ptr,
274  float* scale_B_ptr,
275  float* exp_bias = nullptr)
276 {
277  int problem_N = MoeGemmKind == 1 ? N / 2 : N;
278  int totalElements = M * problem_N;
279  int numThreadsPerBlock = 256; // Common choice for threads per block
280  int numBlocks = (totalElements + numThreadsPerBlock - 1) / numThreadsPerBlock;
281 
282  moe_gemm_kernel<ADataType,
283  BDataType,
284  AccDataType,
285  CDataType,
286  LayoutA,
287  LayoutB,
288  LayoutC,
289  MoeGemmKind,
290  ActivationOp><<<numBlocks, numThreadsPerBlock>>>(p_sorted_token_ids_,
291  p_sorted_expert_ids_,
292  p_max_token_id_,
293  a_ptr,
294  b_ptr,
295  c_ptr,
296  expert_weight_ptr,
297  Num_tokens,
298  TokensPerBlock,
299  TopK,
300  M,
301  N,
302  K,
303  stride_a,
304  stride_b,
305  stride_c,
306  scale_granularity_m,
307  scale_granularity_n,
308  scale_granularity_k,
309  scale_A_ptr,
310  scale_B_ptr,
311  exp_bias);
312 
313  return;
314 }
315 
316 } // namespace ck_tile
Definition: cluster_descriptor.hpp:13
CK_TILE_HOST_DEVICE fp32x2_t pk_int4_t_to_fp32x2_t(const pk_int4_t &x)
Definition: pk_int4.hpp:105
bfloat16_t bf16x2_t
Definition: pk_fp4.hpp:24
float fp32x2_t
Definition: pk_fp4.hpp:22
int32_t index_t
Definition: integer.hpp:9
_Float16 fp16x2_t
Definition: half.hpp:385
__global__ void moe_gemm_kernel(const ck_tile::index_t *p_sorted_token_ids_, const ck_tile::index_t *p_sorted_expert_ids_, const ck_tile::index_t *p_max_token_id_, const ADataType *A, const BDataType *B, CDataType *C, const AccDataType *expert_weight_ptr, ck_tile::index_t Num_tokens, ck_tile::index_t TokensPerBlock, ck_tile::index_t TopK, ck_tile::index_t M, ck_tile::index_t N, ck_tile::index_t K, ck_tile::index_t strideA, ck_tile::index_t strideB, ck_tile::index_t strideC, index_t scale_granularity_m, index_t scale_granularity_n, index_t scale_granularity_k, float *scale_A_ptr, float *scale_B_ptr, float *expert_bias_ptr)
Definition: reference_moe_gemm.hpp:23
void reference_moe_gemm_gpu(const index_t *p_sorted_token_ids_, const index_t *p_sorted_expert_ids_, const index_t *p_max_token_id_, const ADataType *a_ptr, const BDataType *b_ptr, CDataType *c_ptr, const AccDataType *expert_weight_ptr, index_t Num_tokens, index_t TokensPerBlock, index_t TopK, index_t M, index_t N, index_t K, index_t stride_a, index_t stride_b, index_t stride_c, index_t scale_granularity_m, index_t scale_granularity_n, index_t scale_granularity_k, float *scale_A_ptr, float *scale_B_ptr, float *exp_bias=nullptr)
Definition: reference_moe_gemm.hpp:254
constexpr CK_TILE_HOST_DEVICE fp32x2_t pk_fp4_to_fp32x2(const pk_fp4_t &x, float scale)
Definition: pk_fp4.hpp:350
Definition: numeric.hpp:81