/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/host/reference/reference_moe_gemm.hpp File Reference

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/host/reference/reference_moe_gemm.hpp File Reference#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/host/reference/reference_moe_gemm.hpp File Reference
reference_moe_gemm.hpp File Reference
#include <cstdlib>
#include <thread>
#include "ck_tile/core.hpp"
#include "ck_tile/host/host_tensor.hpp"

Go to the source code of this file.

Namespaces

 ck_tile
 

Functions

template<typename ADataType , typename BDataType , typename AccDataType , typename CDataType , typename LayoutA , typename LayoutB , typename LayoutC , int MoeGemmKind = 0, typename ActivationOp = identity>
__global__ void ck_tile::moe_gemm_kernel (const ck_tile::index_t *p_sorted_token_ids_, const ck_tile::index_t *p_sorted_expert_ids_, const ck_tile::index_t *p_max_token_id_, const ADataType *A, const BDataType *B, CDataType *C, const AccDataType *expert_weight_ptr, ck_tile::index_t Num_tokens, ck_tile::index_t TokensPerBlock, ck_tile::index_t TopK, ck_tile::index_t M, ck_tile::index_t N, ck_tile::index_t K, ck_tile::index_t strideA, ck_tile::index_t strideB, ck_tile::index_t strideC, index_t scale_granularity_m, index_t scale_granularity_n, index_t scale_granularity_k, float *scale_A_ptr, float *scale_B_ptr, float *expert_bias_ptr)
 
template<typename ADataType , typename BDataType , typename AccDataType , typename CDataType , typename LayoutA , typename LayoutB , typename LayoutC , int MoeGemmKind = 0, typename ActivationOp = identity>
void ck_tile::reference_moe_gemm_gpu (const index_t *p_sorted_token_ids_, const index_t *p_sorted_expert_ids_, const index_t *p_max_token_id_, const ADataType *a_ptr, const BDataType *b_ptr, CDataType *c_ptr, const AccDataType *expert_weight_ptr, index_t Num_tokens, index_t TokensPerBlock, index_t TopK, index_t M, index_t N, index_t K, index_t stride_a, index_t stride_b, index_t stride_c, index_t scale_granularity_m, index_t scale_granularity_n, index_t scale_granularity_k, float *scale_A_ptr, float *scale_B_ptr, float *exp_bias=nullptr)