/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/host/reference/reference_batched_softmax.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/host/reference/reference_batched_softmax.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/host/reference/reference_batched_softmax.hpp Source File
reference_batched_softmax.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include "ck_tile/core.hpp"
8 #include <thread>
9 
10 namespace ck_tile {
11 
12 template <typename ADataType,
13  typename CompDataType,
14  typename BDataType,
15  typename CompElementOp = ck_tile::identity>
17  const HostTensor<ADataType>& a_b_m_n,
18  HostTensor<BDataType>& b_b_m_n,
19  const CompElementOp& comp_element_op = {},
20  std::optional<std::reference_wrapper<HostTensor<CompDataType>>> lse_b_m = std::nullopt)
21 {
22  const int N = a_b_m_n.mDesc.get_lengths()[2];
23 
24  auto f = [&](auto batch, auto m) {
25  CompDataType v_max = -ck_tile::numeric<CompDataType>::infinity();
26 
27  // max
28  for(int n = 0; n < N; ++n)
29  {
30  const CompDataType v_a = ck_tile::type_convert<CompDataType>(a_b_m_n(batch, m, n));
31 
32  v_max = v_max < v_a ? v_a : v_max;
33  }
34 
35  CompDataType v_exp_sum = 0;
36  // validate v_max if all the elements within a row are -INF
37  if(std::isinf(v_max) && v_max < 0)
38  {
39  v_max = ck_tile::type_convert<CompDataType>(0.f);
40  }
41 
42  // sum
43  for(int n = 0; n < N; ++n)
44  {
45  const CompDataType v_a = ck_tile::type_convert<CompDataType>(a_b_m_n(batch, m, n));
46 
47  v_exp_sum += ck_tile::exp(v_a - v_max);
48  }
49 
50  // if sum is zero(masked), or nan/inf(other computation error), don't do divide
51  CompDataType inv_sum = (v_exp_sum == 0.f ? 1.f : 1.f / v_exp_sum);
52 
53  // elementwise
54  for(int n = 0; n < N; ++n)
55  {
56  const CompDataType v_a = ck_tile::type_convert<CompDataType>(a_b_m_n(batch, m, n));
57  const CompDataType v_b = ck_tile::exp(v_a - v_max) * inv_sum;
58 
59  b_b_m_n(batch, m, n) = ck_tile::type_convert<BDataType>(comp_element_op(v_b));
60  }
61  // lse
62  if(lse_b_m)
63  {
64  lse_b_m->get()(batch, m) = v_max + ck_tile::log(v_exp_sum);
65  }
66  };
67 
68  make_ParallelTensorFunctor(f, b_b_m_n.mDesc.get_lengths()[0], b_b_m_n.mDesc.get_lengths()[1])(
69  std::thread::hardware_concurrency());
70 }
71 } // namespace ck_tile
#define CK_TILE_HOST
Definition: config.hpp:40
Definition: cluster_descriptor.hpp:13
CK_TILE_HOST auto make_ParallelTensorFunctor(F f, Xs... xs)
Definition: host_tensor.hpp:329
CK_TILE_DEVICE bfloat16_t log(bfloat16_t x)
Definition: bfloat16.hpp:432
CK_TILE_DEVICE bfloat16_t exp(bfloat16_t x)
Definition: bfloat16.hpp:423
CK_TILE_HOST void reference_batched_softmax(const HostTensor< ADataType > &a_b_m_n, HostTensor< BDataType > &b_b_m_n, const CompElementOp &comp_element_op={}, std::optional< std::reference_wrapper< HostTensor< CompDataType >>> lse_b_m=std::nullopt)
Definition: reference_batched_softmax.hpp:16
const std::vector< std::size_t > & get_lengths() const
Definition: host_tensor.hpp:198
Definition: host_tensor.hpp:336
Descriptor mDesc
Definition: host_tensor.hpp:800
Definition: functional.hpp:86
static constexpr CK_TILE_HOST_DEVICE T infinity()
Definition: numeric.hpp:38