/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/thread/reduction_functions_threadwise.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/thread/reduction_functions_threadwise.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/thread/reduction_functions_threadwise.hpp Source File
reduction_functions_threadwise.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
7 
8 namespace ck {
9 
10 // Assume
11 // 1) SrcDesc is known at compile-time
12 // 2) DstDesc is known at compile-time
13 // 3) SrcBuffer is static buffer
14 // 4) DstBuffer is static buffer
15 template <typename AccDataType,
16  typename SrcThreadDesc_M_K,
17  typename DstThreadDesc_M,
18  typename OpReduce,
19  bool PropagateNan,
20  typename Accumulation =
21  detail::AccumulateWithNanCheck<PropagateNan, OpReduce, AccDataType>>
23 {
24  static constexpr auto src_thread_desc_m_k = SrcThreadDesc_M_K{};
25  static constexpr auto dst_thread_desc_m = DstThreadDesc_M{};
26 
27  static constexpr auto src_length_m = src_thread_desc_m_k.GetLength(Number<0>{});
28  static constexpr auto src_length_k = src_thread_desc_m_k.GetLength(Number<1>{});
29  static constexpr auto dst_length_m = dst_thread_desc_m.GetLength(Number<0>{});
30 
31  static_assert(src_length_m == dst_length_m, "lengths of source and dst buffer must match!");
32 
33  using Op = OpReduce;
34 
35  template <typename SrcBufferType, typename DstBufferType>
36  __device__ static void Reduce(const SrcBufferType& src_buf, DstBufferType& dst_buf)
37  {
38  static_for<0, src_length_m, 1>{}([&](auto iM) {
39  constexpr index_t out_offset = dst_thread_desc_m.CalculateOffset(make_tuple(iM));
40 
41  static_for<0, src_length_k, 1>{}([&](auto iK) {
42  constexpr auto offset = src_thread_desc_m_k.CalculateOffset(make_tuple(iM, iK));
43 
44  Accumulation::Calculate(dst_buf(Number<out_offset>{}), src_buf[Number<offset>{}]);
45  });
46  });
47  };
48 };
49 
50 // Assume
51 // 1) SrcDesc is known at compile-time
52 // 2) DstDesc is known at compile-time
53 // 3) SrcBuffer is static buffer
54 // 4) DstBuffer is static buffer
55 template <
56  typename AccDataType,
57  typename IndexDataType,
58  typename SrcThreadDesc_M_K,
59  typename DstThreadDesc_M,
60  typename OpReduce,
61  bool PropagateNan,
62  typename Accumulation =
63  detail::AccumulateWithIndexAndNanCheck<PropagateNan, OpReduce, AccDataType, IndexDataType>>
65 {
66  static constexpr auto src_thread_desc_m_k = SrcThreadDesc_M_K{};
67  static constexpr auto dst_thread_desc_m = DstThreadDesc_M{};
68 
69  static constexpr auto src_length_m = src_thread_desc_m_k.GetLength(Number<0>{});
70  static constexpr auto src_length_k = src_thread_desc_m_k.GetLength(Number<1>{});
71  static constexpr auto dst_length_m = dst_thread_desc_m.GetLength(Number<0>{});
72 
73  static_assert(src_length_m == dst_length_m, "lengths of source and dst buffer must match!");
74 
75  template <typename SrcValueBufferType,
76  typename SrcIndexBufferType,
77  typename DstValueBufferType,
78  typename DstIndexBufferType>
79  __device__ static void Reduce(const SrcValueBufferType& src_val_buf,
80  const SrcIndexBufferType& src_idx_buf,
81  DstValueBufferType& dst_val_buf,
82  DstIndexBufferType& dst_idx_buf)
83  {
84  static_for<0, src_length_m, 1>{}([&](auto iM) {
85  constexpr index_t out_offset = dst_thread_desc_m.CalculateOffset(make_tuple(iM));
86 
87  static_for<0, src_length_k, 1>{}([&](auto iK) {
88  constexpr auto offset = src_thread_desc_m_k.CalculateOffset(make_tuple(iM, iK));
89 
90  Accumulation::Calculate(dst_val_buf(Number<out_offset>{}),
91  src_val_buf[Number<offset>{}],
92  dst_idx_buf(Number<out_offset>{}),
93  src_idx_buf[Number<offset>{}]);
94  });
95  });
96  };
97 };
98 
99 } // namespace ck
Definition: ck.hpp:267
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
int32_t index_t
Definition: ck.hpp:298
Definition: reduction_functions_threadwise.hpp:23
static constexpr auto dst_thread_desc_m
Definition: reduction_functions_threadwise.hpp:25
static constexpr auto src_length_m
Definition: reduction_functions_threadwise.hpp:27
static __device__ void Reduce(const SrcBufferType &src_buf, DstBufferType &dst_buf)
Definition: reduction_functions_threadwise.hpp:36
static constexpr auto dst_length_m
Definition: reduction_functions_threadwise.hpp:29
OpReduce Op
Definition: reduction_functions_threadwise.hpp:33
static constexpr auto src_thread_desc_m_k
Definition: reduction_functions_threadwise.hpp:24
static constexpr auto src_length_k
Definition: reduction_functions_threadwise.hpp:28
Definition: reduction_functions_threadwise.hpp:65
static constexpr auto src_thread_desc_m_k
Definition: reduction_functions_threadwise.hpp:66
static constexpr auto src_length_m
Definition: reduction_functions_threadwise.hpp:69
static __device__ void Reduce(const SrcValueBufferType &src_val_buf, const SrcIndexBufferType &src_idx_buf, DstValueBufferType &dst_val_buf, DstIndexBufferType &dst_idx_buf)
Definition: reduction_functions_threadwise.hpp:79
static constexpr auto src_length_k
Definition: reduction_functions_threadwise.hpp:70
static constexpr auto dst_length_m
Definition: reduction_functions_threadwise.hpp:71
static constexpr auto dst_thread_desc_m
Definition: reduction_functions_threadwise.hpp:67
Definition: integral_constant.hpp:20
Definition: functional2.hpp:33