/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/fused_moe/kernel/fused_moegemm_kernel.hpp Source File#
fused_moegemm_kernel.hpp
Go to the documentation of this file.
28 // sorted_token_ids_ptr : [0, 6, 6, 6, 2, 3, 4, 6, 1, 3, 6, 6, 0, 1, 2, 3, 4, 6, 6, 6, 6, 6, 6, 6, 0, 1, 2, 5]
30 // sorted_weight_ptr : [a, *, *, *, g, j, m, *, d, k, *, *, b, e, h, l, n, *, *, *, *, *, *, *, c, f, i, o]
35 // currently we do not have topk information from the data of token_id_per_expert/sorted_token_ids_ptr.
37 // different expert smooth quant. So we modify the number stored inside token_id_per_expert/sorted_token_ids_ptr
43 // the input after smooth-quant is [token, topk, hidden_dim], originally it is [token, hidden_dim]
44 // the input scale for token is [topk, token, 1], the smooth-quant scale for first gemm is [expert, interm_dim]
182 _TS_(S_::Block_M0) + "x" + _TS_(S_::Block_N0) + "x" + _TS_(S_::Block_K0) + "x" + _TS_(S_::Block_N1) + "_" +
183 _TS_(S_::WarpPerBlock_M0) + "x" + _TS_(S_::WarpPerBlock_N0) + "x" + _TS_(S_::WarpPerBlock_K0) + "_" +
184 _TS_(S_::Warp_M0) + "x" + _TS_(S_::Warp_N0) + "x" + _TS_(S_::Warp_K0) + "_" + _SS_(Pipeline::name);
#define _TS_
#define _SS_
Definition: cluster_descriptor.hpp:13
constexpr CK_TILE_HOST_DEVICE auto transform_tensor_view(const OldTensorView &old_tensor_view, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition: tensor_view.hpp:511
constexpr CK_TILE_HOST_DEVICE auto pad_tensor_view(const TensorView &tensor_view, const TileLengths &tile_lengths, DoPads)
Definition: tensor_view.hpp:530
constexpr CK_TILE_HOST_DEVICE auto make_pass_through_transform(const LowLength &low_length)
Definition: coordinate_transform.hpp:1558
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition: type_traits.hpp:21
constexpr CK_TILE_HOST_DEVICE auto make_indexing_transform(const UpLength &up_lengths, const Indices &indices)
Definition: coordinate_transform.hpp:1680
constexpr CK_TILE_DEVICE auto make_tile_window(null_tensor_view, const WindowLengths &window_lengths, const multi_index< WindowLengths::size()> &, Ts &&...)
Definition: null_tile_window.hpp:75
constexpr CK_TILE_HOST_DEVICE auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:360
constexpr CK_TILE_HOST_DEVICE auto make_naive_tensor_view(DataType *__restrict__ p, const tuple< Lengths... > &lengths, const tuple< Strides... > &strides, number< GuaranteedLastDimensionVectorLength >=number<-1 >{}, number< GuaranteedLastDimensionVectorStride >=number<-1 >{})
Definition: tensor_view.hpp:471
__device__ X atomic_add(X *p_dst, const X &x)
Definition: fused_moegemm_kernel.hpp:98
const void * sorted_expert_ids_ptr
Definition: fused_moegemm_kernel.hpp:110
const void * num_sorted_tiles_ptr
Definition: fused_moegemm_kernel.hpp:111
const void * a_scale_ptr
Definition: fused_moegemm_kernel.hpp:100
const void * g_scale_ptr
Definition: fused_moegemm_kernel.hpp:103
const void * sorted_weight_ptr
Definition: fused_moegemm_kernel.hpp:109
const void * d_scale_ptr
Definition: fused_moegemm_kernel.hpp:104
index_t num_experts
Definition: fused_moegemm_kernel.hpp:116
const void * sorted_token_ids_ptr
Definition: fused_moegemm_kernel.hpp:108
index_t intermediate_size
Definition: fused_moegemm_kernel.hpp:114
index_t hidden_size
Definition: fused_moegemm_kernel.hpp:113
const void * y_smooth_scale_ptr
Definition: fused_moegemm_kernel.hpp:105
index_t stride_token
Definition: fused_moegemm_kernel.hpp:119
Definition: fused_moegemm_kernel.hpp:191
index_t topk
Definition: fused_moegemm_kernel.hpp:210
void * o_ptr
Definition: fused_moegemm_kernel.hpp:199
const void * sorted_expert_ids_ptr
Definition: fused_moegemm_kernel.hpp:203
index_t intermediate_size
Definition: fused_moegemm_kernel.hpp:207
index_t hidden_size
Definition: fused_moegemm_kernel.hpp:206
const void * y_smooth_scale_ptr
Definition: fused_moegemm_kernel.hpp:198
const void * a_ptr
Definition: fused_moegemm_kernel.hpp:192
index_t num_tokens
Definition: fused_moegemm_kernel.hpp:208
const void * g_scale_ptr
Definition: fused_moegemm_kernel.hpp:196
const void * d_ptr
Definition: fused_moegemm_kernel.hpp:195
index_t num_experts
Definition: fused_moegemm_kernel.hpp:209
const void * a_scale_ptr
Definition: fused_moegemm_kernel.hpp:193
const void * sorted_weight_ptr
Definition: fused_moegemm_kernel.hpp:202
const void * g_ptr
Definition: fused_moegemm_kernel.hpp:194
index_t stride_token
Definition: fused_moegemm_kernel.hpp:212
const void * num_sorted_tiles_ptr
Definition: fused_moegemm_kernel.hpp:204
const void * d_scale_ptr
Definition: fused_moegemm_kernel.hpp:197
const void * sorted_token_ids_ptr
Definition: fused_moegemm_kernel.hpp:201
Definition: fused_moegemm_kernel.hpp:157
Definition: fused_moegemm_kernel.hpp:125
static constexpr bool UseUK
Definition: fused_moegemm_kernel.hpp:149
typename Pipeline::Problem::ADataType ADataType
Definition: fused_moegemm_kernel.hpp:135
CK_TILE_DEVICE void operator()(Kargs kargs) const
Definition: fused_moegemm_kernel.hpp:238
typename Pipeline::Problem::GDataType GDataType
Definition: fused_moegemm_kernel.hpp:136
typename Pipeline::Problem::Traits Traits
Definition: fused_moegemm_kernel.hpp:148
static constexpr bool PadIntermediateSize
Definition: fused_moegemm_kernel.hpp:154
typename Pipeline::Problem::TopkWeightDataType TopkWeightDataType
Definition: fused_moegemm_kernel.hpp:144
remove_cvref_t< Partitioner_ > Partitioner
Definition: fused_moegemm_kernel.hpp:126
typename Pipeline::Problem::DDataType DDataType
Definition: fused_moegemm_kernel.hpp:137
remove_cvref_t< Pipeline_ > Pipeline
Definition: fused_moegemm_kernel.hpp:127
static constexpr bool UseSmoothQuant
Definition: fused_moegemm_kernel.hpp:152
static constexpr index_t kBlockSize
Definition: fused_moegemm_kernel.hpp:133
typename Pipeline::Problem::DScaleDataType DScaleDataType
Definition: fused_moegemm_kernel.hpp:142
typename Pipeline::Problem::AScaleDataType AScaleDataType
Definition: fused_moegemm_kernel.hpp:140
typename Pipeline::Problem::GScaleDataType GScaleDataType
Definition: fused_moegemm_kernel.hpp:141
typename Pipeline::Problem::AccDataType AccDataType
Definition: fused_moegemm_kernel.hpp:138
typename Pipeline::Problem::ODataType ODataType
Definition: fused_moegemm_kernel.hpp:139
typename Pipeline::Problem::IndexDataType IndexDataType
Definition: fused_moegemm_kernel.hpp:145
static constexpr bool IsGateOnly
Definition: fused_moegemm_kernel.hpp:151
static constexpr bool PadHiddenSize
Definition: fused_moegemm_kernel.hpp:153
typename Pipeline::Problem::YSmoothScaleDataType YSmoothScaleDataType
Definition: fused_moegemm_kernel.hpp:143
remove_cvref_t< Epilogue_ > Epilogue
Definition: fused_moegemm_kernel.hpp:128
static constexpr CK_TILE_HOST_DEVICE index_t GetSmemSize()
Definition: fused_moegemm_kernel.hpp:236
static constexpr CK_TILE_HOST Kargs MakeKargs(const Hargs &hargs)
Definition: fused_moegemm_kernel.hpp:219
static constexpr CK_TILE_HOST auto BlockSize()
Definition: fused_moegemm_kernel.hpp:234
typename Pipeline::Problem::YDataType YDataType
Definition: fused_moegemm_kernel.hpp:146
typename Pipeline::BlockShape BlockShape
Definition: fused_moegemm_kernel.hpp:132
static CK_TILE_HOST std::string GetName()
Definition: fused_moegemm_kernel.hpp:166
static constexpr CK_TILE_HOST auto GridSize(const Hargs &hargs)
Definition: fused_moegemm_kernel.hpp:225
Definition: integral_constant.hpp:13
Definition: sequence.hpp:49