/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/host/reference/reference_softmax.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/host/reference/reference_softmax.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/host/reference/reference_softmax.hpp Source File
reference_softmax.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include "ck_tile/core.hpp"
8 #include <thread>
9 
10 namespace ck_tile {
11 
12 template <typename InputType, typename ComputeType, typename OutputType = ComputeType>
13 CK_TILE_HOST void
15 {
17  assert(static_cast<std::size_t>(rank) == y.get_num_of_dimension());
18  assert(dim == -1 || dim < rank);
19 
20  index_t target_dim = dim == -1 ? (rank - 1) : dim;
21  index_t softmax_len = x.get_length(target_dim);
22  index_t n_parallel = x.get_element_size() / softmax_len;
23  auto x_len = x.get_lengths();
24 
25  auto f = [&](auto i_element) {
26  std::vector<size_t> coord = [&]() {
27  std::vector<size_t> t_(rank, 0);
28  size_t r = i_element;
29  for(index_t i = rank - 1; i >= 0; i--)
30  {
31  if(i == target_dim)
32  continue;
33  t_[i] = r % x_len[i];
34  r = r / x_len[i];
35  }
36  return t_;
37  }();
38 
39  ComputeType v_max = -ck_tile::numeric<ComputeType>::infinity();
40 
41  // compute max
42  for(auto idx = 0; idx < softmax_len; idx++)
43  {
44  auto c_ = coord;
45  c_[target_dim] = idx;
46  const ComputeType v_x = ck_tile::type_convert<ComputeType>(x(c_));
47  v_max = v_max < v_x ? v_x : v_max;
48  }
49 
50  ComputeType v_exp_sum = static_cast<ComputeType>(0);
51 
52  // sum
53  for(auto idx = 0; idx < softmax_len; idx++)
54  {
55  auto c_ = coord;
56  c_[target_dim] = idx;
57 
58  const ComputeType v_x = ck_tile::type_convert<ComputeType>(x(c_));
59 
60  v_exp_sum += ck_tile::exp(v_x - v_max);
61  }
62 
63  // elementwise
64  for(auto idx = 0; idx < softmax_len; idx++)
65  {
66  auto c_ = coord;
67  c_[target_dim] = idx;
68 
69  const ComputeType v_x = ck_tile::type_convert<ComputeType>(x(c_));
70 
71  auto out = ck_tile::exp(v_x - v_max) / v_exp_sum;
72 
73  y(c_) = ck_tile::type_convert<OutputType>(out);
74  }
75  };
76 
77  make_ParallelTensorFunctor(f, n_parallel)(std::thread::hardware_concurrency());
78 }
79 
80 template <typename InputType, typename ComputeType, typename OutputType = ComputeType>
82 {
84 
85  reference_softmax<InputType, ComputeType, OutputType>(x, y, dim);
86 
87  return y;
88 }
89 } // namespace ck_tile
#define CK_TILE_HOST
Definition: config.hpp:40
__host__ constexpr __device__ auto rank([[maybe_unused]] const Layout< Shape, UnrolledDescriptorType > &layout)
Get layout rank (num elements in shape).
Definition: layout_utils.hpp:310
Definition: cluster_descriptor.hpp:13
CK_TILE_HOST auto make_ParallelTensorFunctor(F f, Xs... xs)
Definition: host_tensor.hpp:329
CK_TILE_HOST void reference_softmax(const HostTensor< InputType > &x, HostTensor< OutputType > &y, index_t dim=-1)
Definition: reference_softmax.hpp:14
int32_t index_t
Definition: integer.hpp:9
CK_TILE_DEVICE bfloat16_t exp(bfloat16_t x)
Definition: bfloat16.hpp:423
Definition: host_tensor.hpp:336
decltype(auto) get_lengths() const
Definition: host_tensor.hpp:390
decltype(auto) get_strides() const
Definition: host_tensor.hpp:394
std::size_t get_num_of_dimension() const
Definition: host_tensor.hpp:396
std::size_t get_length(std::size_t dim) const
Definition: host_tensor.hpp:388
std::size_t get_element_size() const
Definition: host_tensor.hpp:398
static constexpr CK_TILE_HOST_DEVICE T infinity()
Definition: numeric.hpp:38