/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/block/blockwise_softmax.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/block/blockwise_softmax.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/block/blockwise_softmax.hpp Source File
blockwise_softmax.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
12 
13 namespace ck {
14 
25 template <index_t BlockSize,
26  typename AccDataType,
27  typename ThreadMap_M_K, // thread_id to m_k
28  typename ThreadClusterDesc_M_K,
29  typename ThreadSliceDesc_M_K,
30  bool IgnoreNaN = false>
32 {
33  static constexpr auto I0 = Number<0>{};
34  static constexpr auto I1 = Number<1>{};
35  static constexpr index_t MRepeat = ThreadSliceDesc_M_K{}.GetLength(I0);
36  static constexpr index_t KRepeat = ThreadSliceDesc_M_K{}.GetLength(I1);
37 
39  make_tuple(ThreadSliceDesc_M_K{}.GetLength(I0))));
40 
41  using ThreadwiseMaxReduce = typename conditional<
42  IgnoreNaN,
43  ThreadwiseReduction<AccDataType,
44  ThreadSliceDesc_M_K,
47  false,
49  ThreadwiseReduction<AccDataType,
50  ThreadSliceDesc_M_K,
53  false>>::type;
54 
55  using ThreadwiseSumReduce = typename conditional<
56  IgnoreNaN,
57  ThreadwiseReduction<AccDataType,
58  ThreadSliceDesc_M_K,
61  false,
63  ThreadwiseReduction<AccDataType,
64  ThreadSliceDesc_M_K,
67  false>>::type;
68 
69  using ThreadClusterLengths_M_K = decltype(ThreadClusterDesc_M_K{}.GetLengths());
70 
72  BlockSize,
74  ThreadMap_M_K,
76  false>;
77 
79  BlockSize,
81  ThreadMap_M_K,
83  false>;
84 
86 
87  template <typename CThreadBuffer, typename WorkspaceBuffer>
88  __host__ __device__ void Run(CThreadBuffer& in_thread_buf, WorkspaceBuffer& reduce_work_buf)
89  {
90  // find max value
91  static_for<0, MRepeat, 1>{}([&](auto I) {
92  max_value_buf(I) = reduce::Max::template GetIdentityValue<AccDataType>();
93  });
94  ThreadwiseMaxReduce::Reduce(in_thread_buf, max_value_buf);
95  static_for<0, MRepeat, 1>{}([&](auto I) {
96  BlockwiseMaxReduce::Reduce(reduce_work_buf, max_value_buf(I));
98  });
99 
100  // calculate exp for elements, P=exp(s-max)
101  static_for<0, MRepeat, 1>{}([&](auto iM) {
102  static_for<0, KRepeat, 1>{}([&](auto iK) {
103  auto offset = Number<ThreadSliceDesc_M_K{}.CalculateOffset(make_tuple(iM, iK))>{};
104  in_thread_buf(offset) = IgnoreNaN && ck::math::isnan(in_thread_buf[offset])
105  ? 0
106  : math::exp(in_thread_buf[offset] - max_value_buf(iM));
107  });
108  });
109 
110  // sum data
111  static_for<0, MRepeat, 1>{}([&](auto I) {
112  sum_value_buf(I) = reduce::Add::template GetIdentityValue<AccDataType>();
113  });
114  ThreadwiseSumReduce::Reduce(in_thread_buf, sum_value_buf);
115  static_for<0, MRepeat, 1>{}([&](auto I) {
116  BlockwiseSumReduce::Reduce(reduce_work_buf, sum_value_buf(I));
117  block_sync_lds();
118  });
119  }
120 
123 };
124 
125 } // namespace ck
__host__ T exp(T x)
Definition: math_v2.hpp:391
Definition: ck.hpp:267
__host__ constexpr __device__ auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition: tensor_descriptor_helper.hpp:101
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
int32_t index_t
Definition: ck.hpp:298
__device__ void block_sync_lds()
Definition: synchronization.hpp:10
Blockwise softmax.
Definition: blockwise_softmax.hpp:32
decltype(make_naive_tensor_descriptor_packed(make_tuple(ThreadSliceDesc_M_K{}.GetLength(I0)))) ThreadSliceDesc_M
Definition: blockwise_softmax.hpp:39
decltype(ThreadClusterDesc_M_K{}.GetLengths()) ThreadClusterLengths_M_K
Definition: blockwise_softmax.hpp:69
__host__ __device__ void Run(CThreadBuffer &in_thread_buf, WorkspaceBuffer &reduce_work_buf)
Definition: blockwise_softmax.hpp:88
static constexpr index_t MRepeat
Definition: blockwise_softmax.hpp:35
typename conditional< IgnoreNaN, ThreadwiseReduction< AccDataType, ThreadSliceDesc_M_K, ThreadSliceDesc_M, reduce::Add, false, detail::AccumulateWithNanIgnore< reduce::Add, AccDataType > >, ThreadwiseReduction< AccDataType, ThreadSliceDesc_M_K, ThreadSliceDesc_M, reduce::Add, false > >::type ThreadwiseSumReduce
Definition: blockwise_softmax.hpp:67
static constexpr index_t KRepeat
Definition: blockwise_softmax.hpp:36
typename conditional< IgnoreNaN, ThreadwiseReduction< AccDataType, ThreadSliceDesc_M_K, ThreadSliceDesc_M, reduce::Max, false, detail::AccumulateWithNanIgnore< reduce::Max, AccDataType > >, ThreadwiseReduction< AccDataType, ThreadSliceDesc_M_K, ThreadSliceDesc_M, reduce::Max, false > >::type ThreadwiseMaxReduce
Definition: blockwise_softmax.hpp:53
static constexpr auto I0
Definition: blockwise_softmax.hpp:33
static constexpr auto I1
Definition: blockwise_softmax.hpp:34
BufferType sum_value_buf
Definition: blockwise_softmax.hpp:122
BufferType max_value_buf
Definition: blockwise_softmax.hpp:121
Definition: reduction_functions_blockwise.hpp:101
static __device__ void Reduce(BufferType &work_buffer, AccDataType &in_out_value)
Definition: reduction_functions_blockwise.hpp:116
Definition: reduction_functions_threadwise.hpp:23
Definition: functional.hpp:100
Definition: reduction_functions_accumulate.hpp:17
Definition: integral_constant.hpp:20
Definition: reduction_operator.hpp:37
Definition: reduction_operator.hpp:163
Definition: functional2.hpp:33