12 template <
typename ADataType,
13 typename CompDataType,
19 const CompElementOp& comp_element_op = {},
20 std::optional<std::reference_wrapper<HostTensor<CompDataType>>> lse_b_m = std::nullopt)
24 auto f = [&](
auto batch,
auto m) {
28 for(
int n = 0; n < N; ++n)
30 const CompDataType v_a = ck_tile::type_convert<CompDataType>(a_b_m_n(batch, m, n));
32 v_max = v_max < v_a ? v_a : v_max;
35 CompDataType v_exp_sum = 0;
37 if(std::isinf(v_max) && v_max < 0)
39 v_max = ck_tile::type_convert<CompDataType>(0.f);
43 for(
int n = 0; n < N; ++n)
45 const CompDataType v_a = ck_tile::type_convert<CompDataType>(a_b_m_n(batch, m, n));
51 CompDataType inv_sum = (v_exp_sum == 0.f ? 1.f : 1.f / v_exp_sum);
54 for(
int n = 0; n < N; ++n)
56 const CompDataType v_a = ck_tile::type_convert<CompDataType>(a_b_m_n(batch, m, n));
57 const CompDataType v_b =
ck_tile::exp(v_a - v_max) * inv_sum;
59 b_b_m_n(batch, m, n) = ck_tile::type_convert<BDataType>(comp_element_op(v_b));
64 lse_b_m->get()(batch, m) = v_max +
ck_tile::log(v_exp_sum);
69 std::thread::hardware_concurrency());
#define CK_TILE_HOST
Definition: config.hpp:39
Definition: cluster_descriptor.hpp:13
CK_TILE_HOST auto make_ParallelTensorFunctor(F f, Xs... xs)
Definition: host_tensor.hpp:272
CK_TILE_DEVICE bfloat16_t log(bfloat16_t x)
Definition: bfloat16.hpp:423
CK_TILE_DEVICE bfloat16_t exp(bfloat16_t x)
Definition: bfloat16.hpp:414
CK_TILE_HOST void reference_batched_softmax(const HostTensor< ADataType > &a_b_m_n, HostTensor< BDataType > &b_b_m_n, const CompElementOp &comp_element_op={}, std::optional< std::reference_wrapper< HostTensor< CompDataType >>> lse_b_m=std::nullopt)
Definition: reference_batched_softmax.hpp:16
const std::vector< std::size_t > & get_lengths() const
Definition: host_tensor.hpp:162
Definition: host_tensor.hpp:279
Descriptor mDesc
Definition: host_tensor.hpp:678
Definition: functional.hpp:62
static constexpr CK_TILE_HOST_DEVICE T infinity()
Definition: numeric.hpp:38