/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/host/reference/reference_batched_elementwise.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/host/reference/reference_batched_elementwise.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/host/reference/reference_batched_elementwise.hpp Source File
reference_batched_elementwise.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include "ck_tile/core.hpp"
8 #include <thread>
9 
10 namespace ck_tile {
11 
12 template <typename ADataType,
13  typename BDataType,
14  typename AccDataType,
15  typename CDataType,
16  typename AElementOp = ck_tile::identity,
17  typename BElementOp = ck_tile::identity,
18  typename BinaryElementOp = ck_tile::plus<AccDataType>>
20  const HostTensor<BDataType>& b_b_m_n,
21  HostTensor<CDataType>& c_b_m_n,
22  const AElementOp& a_element_op = {},
23  const BElementOp& b_element_op = {},
24  const BinaryElementOp& binary_element_op = {})
25 {
26  const ck_tile::index_t N = c_b_m_n.mDesc.get_lengths()[2];
27 
28  const bool broadcast_a_dim_b = (a_b_m_n.get_lengths()[0] == 1);
29  const bool broadcast_a_dim_m = (a_b_m_n.get_lengths()[1] == 1);
30  const bool broadcast_a_dim_n = (a_b_m_n.get_lengths()[2] == 1);
31 
32  const bool broadcast_b_dim_b = (b_b_m_n.get_lengths()[0] == 1);
33  const bool broadcast_b_dim_m = (b_b_m_n.get_lengths()[1] == 1);
34  const bool broadcast_b_dim_n = (b_b_m_n.get_lengths()[2] == 1);
35 
36  auto f = [&](auto batch, auto m) {
37  for(ck_tile::index_t n = 0; n < N; ++n)
38  {
39  AccDataType v_a{};
40  {
41  ck_tile::index_t i_b = (broadcast_a_dim_b ? 0 : batch);
42  ck_tile::index_t i_m = (broadcast_a_dim_m ? 0 : m);
43  ck_tile::index_t i_n = (broadcast_a_dim_n ? 0 : n);
44 
45  v_a = ck_tile::type_convert<AccDataType>(a_element_op(a_b_m_n(i_b, i_m, i_n)));
46  }
47 
48  AccDataType v_b{};
49  {
50  ck_tile::index_t i_b = (broadcast_b_dim_b ? 0 : batch);
51  ck_tile::index_t i_m = (broadcast_b_dim_m ? 0 : m);
52  ck_tile::index_t i_n = (broadcast_b_dim_n ? 0 : n);
53 
54  v_b = ck_tile::type_convert<AccDataType>(b_element_op(b_b_m_n(i_b, i_m, i_n)));
55  }
56 
57  c_b_m_n(batch, m, n) = ck_tile::type_convert<CDataType>(binary_element_op(v_a, v_b));
58  }
59  };
60 
61  make_ParallelTensorFunctor(f, c_b_m_n.mDesc.get_lengths()[0], c_b_m_n.mDesc.get_lengths()[1])(
62  std::thread::hardware_concurrency());
63 }
64 } // namespace ck_tile
#define CK_TILE_HOST
Definition: config.hpp:40
Definition: cluster_descriptor.hpp:13
CK_TILE_HOST auto make_ParallelTensorFunctor(F f, Xs... xs)
Definition: host_tensor.hpp:329
int32_t index_t
Definition: integer.hpp:9
CK_TILE_HOST void reference_batched_elementwise(const HostTensor< ADataType > &a_b_m_n, const HostTensor< BDataType > &b_b_m_n, HostTensor< CDataType > &c_b_m_n, const AElementOp &a_element_op={}, const BElementOp &b_element_op={}, const BinaryElementOp &binary_element_op={})
Definition: reference_batched_elementwise.hpp:19
const std::vector< std::size_t > & get_lengths() const
Definition: host_tensor.hpp:198
Definition: host_tensor.hpp:336
decltype(auto) get_lengths() const
Definition: host_tensor.hpp:390
Descriptor mDesc
Definition: host_tensor.hpp:800
Definition: functional.hpp:86
Definition: math.hpp:50