14 template <
typename ADataType,
 
   19           typename QuantGroupSize,
 
   28                                        const AElementOp& a_element_op     = {},
 
   29                                        const BElementOp& b_element_op     = {},
 
   30                                        const ACCElementOp& acc_element_op = {})
 
   36     auto f_mn = [&](
auto m, 
auto n) {
 
   37         AccDataType v_acc = 0, v_block_acc = 0;
 
   39         static_assert(std::is_same_v<ADataType, pk_int4_t> || std::is_same_v<ADataType, fp8_t> ||
 
   40                       std::is_same_v<ADataType, bf8_t>);
 
   41         static_assert(std::is_same_v<BDataType, fp8_t> || std::is_same_v<BDataType, bf8_t> ||
 
   42                       std::is_same_v<BDataType, pk_int4_t>);
 
   43         static_assert(std::is_same_v<AccDataType, float>);
 
   44         static_assert(std::is_same_v<CDataType, float> ||
 
   45                       std::is_same_v<CDataType, ck_tile::half_t>);
 
   46         for(std::size_t k = 0; k < K; ++k)
 
   50             if constexpr(std::is_same_v<ADataType, pk_int4_t>)
 
   52                 const pk_int4_t pk_val  = a_element_op(a_m_k(m, k));
 
   61                 v_a = ck_tile::type_convert<AccDataType>(a_element_op(a_m_k(m, k)));
 
   63             if constexpr(std::is_same_v<BDataType, pk_int4_t>)
 
   65                 const pk_int4_t pk_val  = b_element_op(b_k_n(k, n));
 
   72             else if constexpr(std::is_same_v<BDataType, fp8_t>)
 
   78                 v_b = ck_tile::type_convert<AccDataType>(b_element_op(b_k_n(k, n)));
 
   80             v_block_acc += v_a * v_b;
 
   83             if((k + 1) % QuantGroupSize::kK == 0)
 
   86                 index_t outer_dim = (aquant) ? (m / QuantGroupSize::kM) : (k / QuantGroupSize::kK);
 
   87                 index_t inner_dim = (aquant) ? (k / QuantGroupSize::kK) : (n / QuantGroupSize::kN);
 
   88                 if constexpr(std::is_same_v<QDataType, float>)
 
   90                     scale = q(outer_dim, inner_dim);
 
   92                 else if constexpr(std::is_same_v<QDataType, ck_tile::fp8_t>)
 
   96                 else if constexpr(std::is_same_v<QDataType, ck_tile::bf8_t>)
 
  102                     static_assert(
false, 
"Unexpected Q datatype.");
 
  104                 v_block_acc *= scale;
 
  105                 v_acc += v_block_acc;
 
  110         c_m_n(m, n) = ck_tile::type_convert<CDataType>(acc_element_op(v_acc));
 
  114     std::cout << std::endl;
 
  117 template <
typename ADataType,
 
  121           typename AccDataType,
 
  131                                               const AElementOp& a_element_op     = {},
 
  132                                               const BElementOp& b_element_op     = {},
 
  133                                               const ACCElementOp& acc_element_op = {})
 
  135     static_assert(std::is_same_v<ADataType, fp8_t> || std::is_same_v<ADataType, bf8_t>);
 
  136     static_assert(std::is_same_v<BDataType, fp8_t> || std::is_same_v<BDataType, bf8_t>);
 
  137     static_assert(std::is_same_v<AccDataType, float>);
 
  138     static_assert(std::is_same_v<CDataType, float> || std::is_same_v<CDataType, ck_tile::half_t>);
 
  139     static_assert(std::is_same_v<AQDataType, float> && std::is_same_v<BQDataType, float>);
 
  144     auto f_mn = [&](
auto m, 
auto n) {
 
  146         AccDataType v_acc = 0;
 
  148         float a_scale = aq_m_1(m, 0);
 
  149         float b_scale = bq_1_n(0, n);
 
  152         for(std::size_t k = 0; k < K; ++k)
 
  158             if constexpr(std::is_same_v<ADataType, pk_int4_t>)
 
  160                 const pk_int4_t pk_val  = a_element_op(a_m_k(m, k));
 
  169                 v_a = ck_tile::type_convert<AccDataType>(a_element_op(a_m_k(m, k)));
 
  173             if constexpr(std::is_same_v<BDataType, pk_int4_t>)
 
  175                 const pk_int4_t pk_val  = b_element_op(b_k_n(k, n));
 
  184                 v_b = ck_tile::type_convert<AccDataType>(b_element_op(b_k_n(k, n)));
 
  190         v_acc = v_acc * a_scale * b_scale;
 
  192         c_m_n(m, n) = ck_tile::type_convert<CDataType>(acc_element_op(v_acc));
 
  198 template <
typename ADataType,
 
  202           typename AccDataType,
 
  212                                               const AElementOp& a_element_op     = {},
 
  213                                               const BElementOp& b_element_op     = {},
 
  214                                               const ACCElementOp& acc_element_op = {})
 
  216     static_assert(std::is_same_v<ADataType, fp8_t> || std::is_same_v<ADataType, bf8_t>);
 
  217     static_assert(std::is_same_v<BDataType, fp8_t> || std::is_same_v<BDataType, bf8_t>);
 
  218     static_assert(std::is_same_v<AccDataType, float>);
 
  219     static_assert(std::is_same_v<CDataType, float> || std::is_same_v<CDataType, ck_tile::half_t>);
 
  220     static_assert(std::is_same_v<AQDataType, float> && std::is_same_v<BQDataType, float>);
 
  225     auto f_mn = [&](
auto m, 
auto n) {
 
  227         AccDataType v_acc = 0;
 
  229         const AccDataType a_scale = ck_tile::type_convert<AccDataType>(aq_1_1(0, 0));
 
  230         const AccDataType b_scale = ck_tile::type_convert<AccDataType>(bq_1_1(0, 0));
 
  233         for(std::size_t k = 0; k < K; ++k)
 
  235             AccDataType v_a = ck_tile::type_convert<AccDataType>(a_element_op(a_m_k(m, k)));
 
  236             AccDataType v_b = ck_tile::type_convert<AccDataType>(b_element_op(b_k_n(k, n)));
 
  241         v_acc = v_acc * a_scale * b_scale;
 
  243         c_m_n(m, n) = ck_tile::type_convert<CDataType>(acc_element_op(v_acc));
 
  249 template <
typename ADataType,
 
  251           typename AccDataType,
 
  259                                  const AElementOp& a_element_op     = {},
 
  260                                  const BElementOp& b_element_op     = {},
 
  261                                  const ACCElementOp& acc_element_op = {})
 
  267     auto f_mn = [&](
auto m, 
auto n) {
 
  268         AccDataType v_acc = 0;
 
  270         for(std::size_t k = 0; k < K; ++k)
 
  274             if constexpr(std::is_same_v<ADataType, pk_int4_t>)
 
  276                 const pk_int4_t pk_val  = a_element_op(a_m_k(m, k));
 
  285                 v_a = ck_tile::type_convert<AccDataType>(a_element_op(a_m_k(m, k)));
 
  287             if constexpr(std::is_same_v<BDataType, pk_int4_t>)
 
  289                 const pk_int4_t pk_val  = b_element_op(b_k_n(k, n));
 
  298                 v_b = ck_tile::type_convert<AccDataType>(b_element_op(b_k_n(k, n)));
 
  303         c_m_n(m, n) = ck_tile::type_convert<CDataType>(acc_element_op(v_acc));
 
  309 template <
typename AsDataType,
 
  312           typename AccDataType,
 
  316           typename CDElementOp,
 
  317           typename ADataType = remove_cvref_t<std::tuple_element_t<0, AsDataType>>,
 
  318           typename BDataType = remove_cvref_t<std::tuple_element_t<0, BsDataType>>,
 
  319           typename DDataType = remove_cvref_t<std::tuple_element_t<0, DsDataType>>>
 
  327                             const AElementOp& a_element_op    = {},
 
  328                             const BElementOp& b_element_op    = {},
 
  329                             const CDElementOp& acc_element_op = {})
 
  336         generate_tie([&](
auto idx) -> 
auto& { 
return as_m_k[idx]; }, 
number<AsDataType::size()>{});
 
  339         generate_tie([&](
auto idx) -> 
auto& { 
return bs_k_n[idx]; }, 
number<BsDataType::size()>{});
 
  342         generate_tie([&](
auto idx) -> 
auto& { 
return ds_m_n[idx]; }, 
number<DsDataType::size()>{});
 
  345     auto a_elementwise_fn = [&](
auto i, 
auto j) {
 
  346         ck_tile::apply([&](
auto&&... t) { a_element_op(a_m_k(i, j), t(i, j)...); }, as_m_k_tuple);
 
  352     auto b_elementwise_fn = [&](
auto i, 
auto j) {
 
  353         ck_tile::apply([&](
auto&&... t) { b_element_op(b_k_n(i, j), t(i, j)...); }, bs_k_n_tuple);
 
  358     auto f_mk_kn_mn = [&](
auto m, 
auto n) {
 
  359         AccDataType v_acc = 0;
 
  360         for(std::size_t k = 0; k < K; ++k)
 
  362             ADataType v_a = a_m_k(m, k);
 
  363             BDataType v_b = b_k_n(k, n);
 
  365                 ck_tile::type_convert<AccDataType>(v_a) * ck_tile::type_convert<AccDataType>(v_b);
 
  373                                ck_tile::type_convert<float>(v_acc),
 
  374                                ck_tile::type_convert<float>(t(m, n))...);
 
  378         c_m_n(m, n) = ck_tile::type_convert<CDataType>(v_c);
 
  384 template <
typename ADataType,
 
  386           typename ScaleDataType,
 
  387           typename AccDataType,
 
  397                                     const AElementOp&   = {},
 
  398                                     const BElementOp&   = {},
 
  399                                     const ACCElementOp& = {})
 
  401     static_assert(std::is_same_v<AElementOp, ck_tile::identity>);
 
  402     static_assert(std::is_same_v<BElementOp, ck_tile::identity>);
 
  403     static_assert(std::is_same_v<ACCElementOp, ck_tile::identity>);
 
  409     const std::size_t ScaleBlockSize = K / scale_a.
get_length(1);
 
  411     HostTensor<AccDataType> a_m_k_scaled({std::size_t(M), std::size_t(K)},
 
  412                                          {std::size_t(K), std::size_t(1)});
 
  413     HostTensor<AccDataType> b_k_n_scaled({std::size_t(K), std::size_t(N)},
 
  414                                          {std::size_t(1), std::size_t(K)});
 
  416     for(std::size_t m = 0; m < M; ++m)
 
  418         for(std::size_t k = 0; k < K; ++k)
 
  420             if constexpr(std::is_same_v<ADataType, pk_fp4_t>)
 
  425                 auto a_f4x2  = a_m_k(m, k);
 
  426                 auto a_scale = ck_tile::type_convert<AccDataType>(scale_a(m, k / ScaleBlockSize));
 
  428                     ck_tile::type_convert<AccDataType>(a_f4x2.template unpack<>(number<0>{}));
 
  430                     ck_tile::type_convert<AccDataType>(a_f4x2.template unpack<>(number<1>{}));
 
  432                 a_m_k_scaled(m, k)     = a_f4_lo * a_scale;
 
  433                 a_m_k_scaled(m, k + 1) = a_f4_hi * a_scale;
 
  438     for(std::size_t n = 0; n < N; n++)
 
  440         for(std::size_t k = 0; k < K; k++)
 
  442             if constexpr(std::is_same_v<BDataType, pk_fp4_t>)
 
  447                 auto b_f4x2  = b_k_n(k, n);
 
  448                 auto b_scale = ck_tile::type_convert<AccDataType>(scale_b(k / ScaleBlockSize, n));
 
  450                     ck_tile::type_convert<AccDataType>(b_f4x2.template unpack<>(number<0>{}));
 
  452                     ck_tile::type_convert<AccDataType>(b_f4x2.template unpack<>(number<1>{}));
 
  454                 b_k_n_scaled(k, n)     = b_f4_lo * b_scale;
 
  455                 b_k_n_scaled(k + 1, n) = b_f4_hi * b_scale;
 
  460                     ck_tile::type_convert<AccDataType>((b_k_n(k, n))) *
 
  461                     ck_tile::type_convert<AccDataType>(scale_b(k / ScaleBlockSize, n));
 
  467     reference_gemm<AccDataType, AccDataType, AccDataType, CDataType>(
 
  468         a_m_k_scaled, b_k_n_scaled, c_m_n);
 
  471 template <
typename ADataType,
 
  474           typename AccDataType,
 
  476           typename ACCElementOp,
 
  477           typename DDataType = remove_cvref_t<std::tuple_element_t<0, DsDataType>>>
 
  483                           const ACCElementOp& acc_element_op = {})
 
  489     auto f_mk_kn_mn = [&](
auto m, 
auto n) {
 
  490         AccDataType v_acc = 0;
 
  491         for(std::size_t k = 0; k < K; ++k)
 
  493             ADataType v_a = a_m_k(m, k);
 
  494             BDataType v_b = b_k_n(k, n);
 
  496                 ck_tile::type_convert<AccDataType>(v_a) * ck_tile::type_convert<AccDataType>(v_b);
 
  500         if constexpr(DsDataType::size() == 0)
 
  502             acc_element_op(v_c, ck_tile::type_convert<float>(v_acc));
 
  504         else if constexpr(DsDataType::size() == 1)
 
  507                            ck_tile::type_convert<float>(v_acc),
 
  508                            ck_tile::type_convert<float>(ds_m_n[0](m, n)));
 
  510         else if constexpr(DsDataType::size() == 2)
 
  513                            ck_tile::type_convert<float>(v_acc),
 
  514                            ck_tile::type_convert<float>(ds_m_n[0](m, n)),
 
  515                            ck_tile::type_convert<float>(ds_m_n[1](m, n)));
 
  517         c_m_n(m, n) = ck_tile::type_convert<CDataType>(v_c);
 
  523 template <
typename ADataType,
 
  525           typename AccDataType,
 
  540     int idx = blockIdx.x * blockDim.x + threadIdx.x;
 
  544     if(row < M && col < N)
 
  546         AccDataType acc = 0.0;
 
  547         for(
int k = 0; k < K; ++k)
 
  552             int a_index = (std::is_same_v<LayoutA, tensor_layout::gemm::RowMajor>)
 
  555             int b_index = (std::is_same_v<LayoutB, tensor_layout::gemm::ColumnMajor>)
 
  561             if constexpr(std::is_same_v<ADataType, pk_int4_t>)
 
  569             else if constexpr(std::is_same_v<ADataType, pk_fp4_t>)
 
  579                 v_a = ck_tile::type_convert<AccDataType>(A[a_index]);
 
  581             if constexpr(std::is_same_v<BDataType, pk_int4_t>)
 
  589             else if constexpr(std::is_same_v<BDataType, pk_fp4_t>)
 
  599                 v_b = ck_tile::type_convert<AccDataType>(B[b_index]);
 
  604         int c_index = (std::is_same_v<LayoutC, tensor_layout::gemm::RowMajor>)
 
  605                           ? row * strideC + col
 
  606                           : col * strideC + row;
 
  607         C[c_index]  = ck_tile::type_convert<CDataType>(acc);
 
  611 template <
typename ADataType,
 
  613           typename AccDataType,
 
  633     int idx = blockIdx.x * blockDim.x + threadIdx.x;
 
  637     if(row < M && col < N)
 
  639         AccDataType acc = 0.0, acc_temp = 0.0;
 
  641         index_t scale_A_stride = (M + scale_granularity_m - 1) / scale_granularity_m;
 
  642         index_t scale_B_stride = (N + scale_granularity_n - 1) / scale_granularity_n;
 
  647         for(
int k = 0; k < K; ++k)
 
  649             if(k % scale_granularity_k == 0)
 
  652                 acc += acc_temp * scale_A * scale_B;
 
  655                 scale_A = scale_A_ptr[(row / scale_granularity_m) +
 
  656                                       (k / scale_granularity_k) * scale_A_stride];
 
  657                 scale_B = scale_B_ptr[(col / scale_granularity_n) +
 
  658                                       (k / scale_granularity_k) * scale_B_stride];
 
  664             int a_index = (std::is_same_v<LayoutA, tensor_layout::gemm::RowMajor>)
 
  667             int b_index = (std::is_same_v<LayoutB, tensor_layout::gemm::ColumnMajor>)
 
  673             if constexpr(std::is_same_v<ADataType, pk_int4_t>)
 
  681             else if constexpr(std::is_same_v<ADataType, pk_fp4_t>)
 
  691                 v_a = ck_tile::type_convert<AccDataType>(A[a_index]);
 
  694             if constexpr(std::is_same_v<BDataType, pk_int4_t>)
 
  702             else if constexpr(std::is_same_v<BDataType, pk_fp4_t>)
 
  712                 v_b = ck_tile::type_convert<AccDataType>(B[b_index]);
 
  714             acc_temp += v_a * v_b;
 
  717         acc += acc_temp * scale_A * scale_B;
 
  719         int c_index = (std::is_same_v<LayoutC, tensor_layout::gemm::RowMajor>)
 
  720                           ? row * strideC + col
 
  721                           : col * strideC + row;
 
  722         C[c_index]  = ck_tile::type_convert<CDataType>(acc);
 
  726 template <
typename ADataType,
 
  728           typename AccDataType,
 
  743     int totalElements      = M * N;
 
  744     int numThreadsPerBlock = 256; 
 
  745     int numBlocks          = (totalElements + numThreadsPerBlock - 1) / numThreadsPerBlock;
 
  747     naive_gemm_kernel<ADataType, BDataType, AccDataType, CDataType, LayoutA, LayoutB, LayoutC>
 
  748         <<<numBlocks, numThreadsPerBlock>>>(
 
  749             a_ptr, b_ptr, c_ptr, M, N, K, stride_a, stride_b, stride_c);
 
  754 template <
typename ADataType,
 
  756           typename AccDataType,
 
  776     int totalElements      = M * N;
 
  777     int numThreadsPerBlock = 256; 
 
  778     int numBlocks          = (totalElements + numThreadsPerBlock - 1) / numThreadsPerBlock;
 
  780     blockwise_gemm_kernel<ADataType, BDataType, AccDataType, CDataType, LayoutA, LayoutB, LayoutC>
 
  781         <<<numBlocks, numThreadsPerBlock>>>(a_ptr,
 
  799 template <
typename ADataType,
 
  801           typename AccDataType,
 
  820     int totalElements      = M * N;
 
  821     int numThreadsPerBlock = 256; 
 
  822     int numBlocks          = (totalElements + numThreadsPerBlock - 1) / numThreadsPerBlock;
 
  824     for(
index_t batch_id = 0; batch_id < batch_count; ++batch_id)
 
  826         ADataType* d_ATemp = a_ptr + batch_id * batch_stride_A;
 
  827         BDataType* d_BTemp = b_ptr + batch_id * batch_stride_B;
 
  828         CDataType* d_CTemp = c_ptr + batch_id * batch_stride_C;
 
  829         naive_gemm_kernel<ADataType, BDataType, AccDataType, CDataType, LayoutA, LayoutB, LayoutC>
 
  830             <<<numBlocks, numThreadsPerBlock>>>(
 
  831                 d_ATemp, d_BTemp, d_CTemp, M, N, K, stride_a, stride_b, stride_c);
 
#define CK_TILE_HOST
Definition: config.hpp:40
 
Definition: cluster_descriptor.hpp:13
 
void reference_batched_gemm_gpu(ADataType *a_ptr, BDataType *b_ptr, CDataType *c_ptr, index_t M, index_t N, index_t K, index_t stride_a, index_t stride_b, index_t stride_c, index_t batch_stride_A, index_t batch_stride_B, index_t batch_stride_C, index_t batch_count)
Definition: reference_gemm.hpp:806
 
CK_TILE_HOST auto make_ParallelTensorFunctor(F f, Xs... xs)
Definition: host_tensor.hpp:329
 
constexpr decltype(auto) apply(F &&f, Tuple &&t)
Definition: tuple.hpp:526
 
__global__ void naive_gemm_kernel(ADataType *A, BDataType *B, CDataType *C, ck_tile::index_t M, ck_tile::index_t N, ck_tile::index_t K, ck_tile::index_t strideA, ck_tile::index_t strideB, ck_tile::index_t strideC)
Definition: reference_gemm.hpp:530
 
CK_TILE_HOST_DEVICE fp32x2_t pk_int4_t_to_fp32x2_t(const pk_int4_t &x)
Definition: pk_int4.hpp:105
 
CK_TILE_HOST_DEVICE float fp8_to_float_raw(uint8_t)
Definition: float8.hpp:751
 
CK_TILE_HOST void reference_gemm_quant(const HostTensor< ADataType > &a_m_k, const HostTensor< QDataType > &q, const HostTensor< BDataType > &b_k_n, HostTensor< CDataType > &c_m_n, const AElementOp &a_element_op={}, const BElementOp &b_element_op={}, const ACCElementOp &acc_element_op={})
Definition: reference_gemm.hpp:24
 
CK_TILE_HOST_DEVICE float bf8_to_float_raw(uint8_t)
Definition: float8.hpp:764
 
CK_TILE_HOST void reference_gemm_multiple_abd(const std::array< HostTensor< ADataType >, AsDataType::size()> &as_m_k, const std::array< HostTensor< BDataType >, BsDataType::size()> &bs_k_n, const std::array< HostTensor< DDataType >, DsDataType::size()> &ds_m_n, HostTensor< ADataType > &a_m_k, HostTensor< BDataType > &b_k_n, HostTensor< CDataType > &c_m_n, const AElementOp &a_element_op={}, const BElementOp &b_element_op={}, const CDElementOp &acc_element_op={})
Definition: reference_gemm.hpp:321
 
float fp32x2_t
Definition: pk_fp4.hpp:22
 
void reference_blockwise_gemm_gpu(ADataType *a_ptr, BDataType *b_ptr, CDataType *c_ptr, index_t M, index_t N, index_t K, index_t stride_a, index_t stride_b, index_t stride_c, index_t scale_granularity_m, index_t scale_granularity_n, index_t scale_granularity_k, float *scale_A_ptr, float *scale_B_ptr)
Definition: reference_gemm.hpp:761
 
int32_t index_t
Definition: integer.hpp:9
 
CK_TILE_HOST void reference_gemm_rowcol_quant(const HostTensor< ADataType > &a_m_k, const HostTensor< AQDataType > &aq_m_1, const HostTensor< BDataType > &b_k_n, const HostTensor< BQDataType > &bq_1_n, HostTensor< CDataType > &c_m_n, const AElementOp &a_element_op={}, const BElementOp &b_element_op={}, const ACCElementOp &acc_element_op={})
Definition: reference_gemm.hpp:126
 
constexpr CK_TILE_HOST_DEVICE auto generate_tie(F &&f, number< N >)
Definition: tuple.hpp:435
 
constant< v > number
Definition: integral_constant.hpp:37
 
CK_TILE_HOST_DEVICE fp32x2_t pk_int4_t_to_fp32x2_t_signed_conversion(const pk_int4_t &x)
Definition: pk_int4.hpp:120
 
__global__ void blockwise_gemm_kernel(ADataType *A, BDataType *B, CDataType *C, ck_tile::index_t M, ck_tile::index_t N, ck_tile::index_t K, ck_tile::index_t strideA, ck_tile::index_t strideB, ck_tile::index_t strideC, ck_tile::index_t scale_granularity_m, ck_tile::index_t scale_granularity_n, ck_tile::index_t scale_granularity_k, float *scale_A_ptr, float *scale_B_ptr)
Definition: reference_gemm.hpp:618
 
constexpr CK_TILE_HOST_DEVICE fp32x2_t pk_fp4_to_fp32x2(const pk_fp4_t &x, float scale)
Definition: pk_fp4.hpp:350
 
void reference_gemm_gpu(ADataType *a_ptr, BDataType *b_ptr, CDataType *c_ptr, index_t M, index_t N, index_t K, index_t stride_a, index_t stride_b, index_t stride_c)
Definition: reference_gemm.hpp:733
 
CK_TILE_HOST void reference_gemm_multiple_d(const HostTensor< ADataType > &a_m_k, const HostTensor< BDataType > &b_k_n, const std::array< HostTensor< DDataType >, DsDataType::size()> &ds_m_n, HostTensor< CDataType > &c_m_n, const ACCElementOp &acc_element_op={})
Definition: reference_gemm.hpp:479
 
CK_TILE_HOST void reference_gemm(const HostTensor< ADataType > &a_m_k, const HostTensor< BDataType > &b_k_n, HostTensor< CDataType > &c_m_n, const AElementOp &a_element_op={}, const BElementOp &b_element_op={}, const ACCElementOp &acc_element_op={})
Definition: reference_gemm.hpp:256
 
CK_TILE_HOST void reference_mx_gemm(const HostTensor< ADataType > &a_m_k, const HostTensor< BDataType > &b_k_n, HostTensor< CDataType > &c_m_n, const HostTensor< ScaleDataType > &scale_a, const HostTensor< ScaleDataType > &scale_b, const AElementOp &={}, const BElementOp &={}, const ACCElementOp &={})
Definition: reference_gemm.hpp:392
 
CK_TILE_HOST void reference_gemm_tensor_quant(const HostTensor< ADataType > &a_m_k, const HostTensor< AQDataType > &aq_1_1, const HostTensor< BDataType > &b_k_n, const HostTensor< BQDataType > &bq_1_1, HostTensor< CDataType > &c_m_n, const AElementOp &a_element_op={}, const BElementOp &b_element_op={}, const ACCElementOp &acc_element_op={})
Definition: reference_gemm.hpp:207
 
Definition: host_tensor.hpp:336
 
std::size_t get_length(std::size_t dim) const
Definition: host_tensor.hpp:388
 
Definition: functional.hpp:86
 
Definition: numeric.hpp:81