14 template <
typename ADataType,
19 typename QuantGroupSize,
28 const AElementOp& a_element_op = {},
29 const BElementOp& b_element_op = {},
30 const ACCElementOp& acc_element_op = {})
36 auto f_mn = [&](
auto m,
auto n) {
37 AccDataType v_acc = 0;
39 constexpr std::size_t kGroupK = QuantGroupSize::kK;
42 auto load_a = [&](std::size_t k) -> AccDataType {
43 if constexpr(std::is_same_v<ADataType, pk_int4_t>)
45 const pk_int4_t pk_val = a_element_op(a_m_k(m, k));
47 return (k & 1) ? fp32_val.hi : fp32_val.lo;
51 return ck_tile::type_convert<AccDataType>(a_element_op(a_m_k(m, k)));
56 auto load_b = [&](std::size_t k) -> AccDataType {
57 if constexpr(std::is_same_v<BDataType, pk_int4_t>)
59 const pk_int4_t pk_val = b_element_op(b_k_n(k, n));
61 return (k & 1) ? fp32_val.hi : fp32_val.lo;
63 else if constexpr(std::is_same_v<BDataType, fp8_t>)
69 return ck_tile::type_convert<AccDataType>(b_element_op(b_k_n(k, n)));
75 const ck_tile::index_t outer_dim = aquant ? (m / QuantGroupSize::kM) : k_group;
76 const ck_tile::index_t inner_dim = aquant ? k_group : (n / QuantGroupSize::kN);
78 if constexpr(std::is_same_v<QDataType, float>)
80 return q(outer_dim, inner_dim);
82 else if constexpr(std::is_same_v<QDataType, ck_tile::fp8_t>)
93 for(std::size_t k_begin = 0; k_begin < K; k_begin += kGroupK)
95 const std::size_t k_end = std::min<std::size_t>(k_begin + kGroupK, K);
97 AccDataType v_block_acc = 0;
100 for(std::size_t k = k_begin; k < k_end; ++k)
102 const AccDataType v_a = load_a(k);
103 const AccDataType v_b = load_b(k);
104 v_block_acc += v_a * v_b;
108 const float scale = load_scale(k_group);
110 v_acc += v_block_acc * scale;
113 c_m_n(m, n) = ck_tile::type_convert<CDataType>(acc_element_op(v_acc));
117 std::cout << std::endl;
120 template <
typename ADataType,
124 typename AccDataType,
126 typename AQuantGroupSize,
127 typename BQuantGroupSize,
136 const AElementOp& a_element_op = {},
137 const BElementOp& b_element_op = {},
138 const ACCElementOp& acc_element_op = {})
144 auto f_mn = [&](
auto m,
auto n) {
145 AccDataType v_acc = 0;
147 constexpr std::size_t kGroupK = BQuantGroupSize::kK;
150 auto load_a = [&](std::size_t k) -> AccDataType {
151 if constexpr(std::is_same_v<ADataType, pk_int4_t>)
153 const pk_int4_t pk_val = a_element_op(a_m_k(m, k));
155 return (k & 1) ? fp32_val.hi : fp32_val.lo;
159 return ck_tile::type_convert<AccDataType>(a_element_op(a_m_k(m, k)));
164 auto load_b = [&](std::size_t k) -> AccDataType {
165 if constexpr(std::is_same_v<BDataType, pk_int4_t>)
167 const pk_int4_t pk_val = b_element_op(b_k_n(k, n));
169 return (k & 1) ? fp32_val.hi : fp32_val.lo;
171 else if constexpr(std::is_same_v<BDataType, fp8_t>)
177 return ck_tile::type_convert<AccDataType>(b_element_op(b_k_n(k, n)));
186 if constexpr(std::is_same_v<AQDataType, float>)
188 return a_q(outer_dim, inner_dim);
190 else if constexpr(std::is_same_v<AQDataType, ck_tile::fp8_t>)
204 if constexpr(std::is_same_v<BQDataType, float>)
206 return b_q(outer_dim, inner_dim);
208 else if constexpr(std::is_same_v<BQDataType, ck_tile::fp8_t>)
218 for(std::size_t k_begin = 0; k_begin < K; k_begin += kGroupK)
220 const std::size_t k_end = std::min<std::size_t>(k_begin + kGroupK, K);
222 AccDataType v_block_acc = 0;
225 for(std::size_t k = k_begin; k < k_end; ++k)
227 const AccDataType v_a = load_a(k);
228 const AccDataType v_b = load_b(k);
229 v_block_acc += v_a * v_b;
233 const float scale_a = load_scale_a(k_group);
234 const float scale_b = load_scale_b(k_group);
236 v_acc += v_block_acc * scale_a * scale_b;
239 c_m_n(m, n) = ck_tile::type_convert<CDataType>(acc_element_op(v_acc));
245 template <
typename ADataType,
249 typename AccDataType,
259 const AElementOp& a_element_op = {},
260 const BElementOp& b_element_op = {},
261 const ACCElementOp& acc_element_op = {})
263 static_assert(std::is_same_v<ADataType, fp8_t> || std::is_same_v<ADataType, bf8_t>);
264 static_assert(std::is_same_v<BDataType, fp8_t> || std::is_same_v<BDataType, bf8_t>);
265 static_assert(std::is_same_v<AccDataType, float>);
266 static_assert(std::is_same_v<CDataType, float> || std::is_same_v<CDataType, ck_tile::half_t>);
267 static_assert(std::is_same_v<AQDataType, float> && std::is_same_v<BQDataType, float>);
272 auto f_mn = [&](
auto m,
auto n) {
274 AccDataType v_acc = 0;
276 float a_scale = aq_m_1(m, 0);
277 float b_scale = bq_1_n(0, n);
280 for(std::size_t k = 0; k < K; ++k)
286 if constexpr(std::is_same_v<ADataType, pk_int4_t>)
288 const pk_int4_t pk_val = a_element_op(a_m_k(m, k));
297 v_a = ck_tile::type_convert<AccDataType>(a_element_op(a_m_k(m, k)));
301 if constexpr(std::is_same_v<BDataType, pk_int4_t>)
303 const pk_int4_t pk_val = b_element_op(b_k_n(k, n));
312 v_b = ck_tile::type_convert<AccDataType>(b_element_op(b_k_n(k, n)));
318 v_acc = v_acc * a_scale * b_scale;
320 c_m_n(m, n) = ck_tile::type_convert<CDataType>(acc_element_op(v_acc));
326 template <
typename ADataType,
330 typename AccDataType,
340 const AElementOp& a_element_op = {},
341 const BElementOp& b_element_op = {},
342 const ACCElementOp& acc_element_op = {})
344 static_assert(std::is_same_v<ADataType, fp8_t> || std::is_same_v<ADataType, bf8_t>);
345 static_assert(std::is_same_v<BDataType, fp8_t> || std::is_same_v<BDataType, bf8_t>);
346 static_assert(std::is_same_v<AccDataType, float>);
347 static_assert(std::is_same_v<CDataType, float> || std::is_same_v<CDataType, ck_tile::half_t>);
348 static_assert(std::is_same_v<AQDataType, float> && std::is_same_v<BQDataType, float>);
353 auto f_mn = [&](
auto m,
auto n) {
355 AccDataType v_acc = 0;
357 const AccDataType a_scale = ck_tile::type_convert<AccDataType>(aq_1_1(0, 0));
358 const AccDataType b_scale = ck_tile::type_convert<AccDataType>(bq_1_1(0, 0));
361 for(std::size_t k = 0; k < K; ++k)
363 AccDataType v_a = ck_tile::type_convert<AccDataType>(a_element_op(a_m_k(m, k)));
364 AccDataType v_b = ck_tile::type_convert<AccDataType>(b_element_op(b_k_n(k, n)));
369 v_acc = v_acc * a_scale * b_scale;
371 c_m_n(m, n) = ck_tile::type_convert<CDataType>(acc_element_op(v_acc));
377 template <
typename ADataType,
380 typename AccDataType,
382 typename QuantGroupSize,
391 const AElementOp& a_element_op = {},
392 const BElementOp& b_element_op = {},
393 const ACCElementOp& acc_element_op = {})
399 auto f_mn = [&](
auto m,
auto n) {
400 AccDataType v_acc = 0;
401 AccDataType pasual = 0;
402 for(std::size_t k = 0; k < (K / 2); k++)
404 using ComputeType = float;
405 auto b_scale = type_convert<int32_t>(q((2 * k) / QuantGroupSize::kK, n)) - 127;
406 ComputeType v_a_0, v_a_1;
407 ComputeType v_b_0, v_b_1;
409 v_a_0 = ck_tile::type_convert<ComputeType>((a_element_op(a_m_k(m, 2 * k))));
410 v_a_1 = ck_tile::type_convert<ComputeType>((a_element_op(a_m_k(m, 2 * k + 1))));
412 if constexpr(std::is_same_v<BDataType, pk_fp4_raw_t>)
414 auto b_pack = type_convert<pk_fp4_t>(b_element_op(b_k_n(k, n)));
415 auto b_scale_fp4 = type_convert<float>(
std::pow(2.0f, b_scale));
417 auto b_f4_lo = type_convert<pk_fp4_t>(b_pack.unpack(number<0>{}));
418 auto b_f4_hi = type_convert<pk_fp4_t>(b_pack.unpack(number<1>{}));
420 v_b_0 = type_convert<ComputeType>(b_f4_lo) * b_scale_fp4;
421 v_b_1 = type_convert<ComputeType>(b_f4_hi) * b_scale_fp4;
424 pasual = v_a_0 * v_b_0 + v_a_1 * v_b_1;
427 c_m_n(m, n) = ck_tile::type_convert<CDataType>(acc_element_op(v_acc));
431 std::cout << std::endl;
434 template <
typename ADataType,
436 typename AccDataType,
444 const AElementOp& a_element_op = {},
445 const BElementOp& b_element_op = {},
446 const ACCElementOp& acc_element_op = {})
452 auto f_mn = [&](
auto m,
auto n) {
453 AccDataType v_acc = 0;
455 for(std::size_t k = 0; k < K; ++k)
459 if constexpr(std::is_same_v<ADataType, pk_int4_t>)
461 const pk_int4_t pk_val = a_element_op(a_m_k(m, k));
470 v_a = ck_tile::type_convert<AccDataType>(a_element_op(a_m_k(m, k)));
472 if constexpr(std::is_same_v<BDataType, pk_int4_t>)
474 const pk_int4_t pk_val = b_element_op(b_k_n(k, n));
483 v_b = ck_tile::type_convert<AccDataType>(b_element_op(b_k_n(k, n)));
488 c_m_n(m, n) = ck_tile::type_convert<CDataType>(acc_element_op(v_acc));
494 template <
typename AsDataType,
497 typename AccDataType,
501 typename CDElementOp,
502 typename ADataType = remove_cvref_t<std::tuple_element_t<0, AsDataType>>,
503 typename BDataType = remove_cvref_t<std::tuple_element_t<0, BsDataType>>,
504 typename DDataType = remove_cvref_t<std::tuple_element_t<0, DsDataType>>>
512 const AElementOp& a_element_op = {},
513 const BElementOp& b_element_op = {},
514 const CDElementOp& acc_element_op = {})
521 generate_tie([&](
auto idx) ->
auto& {
return as_m_k[idx]; },
number<AsDataType::size()>{});
524 generate_tie([&](
auto idx) ->
auto& {
return bs_k_n[idx]; },
number<BsDataType::size()>{});
527 generate_tie([&](
auto idx) ->
auto& {
return ds_m_n[idx]; },
number<DsDataType::size()>{});
530 auto a_elementwise_fn = [&](
auto i,
auto j) {
531 ck_tile::apply([&](
auto&&... t) { a_element_op(a_m_k(i, j), t(i, j)...); }, as_m_k_tuple);
537 auto b_elementwise_fn = [&](
auto i,
auto j) {
538 ck_tile::apply([&](
auto&&... t) { b_element_op(b_k_n(i, j), t(i, j)...); }, bs_k_n_tuple);
543 auto f_mk_kn_mn = [&](
auto m,
auto n) {
544 AccDataType v_acc = 0;
545 for(std::size_t k = 0; k < K; ++k)
547 ADataType v_a = a_m_k(m, k);
548 BDataType v_b = b_k_n(k, n);
550 ck_tile::type_convert<AccDataType>(v_a) * ck_tile::type_convert<AccDataType>(v_b);
558 ck_tile::type_convert<float>(v_acc),
559 ck_tile::type_convert<float>(t(m, n))...);
563 c_m_n(m, n) = ck_tile::type_convert<CDataType>(v_c);
569 template <
typename ADataType,
571 typename ScaleDataType,
572 typename AccDataType,
582 const AElementOp& = {},
583 const BElementOp& = {},
584 const ACCElementOp& = {})
586 static_assert(std::is_same_v<AElementOp, ck_tile::identity>);
587 static_assert(std::is_same_v<BElementOp, ck_tile::identity>);
588 static_assert(std::is_same_v<ACCElementOp, ck_tile::identity>);
594 const std::size_t ScaleBlockSize = K / scale_a.
get_length(1);
596 HostTensor<AccDataType> a_m_k_scaled({std::size_t(M), std::size_t(K)},
597 {std::size_t(K), std::size_t(1)});
598 HostTensor<AccDataType> b_k_n_scaled({std::size_t(K), std::size_t(N)},
599 {std::size_t(1), std::size_t(K)});
601 for(std::size_t m = 0; m < M; ++m)
603 for(std::size_t k = 0; k < K; ++k)
605 if constexpr(std::is_same_v<ADataType, pk_fp4_t>)
610 auto a_f4x2 = a_m_k(m, k);
611 auto a_scale = ck_tile::type_convert<AccDataType>(scale_a(m, k / ScaleBlockSize));
613 ck_tile::type_convert<AccDataType>(a_f4x2.template unpack<>(number<0>{}));
615 ck_tile::type_convert<AccDataType>(a_f4x2.template unpack<>(number<1>{}));
617 a_m_k_scaled(m, k) = a_f4_lo * a_scale;
618 a_m_k_scaled(m, k + 1) = a_f4_hi * a_scale;
623 ck_tile::type_convert<AccDataType>((a_m_k(m, k))) *
624 ck_tile::type_convert<AccDataType>(scale_a(m, k / ScaleBlockSize));
629 for(std::size_t n = 0; n < N; n++)
631 for(std::size_t k = 0; k < K; k++)
633 if constexpr(std::is_same_v<BDataType, pk_fp4_t>)
638 auto b_f4x2 = b_k_n(k, n);
639 auto b_scale = ck_tile::type_convert<AccDataType>(scale_b(k / ScaleBlockSize, n));
641 ck_tile::type_convert<AccDataType>(b_f4x2.template unpack<>(number<0>{}));
643 ck_tile::type_convert<AccDataType>(b_f4x2.template unpack<>(number<1>{}));
645 b_k_n_scaled(k, n) = b_f4_lo * b_scale;
646 b_k_n_scaled(k + 1, n) = b_f4_hi * b_scale;
651 ck_tile::type_convert<AccDataType>((b_k_n(k, n))) *
652 ck_tile::type_convert<AccDataType>(scale_b(k / ScaleBlockSize, n));
658 reference_gemm<AccDataType, AccDataType, AccDataType, CDataType>(
659 a_m_k_scaled, b_k_n_scaled, c_m_n);
662 template <
typename ADataType,
665 typename AccDataType,
667 typename ACCElementOp,
668 typename DDataType = remove_cvref_t<std::tuple_element_t<0, DsDataType>>>
674 const ACCElementOp& acc_element_op = {})
680 auto f_mk_kn_mn = [&](
auto m,
auto n) {
681 AccDataType v_acc = 0;
682 for(std::size_t k = 0; k < K; ++k)
684 ADataType v_a = a_m_k(m, k);
685 BDataType v_b = b_k_n(k, n);
687 ck_tile::type_convert<AccDataType>(v_a) * ck_tile::type_convert<AccDataType>(v_b);
691 if constexpr(DsDataType::size() == 0)
693 acc_element_op(v_c, ck_tile::type_convert<float>(v_acc));
695 else if constexpr(DsDataType::size() == 1)
698 ck_tile::type_convert<float>(v_acc),
699 ck_tile::type_convert<float>(ds_m_n[0](m, n)));
701 else if constexpr(DsDataType::size() == 2)
704 ck_tile::type_convert<float>(v_acc),
705 ck_tile::type_convert<float>(ds_m_n[0](m, n)),
706 ck_tile::type_convert<float>(ds_m_n[1](m, n)));
708 c_m_n(m, n) = ck_tile::type_convert<CDataType>(v_c);
714 template <
typename ADataType,
716 typename AccDataType,
731 int idx = blockIdx.x * blockDim.x + threadIdx.x;
735 if(row < M && col < N)
737 AccDataType acc = 0.0;
738 for(
int k = 0; k < K; ++k)
743 int a_index = (std::is_same_v<LayoutA, tensor_layout::gemm::RowMajor>)
746 int b_index = (std::is_same_v<LayoutB, tensor_layout::gemm::ColumnMajor>)
752 if constexpr(std::is_same_v<ADataType, pk_int4_t>)
760 else if constexpr(std::is_same_v<ADataType, pk_fp4_t>)
770 v_a = ck_tile::type_convert<AccDataType>(A[a_index]);
772 if constexpr(std::is_same_v<BDataType, pk_int4_t>)
780 else if constexpr(std::is_same_v<BDataType, pk_fp4_t>)
790 v_b = ck_tile::type_convert<AccDataType>(B[b_index]);
795 int c_index = (std::is_same_v<LayoutC, tensor_layout::gemm::RowMajor>)
796 ? row * strideC + col
797 : col * strideC + row;
798 C[c_index] = ck_tile::type_convert<CDataType>(acc);
802 template <
typename ADataType,
804 typename AccDataType,
824 int idx = blockIdx.x * blockDim.x + threadIdx.x;
828 if(row < M && col < N)
830 AccDataType acc = 0.0, acc_temp = 0.0;
832 index_t scale_A_stride = (M + scale_granularity_m - 1) / scale_granularity_m;
833 index_t scale_B_stride = (N + scale_granularity_n - 1) / scale_granularity_n;
838 for(
int k = 0; k < K; ++k)
840 if(k % scale_granularity_k == 0)
843 acc += acc_temp * scale_A * scale_B;
846 scale_A = scale_A_ptr[(row / scale_granularity_m) +
847 (k / scale_granularity_k) * scale_A_stride];
848 scale_B = scale_B_ptr[(col / scale_granularity_n) +
849 (k / scale_granularity_k) * scale_B_stride];
855 int a_index = (std::is_same_v<LayoutA, tensor_layout::gemm::RowMajor>)
858 int b_index = (std::is_same_v<LayoutB, tensor_layout::gemm::ColumnMajor>)
864 if constexpr(std::is_same_v<ADataType, pk_int4_t>)
872 else if constexpr(std::is_same_v<ADataType, pk_fp4_t>)
882 v_a = ck_tile::type_convert<AccDataType>(A[a_index]);
885 if constexpr(std::is_same_v<BDataType, pk_int4_t>)
893 else if constexpr(std::is_same_v<BDataType, pk_fp4_t>)
903 v_b = ck_tile::type_convert<AccDataType>(B[b_index]);
905 acc_temp += v_a * v_b;
908 acc += acc_temp * scale_A * scale_B;
910 int c_index = (std::is_same_v<LayoutC, tensor_layout::gemm::RowMajor>)
911 ? row * strideC + col
912 : col * strideC + row;
913 C[c_index] = ck_tile::type_convert<CDataType>(acc);
917 template <
typename ADataType,
919 typename AccDataType,
934 int totalElements = M * N;
935 int numThreadsPerBlock = 256;
936 int numBlocks = (totalElements + numThreadsPerBlock - 1) / numThreadsPerBlock;
938 naive_gemm_kernel<ADataType, BDataType, AccDataType, CDataType, LayoutA, LayoutB, LayoutC>
939 <<<numBlocks, numThreadsPerBlock>>>(
940 a_ptr, b_ptr, c_ptr, M, N, K, stride_a, stride_b, stride_c);
945 template <
typename ADataType,
947 typename AccDataType,
967 int totalElements = M * N;
968 int numThreadsPerBlock = 256;
969 int numBlocks = (totalElements + numThreadsPerBlock - 1) / numThreadsPerBlock;
971 blockwise_gemm_kernel<ADataType, BDataType, AccDataType, CDataType, LayoutA, LayoutB, LayoutC>
972 <<<numBlocks, numThreadsPerBlock>>>(a_ptr,
990 template <
typename ADataType,
992 typename AccDataType,
1011 int totalElements = M * N;
1012 int numThreadsPerBlock = 256;
1013 int numBlocks = (totalElements + numThreadsPerBlock - 1) / numThreadsPerBlock;
1015 for(
index_t batch_id = 0; batch_id < batch_count; ++batch_id)
1017 ADataType* d_ATemp = a_ptr + batch_id * batch_stride_A;
1018 BDataType* d_BTemp = b_ptr + batch_id * batch_stride_B;
1019 CDataType* d_CTemp = c_ptr + batch_id * batch_stride_C;
1020 naive_gemm_kernel<ADataType, BDataType, AccDataType, CDataType, LayoutA, LayoutB, LayoutC>
1021 <<<numBlocks, numThreadsPerBlock>>>(
1022 d_ATemp, d_BTemp, d_CTemp, M, N, K, stride_a, stride_b, stride_c);
#define CK_TILE_HOST
Definition: config.hpp:44
__host__ T pow(T x, T gamma)
Definition: math_v2.hpp:427
Definition: cluster_descriptor.hpp:13
void reference_batched_gemm_gpu(ADataType *a_ptr, BDataType *b_ptr, CDataType *c_ptr, index_t M, index_t N, index_t K, index_t stride_a, index_t stride_b, index_t stride_c, index_t batch_stride_A, index_t batch_stride_B, index_t batch_stride_C, index_t batch_count)
Definition: reference_gemm.hpp:997
CK_TILE_HOST auto make_ParallelTensorFunctor(F f, Xs... xs)
Definition: host_tensor.hpp:329
constexpr decltype(auto) apply(F &&f, Tuple &&t)
Definition: tuple.hpp:526
__global__ void naive_gemm_kernel(ADataType *A, BDataType *B, CDataType *C, ck_tile::index_t M, ck_tile::index_t N, ck_tile::index_t K, ck_tile::index_t strideA, ck_tile::index_t strideB, ck_tile::index_t strideC)
Definition: reference_gemm.hpp:721
CK_TILE_HOST_DEVICE fp32x2_t pk_int4_t_to_fp32x2_t(const pk_int4_t &x)
Definition: pk_int4.hpp:105
CK_TILE_HOST_DEVICE float fp8_to_float_raw(uint8_t)
Definition: float8.hpp:751
CK_TILE_HOST void reference_gemm_quant(const HostTensor< ADataType > &a_m_k, const HostTensor< QDataType > &q, const HostTensor< BDataType > &b_k_n, HostTensor< CDataType > &c_m_n, const AElementOp &a_element_op={}, const BElementOp &b_element_op={}, const ACCElementOp &acc_element_op={})
Definition: reference_gemm.hpp:24
CK_TILE_HOST_DEVICE float bf8_to_float_raw(uint8_t)
Definition: float8.hpp:764
CK_TILE_HOST void reference_gemm_multiple_abd(const std::array< HostTensor< ADataType >, AsDataType::size()> &as_m_k, const std::array< HostTensor< BDataType >, BsDataType::size()> &bs_k_n, const std::array< HostTensor< DDataType >, DsDataType::size()> &ds_m_n, HostTensor< ADataType > &a_m_k, HostTensor< BDataType > &b_k_n, HostTensor< CDataType > &c_m_n, const AElementOp &a_element_op={}, const BElementOp &b_element_op={}, const CDElementOp &acc_element_op={})
Definition: reference_gemm.hpp:506
float fp32x2_t
Definition: bfloat16.hpp:434
void reference_blockwise_gemm_gpu(ADataType *a_ptr, BDataType *b_ptr, CDataType *c_ptr, index_t M, index_t N, index_t K, index_t stride_a, index_t stride_b, index_t stride_c, index_t scale_granularity_m, index_t scale_granularity_n, index_t scale_granularity_k, float *scale_A_ptr, float *scale_B_ptr)
Definition: reference_gemm.hpp:952
int32_t index_t
Definition: integer.hpp:9
CK_TILE_HOST void reference_gemm_rowcol_quant(const HostTensor< ADataType > &a_m_k, const HostTensor< AQDataType > &aq_m_1, const HostTensor< BDataType > &b_k_n, const HostTensor< BQDataType > &bq_1_n, HostTensor< CDataType > &c_m_n, const AElementOp &a_element_op={}, const BElementOp &b_element_op={}, const ACCElementOp &acc_element_op={})
Definition: reference_gemm.hpp:254
constexpr CK_TILE_HOST_DEVICE auto generate_tie(F &&f, number< N >)
Definition: tuple.hpp:435
constant< v > number
Definition: integral_constant.hpp:37
CK_TILE_HOST_DEVICE fp32x2_t pk_int4_t_to_fp32x2_t_signed_conversion(const pk_int4_t &x)
Definition: pk_int4.hpp:120
CK_TILE_HOST void reference_gemm_abquant(const HostTensor< ADataType > &a_m_k, const HostTensor< AQDataType > &a_q, const HostTensor< BDataType > &b_k_n, const HostTensor< BQDataType > &b_q, HostTensor< CDataType > &c_m_n, const AElementOp &a_element_op={}, const BElementOp &b_element_op={}, const ACCElementOp &acc_element_op={})
Definition: reference_gemm.hpp:131
__global__ void blockwise_gemm_kernel(ADataType *A, BDataType *B, CDataType *C, ck_tile::index_t M, ck_tile::index_t N, ck_tile::index_t K, ck_tile::index_t strideA, ck_tile::index_t strideB, ck_tile::index_t strideC, ck_tile::index_t scale_granularity_m, ck_tile::index_t scale_granularity_n, ck_tile::index_t scale_granularity_k, float *scale_A_ptr, float *scale_B_ptr)
Definition: reference_gemm.hpp:809
constexpr CK_TILE_HOST_DEVICE fp32x2_t pk_fp4_to_fp32x2(const pk_fp4_t &x, float scale)
Definition: pk_fp4.hpp:350
void reference_gemm_gpu(ADataType *a_ptr, BDataType *b_ptr, CDataType *c_ptr, index_t M, index_t N, index_t K, index_t stride_a, index_t stride_b, index_t stride_c)
Definition: reference_gemm.hpp:924
CK_TILE_HOST void reference_gemm_multiple_d(const HostTensor< ADataType > &a_m_k, const HostTensor< BDataType > &b_k_n, const std::array< HostTensor< DDataType >, DsDataType::size()> &ds_m_n, HostTensor< CDataType > &c_m_n, const ACCElementOp &acc_element_op={})
Definition: reference_gemm.hpp:670
CK_TILE_HOST void reference_mxfp4gemm_quant(const HostTensor< ADataType > &a_m_k, const HostTensor< QDataType > &q, const HostTensor< BDataType > &b_k_n, HostTensor< CDataType > &c_m_n, const AElementOp &a_element_op={}, const BElementOp &b_element_op={}, const ACCElementOp &acc_element_op={})
Definition: reference_gemm.hpp:387
CK_TILE_HOST void reference_gemm(const HostTensor< ADataType > &a_m_k, const HostTensor< BDataType > &b_k_n, HostTensor< CDataType > &c_m_n, const AElementOp &a_element_op={}, const BElementOp &b_element_op={}, const ACCElementOp &acc_element_op={})
Definition: reference_gemm.hpp:441
CK_TILE_HOST void reference_mx_gemm(const HostTensor< ADataType > &a_m_k, const HostTensor< BDataType > &b_k_n, HostTensor< CDataType > &c_m_n, const HostTensor< ScaleDataType > &scale_a, const HostTensor< ScaleDataType > &scale_b, const AElementOp &={}, const BElementOp &={}, const ACCElementOp &={})
Definition: reference_gemm.hpp:577
CK_TILE_HOST void reference_gemm_tensor_quant(const HostTensor< ADataType > &a_m_k, const HostTensor< AQDataType > &aq_1_1, const HostTensor< BDataType > &b_k_n, const HostTensor< BQDataType > &bq_1_1, HostTensor< CDataType > &c_m_n, const AElementOp &a_element_op={}, const BElementOp &b_element_op={}, const ACCElementOp &acc_element_op={})
Definition: reference_gemm.hpp:335
Definition: host_tensor.hpp:336
std::size_t get_length(std::size_t dim) const
Definition: host_tensor.hpp:388
Definition: functional.hpp:114
Definition: numeric.hpp:81