/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/library/utility/host_gemm.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/library/utility/host_gemm.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/library/utility/host_gemm.hpp Source File
host_gemm.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include "host_tensor.hpp"
7 
8 template <typename AType,
9  typename BType,
10  typename CType,
11  typename AElementwiseOperation,
12  typename BElementwiseOperation,
13  typename CElementwiseOperation>
15  const Tensor<BType>& b_k_n,
16  Tensor<CType>& c_m_n,
17  const AElementwiseOperation& a_element_op,
18  const BElementwiseOperation& b_element_op,
19  const CElementwiseOperation& c_element_op)
20 {
21  auto f_mk_kn_mn = [&](auto m, auto n) {
22  const int K = a_m_k.mDesc.GetLengths()[1];
23 
24  float v_acc = 0;
25 
26  for(int k = 0; k < K; ++k)
27  {
28  float v_a;
29  float v_b;
30 
31  a_element_op(v_a, static_cast<const float>(a_m_k(m, k)));
32  b_element_op(v_b, static_cast<const float>(b_k_n(k, n)));
33 
34  v_acc += v_a * v_b;
35  }
36 
37  float v_c;
38 
39  c_element_op(v_c, v_acc);
40 
41  c_m_n(m, n) = v_c;
42  };
43 
44  make_ParallelTensorFunctor(f_mk_kn_mn,
45  c_m_n.mDesc.GetLengths()[0],
46  c_m_n.mDesc.GetLengths()[1])(std::thread::hardware_concurrency());
47 }
void host_gemm_mk_kn_mn(const Tensor< AType > &a_m_k, const Tensor< BType > &b_k_n, Tensor< CType > &c_m_n, const AElementwiseOperation &a_element_op, const BElementwiseOperation &b_element_op, const CElementwiseOperation &c_element_op)
Definition: host_gemm.hpp:14
auto make_ParallelTensorFunctor(F f, Xs... xs)
Definition: host_tensor.hpp:270
const std::vector< std::size_t > & GetLengths() const
Tensor wrapper that performs static and dynamic buffer logic. The tensor is based on a descriptor sto...
Definition: host_tensor.hpp:277
Descriptor mDesc
Definition: host_tensor.hpp:712