9 #define _BLOCK_SOFTMAX_USE_UNPACK2 0
19 template <
typename Problem_,
typename Policy_ =
void>
27 template <
typename DistributedTensor, index_t dim = 1>
31 const auto f_max = [](
auto e0,
auto e1) {
return max(e0, e1); };
32 const auto f_sum = [](
auto e0,
auto e1) {
return e0 + e1; };
33 #if _BLOCK_SOFTMAX_USE_UNPACK2
34 const auto f_max3 = [](
auto e0,
auto e1,
auto e2) {
36 asm volatile(
"v_max3_f32 %0, %1, %2, %3" :
"=v"(rtn) :
"v"(e0),
"v"(e1),
"v"(e2));
39 const auto f_sum3 = [](
auto e0,
auto e1,
auto e2) {
return e0 + e1 + e2; };
44 #if _BLOCK_SOFTMAX_USE_UNPACK2
45 auto row_max = reduce_row_max(f_max3, f_max, sequence<1, 2>{});
47 auto row_max = reduce_row_max(f_max);
49 sweep_tile<DistributedTensor>([&](
auto idx) {
50 constexpr
auto row_id =
make_tuple(idx[number<0>{}]);
51 y(idx) =
exp(x[idx] - row_max[row_id]);
56 #if _BLOCK_SOFTMAX_USE_UNPACK2
57 auto row_sum = reduce_row_sum(f_sum3, f_sum, sequence<1, 2>{});
59 auto row_sum = reduce_row_sum(f_sum);
62 auto r = make_static_distributed_tensor<DataType>(row_sum.get_tile_distribution());
66 sweep_tile<DistributedTensor>([&](
auto idx) {
67 constexpr
auto row_id =
make_tuple(idx[number<0>{}]);
68 y(idx) = y(idx) * r(row_id);
72 template <
typename DistributedTensor, index_t dim = 1>
75 auto y = DistributedTensor{};
#define CK_TILE_DEVICE
Definition: config.hpp:40
Definition: cluster_descriptor.hpp:13
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition: type_traits.hpp:20
CK_TILE_DEVICE bfloat16_t exp(bfloat16_t x)
Definition: bfloat16.hpp:414
constexpr CK_TILE_HOST_DEVICE auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:337
constexpr CK_TILE_HOST_DEVICE T max(T x)
Definition: math.hpp:161
constexpr CK_TILE_HOST_DEVICE void sweep_tile(const F &f, UnpacksPerXDim={})
Definition: sweep_tile.hpp:231
CK_TILE_HOST_DEVICE_EXTERN BlockReduce2D(const T &, const typename T::DataType &) -> BlockReduce2D< T >
Definition: block_softmax_2d.hpp:21
remove_cvref_t< Problem_ > Problem
Definition: block_softmax_2d.hpp:22
CK_TILE_DEVICE void operator()(const DistributedTensor &x, DistributedTensor &y, number< dim >={})
Definition: block_softmax_2d.hpp:29
remove_cvref_t< Policy_ > Policy
Definition: block_softmax_2d.hpp:23
typename Problem::DataType DataType
Definition: block_softmax_2d.hpp:25
Definition: integral_constant.hpp:13
static constexpr CK_TILE_HOST_DEVICE T infinity()
Definition: numeric.hpp:38