/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/fused_moe/kernel/fused_moegemm_kernel.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/fused_moe/kernel/fused_moegemm_kernel.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/fused_moe/kernel/fused_moegemm_kernel.hpp Source File
fused_moegemm_kernel.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include "ck_tile/core.hpp"
7 #include "ck_tile/ops/common.hpp"
9 #include <string>
10 #include <type_traits>
11 
12 // clang-format off
13 // [indexing implementation-1]
14 // using M_a as constexpr block_size to partition all tokens into different slices
15 // each slice map to one expert, and one expert can have multiple slices
16 // e.g. num_experts = 6, topk=3, M_a = 4, input_tokens = 5
17 // before sort, topk_ids is : [[0, 3, 5], [2, 3, 5], [1, 3, 5], [1, 2, 3], [1, 3, 5]]
18 // tok-0 tok-1 tok-2 tok-3 tok-4
19 // topk_weight is : [[a, b, c], [d, e, f], [g, h, i], [j, k, l], [m, n, o]] (some float number)
20 //
21 // token_id_per_expert is : [[0], [2, 3, 4], [1, 3], [0, 1, 2, 3, 4], [], [0, 1, 2, 5]]
22 // (only for reference) exp-0 exp-1 exp-2 exp-3 exp-4 exp-5
23 // weight_id_per_expert is: [[a], [g, j, m], [d, k], [b, e, h, l, n], [], [c, f, i, o]]
24 //
25 // max_num_tokens_padded : topk * input_tokens + num_experts * M_a - topk (updated)
26 // * this could be larger than actual, since actual tokens are on GPU
27 //
28 // sorted_token_ids_ptr : [0, 6, 6, 6, 2, 3, 4, 6, 1, 3, 6, 6, 0, 1, 2, 3, 4, 6, 6, 6, 6, 6, 6, 6, 0, 1, 2, 5]
29 // |- exp-0 -|- exp-1 -|- exp-2 -|- exp-3 -|- exp-4 -|- exp-5 -|
30 // sorted_weight_ptr : [a, *, *, *, g, j, m, *, d, k, *, *, b, e, h, l, n, *, *, *, *, *, *, *, c, f, i, o]
31 //
32 // * length is max_num_tokens_padded, actual size is num_tokens_post_padded_ptr
33 //
34 // * Note on token_id_per_expert/sorted_token_ids_ptr data:
35 // currently we do not have topk information from the data of token_id_per_expert/sorted_token_ids_ptr.
36 // In some cases(like smooth-quant), we need topk information to indexing into tokens quant from
37 // different expert smooth quant. So we modify the number stored inside token_id_per_expert/sorted_token_ids_ptr
38 //
39 // 32bit 0........23 24.....31 bit
40 // (data) -> (token_id | topk_id)
41 // low 24 bit is for token id, top 8 bit is for topk id
42 //
43 // the input after smooth-quant is [token, topk, hidden_dim], originally it is [token, hidden_dim]
44 // the input scale for token is [topk, token, 1], the smooth-quant scale for first gemm is [expert, interm_dim]
45 //
46 // sorted_expert_ids_ptr : [0, 1, 2, 3, 3, 4, 5]
47 // * length is (max_num_tokens_padded + block_size - 1) / block_size
48 //
49 // num_tokens_post_padded_ptr : [28]
50 // num_sorted_tiles_ptr : [7]
51 //
52 // * different from vLLM
53 // 1) token_id stored in sorted_token_ids_ptr is actual token_id, not token_id*top_K expanded id
54 // 2)need sorted_weight_ptr
55 // 3) use num_sorted_tiles_ptr, already divided by M_a
56 //
57 // * below used for indexing
58 // 1) sorted_token_ids_ptr [max_num_tokens_padded]
59 // 2) sorted_weight_ptr
60 // 3) sorted_expert_ids_ptr
61 // 4)num_tokens_post_padded_ptr/num_sorted_tiles_ptr (select one)
62 //
63 // max_num_tokens_padded: opk_ids.numel() + num_experts * (block_size - 1)
64 //
65 // [indexing implementation-2]
66 // before sort, topk_ids is : [[0, 3, 5], [2, 3, 5], [1, 3, 5], [1, 2, 3], [1, 3, 5]]
67 // tok-0 tok-1 tok-2 tok-3 tok-4
68 // topk_weight is : [[a, b, c], [d, e, f], [g, h, i], [j, k, l], [m, n, o]] (some float number)
69 //
70 // we generate original rol/col id as
71 // topk_rc_ids : [[0, 5, A], [1, 6, B], [2, 7, C], [3, 8, D], [4, 9, E]]
72 // let x be one element of above, we can get:
73 // tpok_row_id(token_id) = x % num_tokens(5)
74 // tpok_col_id(expert_Id) = x / num_tokens
75 // topk_row_id/col_id can be used to access original topk_ids/topk_weight
76 //
77 // token_id_per_expert is : [[0], [2, 3, 4], [1, 3], [0, 1, 2, 3, 4], [], [0, 1, 5, 5]]
78 // (only for reference) exp-0 exp-1 exp-2 exp-3 exp-4 exp-5
79 // weight_id_per_expert is: [[a], [g, j, m], [d, k], [b, e, h, l, n], [], [c, f, i, o]]
80 //
81 // we can get permuted_rc_ids:
82 // [[0], [2, 3, 4], [1, 8], [5, 6, 7, D, 9], [], [A, B, C, E]]
83 //
84 //
85 // clang-format on
86 //
87 namespace ck_tile {
88 
89 // m: num_tokens (or token*input-batch)
90 // k: intermediate_size
91 // n: intermediate_size used between 2 FC (TP slice this)
92 // e: num expert
93 // if doing pre-shuffle
94 // nr : n / Block_Nr
95 // kr : k / Block_Kr
96 // w : fattened 1d wave buffer
98 {
99  const void* a_ptr; // [m, k], input token
100  const void* a_scale_ptr; // [m, 1], token scale
101  const void* g_ptr; // [e, n, k]/[e, 2*n, k], pre-shuffle([e, nr, kr, w])
102  const void* d_ptr; // [e, n, k], pre-shuffle([e, nr, kr, w])
103  const void* g_scale_ptr; // [e, 1, n], gate(up) scale
104  const void* d_scale_ptr; // [e, 1, k], down scale
105  const void* y_smooth_scale_ptr; // [e, 1, n], smooth-quant-scale for 2nd gemm input
106  void* o_ptr; // [m, k], output token
107 
108  const void* sorted_token_ids_ptr; // [max_num_tokens_padded]
109  const void* sorted_weight_ptr; // [max_num_tokens_padded]
110  const void* sorted_expert_ids_ptr; // [(max_num_tokens_padded + block_size - 1) / block_size]
111  const void* num_sorted_tiles_ptr; // [1]
112 
114  index_t intermediate_size; // n / TP, for Gate/UP/Down
115  index_t num_tokens; // input number of tokens for current iteration
116  index_t num_experts; // number of groups
117  index_t topk; // need this?
118 
119  index_t stride_token; // for input/output, stride for each row, should >= hidden_size
120 };
121 
122 // This is scatter/gather b2b group-gemm
123 template <typename Partitioner_, typename Pipeline_, typename Epilogue_>
125 {
128  using Epilogue = remove_cvref_t<Epilogue_>; // TODO: not used
129  // static constexpr index_t kBlockPerCu = Pipeline::kBlockPerCu;
130  // static_assert(kBlockPerCu > 0);
131 
132  using BlockShape = typename Pipeline::BlockShape; // this is FusedMoeGemmShape
133  static constexpr index_t kBlockSize = BlockShape::BlockSize;
134 
135  using ADataType = typename Pipeline::Problem::ADataType;
136  using GDataType = typename Pipeline::Problem::GDataType;
137  using DDataType = typename Pipeline::Problem::DDataType;
138  using AccDataType = typename Pipeline::Problem::AccDataType;
139  using ODataType = typename Pipeline::Problem::ODataType;
140  using AScaleDataType = typename Pipeline::Problem::AScaleDataType;
141  using GScaleDataType = typename Pipeline::Problem::GScaleDataType;
142  using DScaleDataType = typename Pipeline::Problem::DScaleDataType;
143  using YSmoothScaleDataType = typename Pipeline::Problem::YSmoothScaleDataType;
144  using TopkWeightDataType = typename Pipeline::Problem::TopkWeightDataType;
145  using IndexDataType = typename Pipeline::Problem::IndexDataType;
146  using YDataType = typename Pipeline::Problem::YDataType;
147 
148  using Traits = typename Pipeline::Problem::Traits;
149  static constexpr bool UseUK = true;
150 
151  static constexpr bool IsGateOnly = Traits::IsGateOnly;
152  static constexpr bool UseSmoothQuant = Traits::UseSmoothQuant;
153  static constexpr bool PadHiddenSize = Traits::PadHiddenSize;
154  static constexpr bool PadIntermediateSize = Traits::PadIntermediateSize;
155 
156  // clang-format off
157  template <typename T> struct t2s;
158  template <> struct t2s<float> { static constexpr const char * name = "fp32"; };
159  template <> struct t2s<fp16_t> { static constexpr const char * name = "fp16"; };
160  template <> struct t2s<bf16_t> { static constexpr const char * name = "bf16"; };
161  template <> struct t2s<fp8_t> { static constexpr const char * name = "fp8"; };
162  template <> struct t2s<bf8_t> { static constexpr const char * name = "bf8"; };
163  template <> struct t2s<int8_t> { static constexpr const char * name = "int8"; };
164  // clang-format on
165 
166  CK_TILE_HOST static std::string GetName()
167  {
168 #define _SS_ std::string
169 #define _TS_ std::to_string
170  // clang-format off
171  using S_ = BlockShape;
172 
173  auto prec_str = [&] () {
174  std::string base_str = _SS_(t2s<ADataType>::name);
175  if (!std::is_same_v<ADataType, GDataType>) {
176  base_str += _SS_("_") + _SS_(t2s<GDataType>::name);
177  }
178  return base_str;
179  }();
180 
181  return _SS_("fused_moe_") + _SS_(prec_str) + "_" + (IsGateOnly ? "g1u0_":"g1u1_") +
182  _TS_(S_::Block_M0) + "x" + _TS_(S_::Block_N0) + "x" + _TS_(S_::Block_K0) + "x" + _TS_(S_::Block_N1) + "_" +
183  _TS_(S_::WarpPerBlock_M0) + "x" + _TS_(S_::WarpPerBlock_N0) + "x" + _TS_(S_::WarpPerBlock_K0) + "_" +
184  _TS_(S_::Warp_M0) + "x" + _TS_(S_::Warp_N0) + "x" + _TS_(S_::Warp_K0) + "_" + _SS_(Pipeline::name);
185 #undef _SS_
186 #undef _TS_
187  // clang-format on
188  }
189 
191  {
192  const void* a_ptr; // [m, k], input token
193  const void* a_scale_ptr; // [m, 1], token scale
194  const void* g_ptr; // [e, n, k]/[e, 2*n, k], pre-shuffle([e, nr, kr, w])
195  const void* d_ptr; // [e, n, k], pre-shuffle([e, nr, kr, w])
196  const void* g_scale_ptr; // [e, 1, n], gate(up) scale
197  const void* d_scale_ptr; // [e, 1, k], down scale
198  const void* y_smooth_scale_ptr; // [e, 1, n], smooth-quant-scale for 2nd gemm input
199  void* o_ptr; // [m, k], output token
200 
201  const void* sorted_token_ids_ptr;
202  const void* sorted_weight_ptr;
204  const void* num_sorted_tiles_ptr;
205 
207  index_t intermediate_size; // n / TP, for Gate/Up/Down
208  index_t num_tokens; // input number of tokens for current iteration
209  index_t num_experts; // number of groups
210  index_t topk; // need this?
211 
212  index_t stride_token; // for input/output, stride for each row, should >= hidden_size
213  };
214 
215  // TODO: switch karg based on
218 
219  CK_TILE_HOST static constexpr Kargs MakeKargs(const Hargs& hargs)
220  {
221  // TODO: hargs/kargs not guranteed to be the same
222  return bit_cast<Kargs>(hargs);
223  }
224 
225  CK_TILE_HOST static constexpr auto GridSize(const Hargs& hargs)
226  {
227  constexpr index_t block_m = BlockShape::Block_M0;
228  int max_num_tokens_padded =
229  hargs.topk * hargs.num_tokens + hargs.num_experts * block_m - hargs.topk;
230  // printf("xxx max_num_tokens_padded:%d\n", max_num_tokens_padded);
231  return Partitioner::GridSize(max_num_tokens_padded, hargs.intermediate_size);
232  }
233 
234  CK_TILE_HOST static constexpr auto BlockSize() { return dim3(kBlockSize); }
235 
236  CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() { return Pipeline::GetSmemSize(); }
237 
238  CK_TILE_DEVICE void operator()(Kargs kargs) const
239  {
240  if constexpr(UseUK)
241  {
242  __shared__ CK_TILE_LDS_ADDR char smem[GetSmemSize()];
243  IndexDataType num_sorted_tiles = __builtin_amdgcn_readfirstlane(
244  *reinterpret_cast<const IndexDataType*>(kargs.num_sorted_tiles_ptr));
245 
246  num_sorted_tiles = num_sorted_tiles / BlockShape::Block_M0;
247 
248  const auto [sorted_tile_id, intermediate_tile_id] =
249  Partitioner{}(num_sorted_tiles, kargs.intermediate_size);
250  // if(threadIdx.x == 0)
251  // printf("bid:%d,%d, num_sorted_tiles:%d, sorted_tile_id:%d(%d),
252  // intermediate_tile_id:%d\n", static_cast<int>(blockIdx.x),
253  // static_cast<int>(blockIdx.y), num_sorted_tiles, sorted_tile_id, sorted_tile_id >=
254  // num_sorted_tiles? 1 : 0, intermediate_tile_id);
255  if(sorted_tile_id >= num_sorted_tiles)
256  return;
257 
258  Pipeline{}(kargs, smem, sorted_tile_id, intermediate_tile_id);
259  }
260  else
261  {
262  // allocate LDS
263  // __shared__ char smem_ptr[GetSmemSize()];
264  IndexDataType num_sorted_tiles = __builtin_amdgcn_readfirstlane(
265  *reinterpret_cast<const IndexDataType*>(kargs.num_sorted_tiles_ptr));
266  constexpr index_t hidden_radio_0 = IsGateOnly ? 1 : 2;
267 
268  index_t nr_0 = kargs.intermediate_size / BlockShape::Block_Nr0;
269  index_t kr_0 = kargs.hidden_size / BlockShape::Block_Kr0;
270  index_t nr_1 = kargs.hidden_size / BlockShape::Block_Nr1; // should be same as kr_0
271  index_t kr_1 =
272  kargs.intermediate_size / BlockShape::Block_Kr1; // should be same as nr_0
273 
274  index_t expert_stride_0 = kargs.intermediate_size * hidden_radio_0 * kargs.hidden_size;
275  index_t expert_stride_1 = kargs.intermediate_size * kargs.hidden_size;
276 
277  __shared__ CK_TILE_LDS_ADDR ADataType smem[GetSmemSize()];
278 
279  // note this is in unit of tile, need multiple tile size to get the index
280  const auto [sorted_tile_id, intermediate_tile_id] =
281  Partitioner{}(num_sorted_tiles, kargs.intermediate_size);
282  if(sorted_tile_id >= num_sorted_tiles)
283  return;
284 
285  const IndexDataType expert_id =
286  __builtin_amdgcn_readfirstlane(reinterpret_cast<const IndexDataType*>(
287  kargs.sorted_expert_ids_ptr)[sorted_tile_id]);
288 
289  // index along intermediate_size
290  // index_t hidden_idx = __builtin_amdgcn_readfirstlane(intermediate_tile_id *
291  // BlockShape::Block_N0);
292  index_t interm_idx_nr =
293  __builtin_amdgcn_readfirstlane(intermediate_tile_id * BlockShape::Block_Nr0);
294 
295  const auto a_coord = Pipeline::GetACoord(); // 2d thread offset, [i_row, i_col]
296  const auto sorted_token_id =
297  a_coord[number<0>{}] + sorted_tile_id * BlockShape::Block_M0;
298 
299  index_t token_id =
300  reinterpret_cast<const index_t*>(kargs.sorted_token_ids_ptr)[sorted_token_id];
301 #if CK_TILE_REFERENCE_MOE_SORTING_MOCK_ID
302  token_id &= 0xffffff;
303 #endif
304  auto topk_weight = reinterpret_cast<const TopkWeightDataType*>(
305  kargs.sorted_weight_ptr)[sorted_token_id];
306 
307  const auto a_window = [&]() {
308  // A is already pre-padded in previous kernel
309  const ADataType* a_ptr = reinterpret_cast<const ADataType*>(kargs.a_ptr);
310  const auto a_view_ = make_naive_tensor_view<address_space_enum::global>(
311  a_ptr,
312  make_tuple(kargs.num_tokens, kargs.hidden_size),
313  make_tuple(kargs.stride_token, 1),
315  number<1>{});
316 
317  // gather is here use indexing transform
318  const auto a_gather_view_ = transform_tensor_view(
319  a_view_,
324 
325  const auto a_window_ = make_tile_window(
326  a_gather_view_,
328  {0, 0});
329  return a_window_;
330  }();
331 
332  // TODO: gtile using NSub to have less register pressure
333  const auto g_window = [&]() {
334  const GDataType* g_ptr = reinterpret_cast<const GDataType*>(kargs.g_ptr) +
335  static_cast<long_index_t>(expert_id) * expert_stride_0 +
336  interm_idx_nr * kr_0 * BlockShape::Block_W0;
337  const auto g_view_ = make_naive_tensor_view<address_space_enum::global>(
338  g_ptr,
340  make_tuple(kr_0 * BlockShape::Block_W0, number<BlockShape::Block_W0>{}, 1),
342  number<1>{});
343  const auto g_view_1_ =
344  pad_tensor_view(g_view_,
349 
350  const auto g_window_ = make_tile_window(g_view_1_,
354  {0, 0, 0});
355  return g_window_;
356  }();
357 
358  const auto d_window = [&]() {
359  const DDataType* d_ptr = reinterpret_cast<const DDataType*>(kargs.d_ptr) +
360  static_cast<long_index_t>(expert_id) * expert_stride_1 +
361  interm_idx_nr * BlockShape::Block_W1;
362  // note interm_idx_nr is along the gemm-k dim of 2nd gemm
363 
364  const auto d_view_ = make_naive_tensor_view<address_space_enum::global>(
365  d_ptr,
366  make_tuple(nr_1, kr_1, BlockShape::Block_W1),
367  make_tuple(kr_1 * BlockShape::Block_W1, BlockShape::Block_W1, 1),
369  number<1>{});
370  const auto d_view_1_ =
371  pad_tensor_view(d_view_,
376 
377  const auto d_window_ = make_tile_window(d_view_1_,
381  {0, 0, 0});
382  return d_window_;
383  }();
384 
385  auto o_window = [&]() {
386  ODataType* o_ptr = reinterpret_cast<ODataType*>(kargs.o_ptr);
387  auto o_view_ = make_naive_tensor_view<address_space_enum::global,
389  o_ptr,
390  make_tuple(kargs.num_tokens, kargs.hidden_size),
391  make_tuple(kargs.stride_token, 1),
393  number<1>{});
394 
395  // gather is here
396  auto o_scatter_view_ = transform_tensor_view(
397  o_view_,
402 
403  auto o_window_ = make_tile_window(
404  o_scatter_view_,
406  {0, 0});
407  return o_window_;
408  }();
409 
410  // do compute yeah
411  Pipeline{}(a_window,
412  g_window,
413  d_window,
414  o_window,
415  topk_weight,
416  smem,
417  kargs.hidden_size,
418  kargs.intermediate_size,
419  kargs.stride_token);
420  }
421  }
422 };
423 
424 } // namespace ck_tile
#define CK_TILE_DEVICE
Definition: config.hpp:41
#define CK_TILE_HOST
Definition: config.hpp:40
#define CK_TILE_LDS_ADDR
Definition: config.hpp:58
#define CK_TILE_HOST_DEVICE
Definition: config.hpp:42
#define _TS_
#define _SS_
Definition: cluster_descriptor.hpp:13
_BitInt(8) fp8_t
Definition: float8.hpp:204
_Float16 fp16_t
Definition: half.hpp:110
int8_t int8_t
Definition: int8.hpp:20
constexpr CK_TILE_HOST_DEVICE auto transform_tensor_view(const OldTensorView &old_tensor_view, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition: tensor_view.hpp:511
bfloat16_t bf16_t
Definition: bfloat16.hpp:113
int32_t index_t
Definition: integer.hpp:9
constexpr CK_TILE_HOST_DEVICE auto pad_tensor_view(const TensorView &tensor_view, const TileLengths &tile_lengths, DoPads)
Definition: tensor_view.hpp:530
constexpr CK_TILE_HOST_DEVICE auto make_pass_through_transform(const LowLength &low_length)
Definition: coordinate_transform.hpp:1558
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition: type_traits.hpp:21
int64_t long_index_t
Definition: integer.hpp:11
constexpr CK_TILE_HOST_DEVICE auto make_indexing_transform(const UpLength &up_lengths, const Indices &indices)
Definition: coordinate_transform.hpp:1680
constexpr CK_TILE_DEVICE auto make_tile_window(null_tensor_view, const WindowLengths &window_lengths, const multi_index< WindowLengths::size()> &, Ts &&...)
Definition: null_tile_window.hpp:75
unsigned _BitInt(8) bf8_t
Definition: float8.hpp:206
constexpr CK_TILE_HOST_DEVICE auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:360
constexpr CK_TILE_HOST_DEVICE auto make_naive_tensor_view(DataType *__restrict__ p, const tuple< Lengths... > &lengths, const tuple< Strides... > &strides, number< GuaranteedLastDimensionVectorLength >=number<-1 >{}, number< GuaranteedLastDimensionVectorStride >=number<-1 >{})
Definition: tensor_view.hpp:471
__device__ X atomic_add(X *p_dst, const X &x)
Definition: fused_moegemm_kernel.hpp:98
const void * a_ptr
Definition: fused_moegemm_kernel.hpp:99
index_t num_tokens
Definition: fused_moegemm_kernel.hpp:115
const void * sorted_expert_ids_ptr
Definition: fused_moegemm_kernel.hpp:110
void * o_ptr
Definition: fused_moegemm_kernel.hpp:106
const void * num_sorted_tiles_ptr
Definition: fused_moegemm_kernel.hpp:111
const void * a_scale_ptr
Definition: fused_moegemm_kernel.hpp:100
const void * g_scale_ptr
Definition: fused_moegemm_kernel.hpp:103
const void * sorted_weight_ptr
Definition: fused_moegemm_kernel.hpp:109
const void * d_ptr
Definition: fused_moegemm_kernel.hpp:102
const void * d_scale_ptr
Definition: fused_moegemm_kernel.hpp:104
index_t topk
Definition: fused_moegemm_kernel.hpp:117
index_t num_experts
Definition: fused_moegemm_kernel.hpp:116
const void * sorted_token_ids_ptr
Definition: fused_moegemm_kernel.hpp:108
const void * g_ptr
Definition: fused_moegemm_kernel.hpp:101
index_t intermediate_size
Definition: fused_moegemm_kernel.hpp:114
index_t hidden_size
Definition: fused_moegemm_kernel.hpp:113
const void * y_smooth_scale_ptr
Definition: fused_moegemm_kernel.hpp:105
index_t stride_token
Definition: fused_moegemm_kernel.hpp:119
Definition: fused_moegemm_kernel.hpp:191
index_t topk
Definition: fused_moegemm_kernel.hpp:210
void * o_ptr
Definition: fused_moegemm_kernel.hpp:199
const void * sorted_expert_ids_ptr
Definition: fused_moegemm_kernel.hpp:203
index_t intermediate_size
Definition: fused_moegemm_kernel.hpp:207
index_t hidden_size
Definition: fused_moegemm_kernel.hpp:206
const void * y_smooth_scale_ptr
Definition: fused_moegemm_kernel.hpp:198
const void * a_ptr
Definition: fused_moegemm_kernel.hpp:192
index_t num_tokens
Definition: fused_moegemm_kernel.hpp:208
const void * g_scale_ptr
Definition: fused_moegemm_kernel.hpp:196
const void * d_ptr
Definition: fused_moegemm_kernel.hpp:195
index_t num_experts
Definition: fused_moegemm_kernel.hpp:209
const void * a_scale_ptr
Definition: fused_moegemm_kernel.hpp:193
const void * sorted_weight_ptr
Definition: fused_moegemm_kernel.hpp:202
const void * g_ptr
Definition: fused_moegemm_kernel.hpp:194
index_t stride_token
Definition: fused_moegemm_kernel.hpp:212
const void * num_sorted_tiles_ptr
Definition: fused_moegemm_kernel.hpp:204
const void * d_scale_ptr
Definition: fused_moegemm_kernel.hpp:197
const void * sorted_token_ids_ptr
Definition: fused_moegemm_kernel.hpp:201
Definition: fused_moegemm_kernel.hpp:157
Definition: fused_moegemm_kernel.hpp:125
static constexpr bool UseUK
Definition: fused_moegemm_kernel.hpp:149
typename Pipeline::Problem::ADataType ADataType
Definition: fused_moegemm_kernel.hpp:135
CK_TILE_DEVICE void operator()(Kargs kargs) const
Definition: fused_moegemm_kernel.hpp:238
typename Pipeline::Problem::GDataType GDataType
Definition: fused_moegemm_kernel.hpp:136
typename Pipeline::Problem::Traits Traits
Definition: fused_moegemm_kernel.hpp:148
static constexpr bool PadIntermediateSize
Definition: fused_moegemm_kernel.hpp:154
typename Pipeline::Problem::TopkWeightDataType TopkWeightDataType
Definition: fused_moegemm_kernel.hpp:144
remove_cvref_t< Partitioner_ > Partitioner
Definition: fused_moegemm_kernel.hpp:126
typename Pipeline::Problem::DDataType DDataType
Definition: fused_moegemm_kernel.hpp:137
remove_cvref_t< Pipeline_ > Pipeline
Definition: fused_moegemm_kernel.hpp:127
static constexpr bool UseSmoothQuant
Definition: fused_moegemm_kernel.hpp:152
static constexpr index_t kBlockSize
Definition: fused_moegemm_kernel.hpp:133
typename Pipeline::Problem::DScaleDataType DScaleDataType
Definition: fused_moegemm_kernel.hpp:142
typename Pipeline::Problem::AScaleDataType AScaleDataType
Definition: fused_moegemm_kernel.hpp:140
typename Pipeline::Problem::GScaleDataType GScaleDataType
Definition: fused_moegemm_kernel.hpp:141
typename Pipeline::Problem::AccDataType AccDataType
Definition: fused_moegemm_kernel.hpp:138
typename Pipeline::Problem::ODataType ODataType
Definition: fused_moegemm_kernel.hpp:139
typename Pipeline::Problem::IndexDataType IndexDataType
Definition: fused_moegemm_kernel.hpp:145
static constexpr bool IsGateOnly
Definition: fused_moegemm_kernel.hpp:151
static constexpr bool PadHiddenSize
Definition: fused_moegemm_kernel.hpp:153
typename Pipeline::Problem::YSmoothScaleDataType YSmoothScaleDataType
Definition: fused_moegemm_kernel.hpp:143
remove_cvref_t< Epilogue_ > Epilogue
Definition: fused_moegemm_kernel.hpp:128
static constexpr CK_TILE_HOST_DEVICE index_t GetSmemSize()
Definition: fused_moegemm_kernel.hpp:236
static constexpr CK_TILE_HOST Kargs MakeKargs(const Hargs &hargs)
Definition: fused_moegemm_kernel.hpp:219
static constexpr CK_TILE_HOST auto BlockSize()
Definition: fused_moegemm_kernel.hpp:234
typename Pipeline::Problem::YDataType YDataType
Definition: fused_moegemm_kernel.hpp:146
typename Pipeline::BlockShape BlockShape
Definition: fused_moegemm_kernel.hpp:132
static CK_TILE_HOST std::string GetName()
Definition: fused_moegemm_kernel.hpp:166
static constexpr CK_TILE_HOST auto GridSize(const Hargs &hargs)
Definition: fused_moegemm_kernel.hpp:225
Definition: integral_constant.hpp:13
Definition: sequence.hpp:49