/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/host/reference/reference_rowwise_quantization2d.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/host/reference/reference_rowwise_quantization2d.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/host/reference/reference_rowwise_quantization2d.hpp Source File
reference_rowwise_quantization2d.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include "ck_tile/core.hpp"
8 #include <thread>
9 
10 namespace ck_tile {
11 template <typename XDataType, typename ScaleDataType, typename QXDataType>
13  const HostTensor<ScaleDataType>& scale_m,
14  HostTensor<QXDataType>& qx_m_n)
15 {
16  auto f = [&](auto m) {
17  const int N = x_m_n.mDesc.get_lengths()[1];
18 
19  for(int n = 0; n < N; ++n)
20  {
21  auto v_x = x_m_n(m, n);
22  // scale = amax / 127 for int8
23  auto v_scale = type_convert<XDataType>(scale_m(m));
24  auto v_qx = v_x / v_scale;
25  qx_m_n(m, n) = type_convert<QXDataType>(saturates<QXDataType>{}(v_qx));
26  }
27  };
28 
30  scale_m.mDesc.get_lengths()[0])(std::thread::hardware_concurrency());
31 }
32 
33 } // namespace ck_tile
#define CK_TILE_HOST
Definition: config.hpp:40
Definition: cluster_descriptor.hpp:13
CK_TILE_HOST auto make_ParallelTensorFunctor(F f, Xs... xs)
Definition: host_tensor.hpp:329
CK_TILE_HOST void reference_rowwise_quantization2d(const HostTensor< XDataType > &x_m_n, const HostTensor< ScaleDataType > &scale_m, HostTensor< QXDataType > &qx_m_n)
Definition: reference_rowwise_quantization2d.hpp:12
const std::vector< std::size_t > & get_lengths() const
Definition: host_tensor.hpp:198
Definition: host_tensor.hpp:336
Descriptor mDesc
Definition: host_tensor.hpp:800
Definition: unary_element_function.hpp:56