/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/host/reference/reference_batched_dropout.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/host/reference/reference_batched_dropout.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/host/reference/reference_batched_dropout.hpp Source File
reference_batched_dropout.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include "ck_tile/core.hpp"
8 #include <thread>
9 
10 namespace ck_tile {
11 
12 template <typename DataType, typename RandValOutputDataType>
14  const HostTensor<RandValOutputDataType>& randval_b_m_n,
15  const uint8_t& p_undrop_in_uint8_t,
16  const float scale)
17 {
18  const int N = in_out_b_m_n.mDesc.get_lengths()[2];
19  auto f = [&](auto batch, auto m) {
20  for(int n = 0; n < N; ++n)
21  {
22  float tmp = ck_tile::type_convert<float>(in_out_b_m_n(batch, m, n)) * scale;
23  in_out_b_m_n(batch, m, n) = randval_b_m_n(batch, m, n) <= p_undrop_in_uint8_t
24  ? ck_tile::type_convert<DataType>(tmp)
25  : DataType(0);
26  }
27  };
28 
30  f, randval_b_m_n.mDesc.get_lengths()[0], randval_b_m_n.mDesc.get_lengths()[1])(
31  std::thread::hardware_concurrency());
32 }
33 } // namespace ck_tile
#define CK_TILE_HOST
Definition: config.hpp:40
Definition: cluster_descriptor.hpp:13
CK_TILE_HOST auto make_ParallelTensorFunctor(F f, Xs... xs)
Definition: host_tensor.hpp:329
CK_TILE_HOST void reference_batched_dropout(HostTensor< DataType > &in_out_b_m_n, const HostTensor< RandValOutputDataType > &randval_b_m_n, const uint8_t &p_undrop_in_uint8_t, const float scale)
Definition: reference_batched_dropout.hpp:13
unsigned char uint8_t
Definition: stdint.h:124
const std::vector< std::size_t > & get_lengths() const
Definition: host_tensor.hpp:198
Definition: host_tensor.hpp:336
Descriptor mDesc
Definition: host_tensor.hpp:800