14 template <
typename ADataType,
28 const AElementOp& a_element_op = {},
29 const BElementOp& b_element_op = {},
30 const ACCElementOp& acc_element_op = {})
36 auto f_mn = [&](
auto m,
auto n) {
37 AccDataType v_acc = 0, v_block_acc = 0;
39 static_assert(std::is_same_v<ADataType, pk_int4_t> || std::is_same_v<ADataType, fp8_t> ||
40 std::is_same_v<ADataType, bf8_t>);
41 static_assert(std::is_same_v<BDataType, fp8_t> || std::is_same_v<BDataType, bf8_t> ||
42 std::is_same_v<BDataType, pk_int4_t>);
43 static_assert(std::is_same_v<AccDataType, float>);
44 static_assert(std::is_same_v<CDataType, float> ||
45 std::is_same_v<CDataType, ck_tile::half_t>);
46 for(std::size_t k = 0; k < K; ++k)
50 if constexpr(std::is_same_v<ADataType, pk_int4_t>)
52 const pk_int4_t pk_val = a_element_op(a_m_k(m, k));
61 v_a = ck_tile::type_convert<AccDataType>(a_element_op(a_m_k(m, k)));
63 if constexpr(std::is_same_v<BDataType, pk_int4_t>)
65 const pk_int4_t pk_val = b_element_op(b_k_n(k, n));
72 else if constexpr(std::is_same_v<BDataType, fp8_t>)
78 v_b = ck_tile::type_convert<AccDataType>(b_element_op(b_k_n(k, n)));
80 v_block_acc += v_a * v_b;
83 if((k + 1) % QuantGroupSize == 0)
86 index_t outer_dim = (aquant) ? m : k / QuantGroupSize;
87 index_t inner_dim = (aquant) ? k / QuantGroupSize : n;
89 if constexpr(std::is_same_v<QDataType, float>)
91 scale = q(outer_dim, inner_dim);
93 else if constexpr(std::is_same_v<QDataType, ck_tile::fp8_t>)
97 else if constexpr(std::is_same_v<QDataType, ck_tile::bf8_t>)
103 static_assert(
false,
"Unexpected Q datatype.");
105 v_block_acc *= scale;
106 v_acc += v_block_acc;
111 c_m_n(m, n) = ck_tile::type_convert<CDataType>(acc_element_op(v_acc));
115 std::cout << std::endl;
118 template <
typename ADataType,
122 typename AccDataType,
132 const AElementOp& a_element_op = {},
133 const BElementOp& b_element_op = {},
134 const ACCElementOp& acc_element_op = {})
136 static_assert(std::is_same_v<ADataType, fp8_t> || std::is_same_v<ADataType, bf8_t>);
137 static_assert(std::is_same_v<BDataType, fp8_t> || std::is_same_v<BDataType, bf8_t>);
138 static_assert(std::is_same_v<AccDataType, float>);
139 static_assert(std::is_same_v<CDataType, float> || std::is_same_v<CDataType, ck_tile::half_t>);
140 static_assert(std::is_same_v<AQDataType, float> && std::is_same_v<BQDataType, float>);
145 auto f_mn = [&](
auto m,
auto n) {
147 AccDataType v_acc = 0;
149 float a_scale = aq_m_1(m, 0);
150 float b_scale = bq_1_n(0, n);
153 for(std::size_t k = 0; k < K; ++k)
159 if constexpr(std::is_same_v<ADataType, pk_int4_t>)
161 const pk_int4_t pk_val = a_element_op(a_m_k(m, k));
170 v_a = ck_tile::type_convert<AccDataType>(a_element_op(a_m_k(m, k)));
174 if constexpr(std::is_same_v<BDataType, pk_int4_t>)
176 const pk_int4_t pk_val = b_element_op(b_k_n(k, n));
185 v_b = ck_tile::type_convert<AccDataType>(b_element_op(b_k_n(k, n)));
191 v_acc = v_acc * a_scale * b_scale;
193 c_m_n(m, n) = ck_tile::type_convert<CDataType>(acc_element_op(v_acc));
199 template <
typename ADataType,
203 typename AccDataType,
213 const AElementOp& a_element_op = {},
214 const BElementOp& b_element_op = {},
215 const ACCElementOp& acc_element_op = {})
217 static_assert(std::is_same_v<ADataType, fp8_t> || std::is_same_v<ADataType, bf8_t>);
218 static_assert(std::is_same_v<BDataType, fp8_t> || std::is_same_v<BDataType, bf8_t>);
219 static_assert(std::is_same_v<AccDataType, float>);
220 static_assert(std::is_same_v<CDataType, float> || std::is_same_v<CDataType, ck_tile::half_t>);
221 static_assert(std::is_same_v<AQDataType, float> && std::is_same_v<BQDataType, float>);
226 auto f_mn = [&](
auto m,
auto n) {
228 AccDataType v_acc = 0;
230 const AccDataType a_scale = ck_tile::type_convert<AccDataType>(aq_1_1(0, 0));
231 const AccDataType b_scale = ck_tile::type_convert<AccDataType>(bq_1_1(0, 0));
234 for(std::size_t k = 0; k < K; ++k)
236 AccDataType v_a = ck_tile::type_convert<AccDataType>(a_element_op(a_m_k(m, k)));
237 AccDataType v_b = ck_tile::type_convert<AccDataType>(b_element_op(b_k_n(k, n)));
242 v_acc = v_acc * a_scale * b_scale;
244 c_m_n(m, n) = ck_tile::type_convert<CDataType>(acc_element_op(v_acc));
250 template <
typename ADataType,
252 typename AccDataType,
260 const AElementOp& a_element_op = {},
261 const BElementOp& b_element_op = {},
262 const ACCElementOp& acc_element_op = {})
268 auto f_mn = [&](
auto m,
auto n) {
269 AccDataType v_acc = 0;
271 for(std::size_t k = 0; k < K; ++k)
275 if constexpr(std::is_same_v<ADataType, pk_int4_t>)
277 const pk_int4_t pk_val = a_element_op(a_m_k(m, k));
286 v_a = ck_tile::type_convert<AccDataType>(a_element_op(a_m_k(m, k)));
288 if constexpr(std::is_same_v<BDataType, pk_int4_t>)
290 const pk_int4_t pk_val = b_element_op(b_k_n(k, n));
299 v_b = ck_tile::type_convert<AccDataType>(b_element_op(b_k_n(k, n)));
304 c_m_n(m, n) = ck_tile::type_convert<CDataType>(acc_element_op(v_acc));
310 template <
typename AsDataType,
313 typename AccDataType,
317 typename CDElementOp,
318 typename ADataType = remove_cvref_t<std::tuple_element_t<0, AsDataType>>,
319 typename BDataType = remove_cvref_t<std::tuple_element_t<0, BsDataType>>,
320 typename DDataType = remove_cvref_t<std::tuple_element_t<0, DsDataType>>>
328 const AElementOp& a_element_op = {},
329 const BElementOp& b_element_op = {},
330 const CDElementOp& acc_element_op = {})
337 generate_tie([&](
auto idx) ->
auto& {
return as_m_k[idx]; },
number<AsDataType::size()>{});
340 generate_tie([&](
auto idx) ->
auto& {
return bs_k_n[idx]; },
number<BsDataType::size()>{});
343 generate_tie([&](
auto idx) ->
auto& {
return ds_m_n[idx]; },
number<DsDataType::size()>{});
346 auto a_elementwise_fn = [&](
auto i,
auto j) {
347 ck_tile::apply([&](
auto&&... t) { a_element_op(a_m_k(i, j), t(i, j)...); }, as_m_k_tuple);
353 auto b_elementwise_fn = [&](
auto i,
auto j) {
354 ck_tile::apply([&](
auto&&... t) { b_element_op(b_k_n(i, j), t(i, j)...); }, bs_k_n_tuple);
359 auto f_mk_kn_mn = [&](
auto m,
auto n) {
360 AccDataType v_acc = 0;
361 for(std::size_t k = 0; k < K; ++k)
363 ADataType v_a = a_m_k(m, k);
364 BDataType v_b = b_k_n(k, n);
366 ck_tile::type_convert<AccDataType>(v_a) * ck_tile::type_convert<AccDataType>(v_b);
374 ck_tile::type_convert<float>(v_acc),
375 ck_tile::type_convert<float>(t(m, n))...);
379 c_m_n(m, n) = ck_tile::type_convert<CDataType>(v_c);
385 template <
typename ADataType,
388 typename AccDataType,
390 typename ACCElementOp,
391 typename DDataType = remove_cvref_t<std::tuple_element_t<0, DsDataType>>>
397 const ACCElementOp& acc_element_op = {})
403 auto f_mk_kn_mn = [&](
auto m,
auto n) {
404 AccDataType v_acc = 0;
405 for(std::size_t k = 0; k < K; ++k)
407 ADataType v_a = a_m_k(m, k);
408 BDataType v_b = b_k_n(k, n);
410 ck_tile::type_convert<AccDataType>(v_a) * ck_tile::type_convert<AccDataType>(v_b);
414 if constexpr(DsDataType::size() == 0)
416 acc_element_op(v_c, ck_tile::type_convert<float>(v_acc));
418 else if constexpr(DsDataType::size() == 1)
421 ck_tile::type_convert<float>(v_acc),
422 ck_tile::type_convert<float>(ds_m_n[0](m, n)));
424 else if constexpr(DsDataType::size() == 2)
427 ck_tile::type_convert<float>(v_acc),
428 ck_tile::type_convert<float>(ds_m_n[0](m, n)),
429 ck_tile::type_convert<float>(ds_m_n[1](m, n)));
431 c_m_n(m, n) = ck_tile::type_convert<CDataType>(v_c);
437 template <
typename ADataType,
439 typename AccDataType,
454 int idx = blockIdx.x * blockDim.x + threadIdx.x;
458 if(row < M && col < N)
460 AccDataType acc = 0.0;
461 for(
int k = 0; k < K; ++k)
466 int a_index = (std::is_same_v<LayoutA, tensor_layout::gemm::RowMajor>)
469 int b_index = (std::is_same_v<LayoutB, tensor_layout::gemm::ColumnMajor>)
475 if constexpr(std::is_same_v<ADataType, pk_int4_t>)
485 v_a = ck_tile::type_convert<AccDataType>(A[a_index]);
487 if constexpr(std::is_same_v<BDataType, pk_int4_t>)
497 v_b = ck_tile::type_convert<AccDataType>(B[b_index]);
502 int c_index = (std::is_same_v<LayoutC, tensor_layout::gemm::RowMajor>)
503 ? row * strideC + col
504 : col * strideC + row;
505 C[c_index] = ck_tile::type_convert<CDataType>(acc);
509 template <
typename ADataType,
511 typename AccDataType,
526 int totalElements = M * N;
527 int numThreadsPerBlock = 256;
528 int numBlocks = (totalElements + numThreadsPerBlock - 1) / numThreadsPerBlock;
530 naive_gemm_kernel<ADataType, BDataType, AccDataType, CDataType, LayoutA, LayoutB, LayoutC>
531 <<<numBlocks, numThreadsPerBlock>>>(
532 a_ptr, b_ptr, c_ptr, M, N, K, stride_a, stride_b, stride_c);
537 template <
typename ADataType,
539 typename AccDataType,
558 int totalElements = M * N;
559 int numThreadsPerBlock = 256;
560 int numBlocks = (totalElements + numThreadsPerBlock - 1) / numThreadsPerBlock;
562 for(
index_t batch_id = 0; batch_id < batch_count; ++batch_id)
564 ADataType* d_ATemp = a_ptr + batch_id * batch_stride_A;
565 BDataType* d_BTemp = b_ptr + batch_id * batch_stride_B;
566 CDataType* d_CTemp = c_ptr + batch_id * batch_stride_C;
567 naive_gemm_kernel<ADataType, BDataType, AccDataType, CDataType, LayoutA, LayoutB, LayoutC>
568 <<<numBlocks, numThreadsPerBlock>>>(
569 d_ATemp, d_BTemp, d_CTemp, M, N, K, stride_a, stride_b, stride_c);
#define CK_TILE_HOST
Definition: config.hpp:40
Definition: cluster_descriptor.hpp:13
void reference_batched_gemm_gpu(ADataType *a_ptr, BDataType *b_ptr, CDataType *c_ptr, index_t M, index_t N, index_t K, index_t stride_a, index_t stride_b, index_t stride_c, index_t batch_stride_A, index_t batch_stride_B, index_t batch_stride_C, index_t batch_count)
Definition: reference_gemm.hpp:544
CK_TILE_HOST auto make_ParallelTensorFunctor(F f, Xs... xs)
Definition: host_tensor.hpp:329
constexpr decltype(auto) apply(F &&f, Tuple &&t)
Definition: tuple.hpp:526
__global__ void naive_gemm_kernel(ADataType *A, BDataType *B, CDataType *C, ck_tile::index_t M, ck_tile::index_t N, ck_tile::index_t K, ck_tile::index_t strideA, ck_tile::index_t strideB, ck_tile::index_t strideC)
Definition: reference_gemm.hpp:444
CK_TILE_HOST_DEVICE fp32x2_t pk_int4_t_to_fp32x2_t(const pk_int4_t &x)
Definition: pk_int4.hpp:105
CK_TILE_HOST_DEVICE float fp8_to_float_raw(uint8_t)
Definition: float8.hpp:751
CK_TILE_HOST void reference_gemm_quant(const HostTensor< ADataType > &a_m_k, const HostTensor< QDataType > &q, const HostTensor< BDataType > &b_k_n, HostTensor< CDataType > &c_m_n, const AElementOp &a_element_op={}, const BElementOp &b_element_op={}, const ACCElementOp &acc_element_op={})
Definition: reference_gemm.hpp:24
CK_TILE_HOST_DEVICE float bf8_to_float_raw(uint8_t)
Definition: float8.hpp:764
CK_TILE_HOST void reference_gemm_multiple_abd(const std::array< HostTensor< ADataType >, AsDataType::size()> &as_m_k, const std::array< HostTensor< BDataType >, BsDataType::size()> &bs_k_n, const std::array< HostTensor< DDataType >, DsDataType::size()> &ds_m_n, HostTensor< ADataType > &a_m_k, HostTensor< BDataType > &b_k_n, HostTensor< CDataType > &c_m_n, const AElementOp &a_element_op={}, const BElementOp &b_element_op={}, const CDElementOp &acc_element_op={})
Definition: reference_gemm.hpp:322
float fp32x2_t
Definition: pk_fp4.hpp:22
int32_t index_t
Definition: integer.hpp:9
CK_TILE_HOST void reference_gemm_rowcol_quant(const HostTensor< ADataType > &a_m_k, const HostTensor< AQDataType > &aq_m_1, const HostTensor< BDataType > &b_k_n, const HostTensor< BQDataType > &bq_1_n, HostTensor< CDataType > &c_m_n, const AElementOp &a_element_op={}, const BElementOp &b_element_op={}, const ACCElementOp &acc_element_op={})
Definition: reference_gemm.hpp:127
constexpr CK_TILE_HOST_DEVICE auto generate_tie(F &&f, number< N >)
Definition: tuple.hpp:435
constant< v > number
Definition: integral_constant.hpp:37
CK_TILE_HOST_DEVICE fp32x2_t pk_int4_t_to_fp32x2_t_signed_conversion(const pk_int4_t &x)
Definition: pk_int4.hpp:120
void reference_gemm_gpu(ADataType *a_ptr, BDataType *b_ptr, CDataType *c_ptr, index_t M, index_t N, index_t K, index_t stride_a, index_t stride_b, index_t stride_c)
Definition: reference_gemm.hpp:516
CK_TILE_HOST void reference_gemm_multiple_d(const HostTensor< ADataType > &a_m_k, const HostTensor< BDataType > &b_k_n, const std::array< HostTensor< DDataType >, DsDataType::size()> &ds_m_n, HostTensor< CDataType > &c_m_n, const ACCElementOp &acc_element_op={})
Definition: reference_gemm.hpp:393
CK_TILE_HOST void reference_gemm(const HostTensor< ADataType > &a_m_k, const HostTensor< BDataType > &b_k_n, HostTensor< CDataType > &c_m_n, const AElementOp &a_element_op={}, const BElementOp &b_element_op={}, const ACCElementOp &acc_element_op={})
Definition: reference_gemm.hpp:257
CK_TILE_HOST void reference_gemm_tensor_quant(const HostTensor< ADataType > &a_m_k, const HostTensor< AQDataType > &aq_1_1, const HostTensor< BDataType > &b_k_n, const HostTensor< BQDataType > &bq_1_1, HostTensor< CDataType > &c_m_n, const AElementOp &a_element_op={}, const BElementOp &b_element_op={}, const ACCElementOp &acc_element_op={})
Definition: reference_gemm.hpp:208
unsigned int uint32_t
Definition: stdint.h:126
Definition: host_tensor.hpp:336
std::size_t get_length(std::size_t dim) const
Definition: host_tensor.hpp:388
Definition: functional.hpp:86
Definition: numeric.hpp:81