/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/softmax/block/block_softmax_2d.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/softmax/block/block_softmax_2d.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/softmax/block/block_softmax_2d.hpp Source File
block_softmax_2d.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include "ck_tile/core.hpp"
7 #include "ck_tile/ops/reduce.hpp"
8 
9 #define _BLOCK_SOFTMAX_USE_UNPACK2 0
10 
11 namespace ck_tile {
12 
13 /*
14 simple 2d softmax implementation, along row (dim=1)
15 requirement:
16  1). each row is within a warp
17  2). data type must be a dword
18 */
19 template <typename Problem_, typename Policy_ = void>
21 {
24 
25  using DataType = typename Problem::DataType;
26 
27  template <typename DistributedTensor, index_t dim = 1>
28  CK_TILE_DEVICE void
29  operator()(const DistributedTensor& x, DistributedTensor& y, number<dim> = {})
30  {
31  const auto f_max = [](auto e0, auto e1) { return max(e0, e1); };
32  const auto f_sum = [](auto e0, auto e1) { return e0 + e1; };
33 #if _BLOCK_SOFTMAX_USE_UNPACK2
34  const auto f_max3 = [](auto e0, auto e1, auto e2) {
35  float rtn;
36  asm volatile("v_max3_f32 %0, %1, %2, %3" : "=v"(rtn) : "v"(e0), "v"(e1), "v"(e2));
37  return rtn;
38  };
39  const auto f_sum3 = [](auto e0, auto e1, auto e2) { return e0 + e1 + e2; };
40 #endif
41 
42  // compute row max
43  auto reduce_row_max = BlockReduce2D{x, -numeric<DataType>::infinity()};
44 #if _BLOCK_SOFTMAX_USE_UNPACK2
45  auto row_max = reduce_row_max(f_max3, f_max, sequence<1, 2>{});
46 #else
47  auto row_max = reduce_row_max(f_max);
48 #endif
49  sweep_tile<DistributedTensor>([&](auto idx) {
50  constexpr auto row_id = make_tuple(idx[number<0>{}]);
51  y(idx) = exp(x[idx] - row_max[row_id]);
52  });
53 
54  // compute row sum
55  auto reduce_row_sum = BlockReduce2D<decltype(y)>{y, DataType{0}};
56 #if _BLOCK_SOFTMAX_USE_UNPACK2
57  auto row_sum = reduce_row_sum(f_sum3, f_sum, sequence<1, 2>{});
58 #else
59  auto row_sum = reduce_row_sum(f_sum);
60 #endif
61  // reciprocal
62  auto r = make_static_distributed_tensor<DataType>(row_sum.get_tile_distribution());
63  sweep_tile(row_sum, [&](auto idx) { r(idx) = DataType{1} / row_sum(idx); });
64 
65  // scale
66  sweep_tile<DistributedTensor>([&](auto idx) {
67  constexpr auto row_id = make_tuple(idx[number<0>{}]);
68  y(idx) = y(idx) * r(row_id);
69  });
70  }
71 
72  template <typename DistributedTensor, index_t dim = 1>
73  CK_TILE_DEVICE decltype(auto) operator()(const DistributedTensor& x, number<dim> = {})
74  {
75  auto y = DistributedTensor{}; // distributed tensor
76  operator()(x, y, number<dim>{});
77  return y;
78  }
79 };
80 
81 } // namespace ck_tile
#define CK_TILE_DEVICE
Definition: config.hpp:41
Definition: cluster_descriptor.hpp:13
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition: type_traits.hpp:21
CK_TILE_DEVICE bfloat16_t exp(bfloat16_t x)
Definition: bfloat16.hpp:423
constexpr CK_TILE_HOST_DEVICE auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:360
constexpr CK_TILE_HOST_DEVICE T max(T x)
Definition: math.hpp:161
constexpr CK_TILE_HOST_DEVICE void sweep_tile(const F &f, UnpacksPerXDim={})
Definition: sweep_tile.hpp:231
CK_TILE_HOST_DEVICE_EXTERN BlockReduce2D(const T &, const typename T::DataType &) -> BlockReduce2D< T >
Definition: block_softmax_2d.hpp:21
remove_cvref_t< Problem_ > Problem
Definition: block_softmax_2d.hpp:22
CK_TILE_DEVICE void operator()(const DistributedTensor &x, DistributedTensor &y, number< dim >={})
Definition: block_softmax_2d.hpp:29
remove_cvref_t< Policy_ > Policy
Definition: block_softmax_2d.hpp:23
typename Problem::DataType DataType
Definition: block_softmax_2d.hpp:25
Definition: integral_constant.hpp:13
static constexpr CK_TILE_HOST_DEVICE T infinity()
Definition: numeric.hpp:38