14 template <
typename ADataType,
24 const AElementOp& a_element_op = {},
25 const BElementOp& b_element_op = {},
26 const ACCElementOp& acc_element_op = {})
32 auto f_mn = [&](
auto m,
auto n) {
33 AccDataType v_acc = 0;
35 for(std::size_t k = 0; k < K; ++k)
37 ADataType v_a = a_element_op(a_m_k(m, k));
38 BDataType v_b = b_element_op(b_k_n(k, n));
41 ck_tile::type_convert<AccDataType>(v_a) * ck_tile::type_convert<AccDataType>(v_b);
44 c_m_n(m, n) = ck_tile::type_convert<CDataType>(acc_element_op(v_acc));
50 template <
typename ADataType,
67 int idx = blockIdx.x * blockDim.x + threadIdx.x;
71 if(row < M && col < N)
73 AccDataType acc = 0.0;
74 for(
int k = 0; k < K; ++k)
77 int a_index = (std::is_same_v<LayoutA, tensor_layout::gemm::RowMajor>)
80 int b_index = (std::is_same_v<LayoutB, tensor_layout::gemm::ColumnMajor>)
83 acc += ck_tile::type_convert<AccDataType>(A[a_index]) *
84 ck_tile::type_convert<AccDataType>(B[b_index]);
87 int c_index = (std::is_same_v<LayoutC, tensor_layout::gemm::RowMajor>)
89 : col * strideC + row;
90 C[c_index] = ck_tile::type_convert<CDataType>(acc);
94 template <
typename ADataType,
111 int totalElements = M * N;
112 int numThreadsPerBlock = 256;
113 int numBlocks = (totalElements + numThreadsPerBlock - 1) / numThreadsPerBlock;
115 naive_gemm_kernel<ADataType, BDataType, AccDataType, CDataType, LayoutA, LayoutB, LayoutC>
116 <<<numBlocks, numThreadsPerBlock>>>(
117 a_ptr, b_ptr, c_ptr, M, N, K, stride_a, stride_b, stride_c);
122 template <
typename ADataType,
124 typename AccDataType,
143 int totalElements = M * N;
144 int numThreadsPerBlock = 256;
145 int numBlocks = (totalElements + numThreadsPerBlock - 1) / numThreadsPerBlock;
147 for(
index_t batch_id = 0; batch_id < batch_count; ++batch_id)
149 ADataType* d_ATemp = a_ptr + batch_id * batch_stride_A;
150 BDataType* d_BTemp = b_ptr + batch_id * batch_stride_B;
151 CDataType* d_CTemp = c_ptr + batch_id * batch_stride_C;
152 naive_gemm_kernel<ADataType, BDataType, AccDataType, CDataType, LayoutA, LayoutB, LayoutC>
153 <<<numBlocks, numThreadsPerBlock>>>(
154 d_ATemp, d_BTemp, d_CTemp, M, N, K, stride_a, stride_b, stride_c);
#define CK_TILE_HOST
Definition: config.hpp:39
Definition: cluster_descriptor.hpp:13
void reference_batched_gemm_gpu(ADataType *a_ptr, BDataType *b_ptr, CDataType *c_ptr, index_t M, index_t N, index_t K, index_t stride_a, index_t stride_b, index_t stride_c, index_t batch_stride_A, index_t batch_stride_B, index_t batch_stride_C, index_t batch_count)
Definition: reference_gemm.hpp:129
CK_TILE_HOST auto make_ParallelTensorFunctor(F f, Xs... xs)
Definition: host_tensor.hpp:272
__global__ void naive_gemm_kernel(ADataType *A, BDataType *B, CDataType *C, ck_tile::index_t M, ck_tile::index_t N, ck_tile::index_t K, ck_tile::index_t strideA, ck_tile::index_t strideB, ck_tile::index_t strideC)
Definition: reference_gemm.hpp:57
int32_t index_t
Definition: integer.hpp:9
void reference_gemm_gpu(ADataType *a_ptr, BDataType *b_ptr, CDataType *c_ptr, index_t M, index_t N, index_t K, index_t stride_a, index_t stride_b, index_t stride_c)
Definition: reference_gemm.hpp:101
CK_TILE_HOST void reference_gemm(const HostTensor< ADataType > &a_m_k, const HostTensor< BDataType > &b_k_n, HostTensor< CDataType > &c_m_n, const AElementOp &a_element_op={}, const BElementOp &b_element_op={}, const ACCElementOp &acc_element_op={})
Definition: reference_gemm.hpp:21
Definition: host_tensor.hpp:279
std::size_t get_length(std::size_t dim) const
Definition: host_tensor.hpp:331
Definition: functional.hpp:62