14 template <
typename ADataType,
 
   22           typename ActivationOp = identity>
 
   29                                 const AccDataType* expert_weight_ptr,
 
   44                                 float* expert_bias_ptr)
 
   46     int idx       = blockIdx.x * blockDim.x + threadIdx.x;
 
   47     int problem_N = MoeGemmKind == 1 ? N / 2 : N;
 
   48     int row       = idx / problem_N; 
 
   49     int col       = idx % problem_N; 
 
   55     if(row < p_max_token_id_[0])
 
   57         expert_id        = p_sorted_expert_ids_[row / TokensPerBlock];
 
   58         gather_token_id  = p_sorted_token_ids_[row] & 0xff'ffff;
 
   59         scatter_token_id = p_sorted_token_ids_[row] & 0xff'ffff;
 
   60         if(gather_token_id >= Num_tokens)
 
   66             gather_token_id = gather_token_id * TopK + (p_sorted_token_ids_[row] >> 24);
 
   70             scatter_token_id = scatter_token_id * TopK + (p_sorted_token_ids_[row] >> 24);
 
   80         AccDataType acc    = 0.0;
 
   81         AccDataType acc_up = 0.0;
 
   83         AccDataType acc_temp    = 0.0;
 
   84         AccDataType acc_up_temp = 0.0;
 
   90         index_t scale_A_stride        = (M + scale_granularity_m - 1) / scale_granularity_m;
 
   91         index_t scale_B_stride        = (N + scale_granularity_n - 1) / scale_granularity_n;
 
   92         index_t scale_B_expert_stride = scale_B_stride * K / scale_granularity_k;
 
   94         for(
int k = 0; k < K; ++k)
 
   96             if(k % scale_granularity_k == 0)
 
   99                 acc += acc_temp * scale_A * scale_B;
 
  100                 acc_up += acc_up_temp * scale_A * scale_B_up;
 
  105                 scale_A = scale_A_ptr[(gather_token_id / scale_granularity_m) +
 
  106                                       (k / scale_granularity_k) * scale_A_stride];
 
  108                     scale_B_ptr[expert_id * scale_B_expert_stride + col / scale_granularity_n +
 
  109                                 (k / scale_granularity_k) * scale_B_stride];
 
  110                 if constexpr(MoeGemmKind == 1)
 
  111                     scale_B_up = scale_B_ptr[expert_id * scale_B_expert_stride +
 
  112                                              (col + problem_N) / scale_granularity_n +
 
  113                                              (k / scale_granularity_k) * scale_B_stride];
 
  119             int a_index = (std::is_same_v<LayoutA, tensor_layout::gemm::RowMajor>)
 
  120                               ? gather_token_id * strideA + k
 
  121                               : k * strideA + gather_token_id;
 
  124                 long(expert_id) * N * K +
 
  125                 ((std::is_same_v<LayoutB, tensor_layout::gemm::ColumnMajor>) ? col * strideB + k
 
  126                                                                              : k * strideB + col);
 
  128             if constexpr(MoeGemmKind == 1)
 
  129                 b_index_up = long(expert_id) * N * K +
 
  130                              ((std::is_same_v<LayoutB, tensor_layout::gemm::ColumnMajor>)
 
  131                                   ? (col + problem_N) * strideB + k
 
  132                                   : k * strideB + col + problem_N);
 
  137             if constexpr(std::is_same_v<ADataType, pk_int4_t>)
 
  145             else if constexpr(std::is_same_v<ADataType, pk_fp4_t>)
 
  155                 v_a = ck_tile::type_convert<AccDataType>(A[a_index]);
 
  157             if constexpr(std::is_same_v<BDataType, pk_int4_t>)
 
  164                 if constexpr(MoeGemmKind == 1)
 
  169                         v_b_up = fp32_val_up.hi;
 
  171                         v_b_up = fp32_val_up.lo;
 
  174             else if constexpr(std::is_same_v<BDataType, pk_fp4_t>)
 
  181                 if constexpr(MoeGemmKind == 1)
 
  186                         v_b_up = fp32_val_up.hi;
 
  188                         v_b_up = fp32_val_up.lo;
 
  193                 v_b = ck_tile::type_convert<AccDataType>(B[b_index]);
 
  194                 if constexpr(MoeGemmKind == 1)
 
  195                     v_b_up = ck_tile::type_convert<AccDataType>(B[b_index_up]);
 
  197             acc_temp += v_a * v_b;
 
  198             if constexpr(MoeGemmKind == 1)
 
  199                 acc_up_temp += v_a * v_b_up;
 
  202         acc += acc_temp * scale_A * scale_B;
 
  203         acc_up += acc_up_temp * scale_A * scale_B_up;
 
  205         float bias = 0.f, bias_up = 0.f;
 
  206         if(expert_bias_ptr != 
nullptr)
 
  208             bias = expert_bias_ptr[expert_id * N + col];
 
  209             if constexpr(MoeGemmKind == 1)
 
  210                 bias_up = expert_bias_ptr[expert_id * N + col + problem_N];
 
  213         int c_index = (std::is_same_v<LayoutC, tensor_layout::gemm::RowMajor>)
 
  214                           ? scatter_token_id * strideC + col
 
  215                           : col * strideC + scatter_token_id;
 
  216         if constexpr(MoeGemmKind < 2)
 
  218             C[c_index] = ck_tile::type_convert<CDataType>(
 
  219                 ActivationOp{}(acc + bias, MoeGemmKind == 1 ? acc_up + bias_up : 1));
 
  224             CDataType res = ck_tile::type_convert<CDataType>((acc + bias) * expert_weight_ptr[row]);
 
  225             using ResV2Type = std::conditional_t<std::is_same_v<CDataType, ck_tile::half_t>,
 
  228             ResV2Type add_v{0, 0};
 
  240             atomic_add<ResV2Type>(
reinterpret_cast<ResV2Type*
>(C + (c_index & 0xffff'fffe)), add_v);
 
  245 template <
typename ADataType,
 
  247           typename AccDataType,
 
  253           typename ActivationOp = identity>
 
  255                             const index_t* p_sorted_expert_ids_,
 
  256                             const index_t* p_max_token_id_,
 
  257                             const ADataType* a_ptr,
 
  258                             const BDataType* b_ptr,
 
  260                             const AccDataType* expert_weight_ptr,
 
  275                             float* exp_bias = 
nullptr)
 
  277     int problem_N          = MoeGemmKind == 1 ? N / 2 : N;
 
  278     int totalElements      = M * problem_N;
 
  279     int numThreadsPerBlock = 256; 
 
  280     int numBlocks          = (totalElements + numThreadsPerBlock - 1) / numThreadsPerBlock;
 
  290                     ActivationOp><<<numBlocks, numThreadsPerBlock>>>(p_sorted_token_ids_,
 
  291                                                                      p_sorted_expert_ids_,
 
Definition: cluster_descriptor.hpp:13
 
CK_TILE_HOST_DEVICE fp32x2_t pk_int4_t_to_fp32x2_t(const pk_int4_t &x)
Definition: pk_int4.hpp:105
 
bfloat16_t bf16x2_t
Definition: pk_fp4.hpp:24
 
float fp32x2_t
Definition: pk_fp4.hpp:22
 
int32_t index_t
Definition: integer.hpp:9
 
_Float16 fp16x2_t
Definition: half.hpp:385
 
__global__ void moe_gemm_kernel(const ck_tile::index_t *p_sorted_token_ids_, const ck_tile::index_t *p_sorted_expert_ids_, const ck_tile::index_t *p_max_token_id_, const ADataType *A, const BDataType *B, CDataType *C, const AccDataType *expert_weight_ptr, ck_tile::index_t Num_tokens, ck_tile::index_t TokensPerBlock, ck_tile::index_t TopK, ck_tile::index_t M, ck_tile::index_t N, ck_tile::index_t K, ck_tile::index_t strideA, ck_tile::index_t strideB, ck_tile::index_t strideC, index_t scale_granularity_m, index_t scale_granularity_n, index_t scale_granularity_k, float *scale_A_ptr, float *scale_B_ptr, float *expert_bias_ptr)
Definition: reference_moe_gemm.hpp:23
 
void reference_moe_gemm_gpu(const index_t *p_sorted_token_ids_, const index_t *p_sorted_expert_ids_, const index_t *p_max_token_id_, const ADataType *a_ptr, const BDataType *b_ptr, CDataType *c_ptr, const AccDataType *expert_weight_ptr, index_t Num_tokens, index_t TokensPerBlock, index_t TopK, index_t M, index_t N, index_t K, index_t stride_a, index_t stride_b, index_t stride_c, index_t scale_granularity_m, index_t scale_granularity_n, index_t scale_granularity_k, float *scale_A_ptr, float *scale_B_ptr, float *exp_bias=nullptr)
Definition: reference_moe_gemm.hpp:254
 
constexpr CK_TILE_HOST_DEVICE fp32x2_t pk_fp4_to_fp32x2(const pk_fp4_t &x, float scale)
Definition: pk_fp4.hpp:350
 
Definition: numeric.hpp:81