/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/block/blockwise_welford.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/block/blockwise_welford.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/block/blockwise_welford.hpp Source File
blockwise_welford.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
8 
9 namespace ck {
10 
11 // clang-format off
12 // Assume:
13 // 1) work_buffer is buffer (typically LDS) allocated outside as workspace
14 // 2) work_buffer has T elements, and space size is no less than 3*BlockSize
15 // 3) mean_value, var_value and count is the input data in vgpr from each thread
16 // 4) mean_value, var_value and count is the over-written reduced output in vgpr for each thread
17 // 5) Merge mean and M from ThreadwiseWelford
18 // clang-format on
19 template <typename T,
20  index_t BlockSize,
21  typename ThreadClusterLengths_M_K,
22  typename ThreadClusterArrangeOrder,
23  bool GetActualVariance = true>
25 {
26  static_assert(BlockSize == ThreadClusterLengths_M_K::At(0) * ThreadClusterLengths_M_K::At(1),
27  "The product of cluster lengths should be same as BlockSize!");
28 
29  static constexpr auto BufferLength_M = ThreadClusterLengths_M_K::At(0);
30  static constexpr auto BufferLength_K = ThreadClusterLengths_M_K::At(1);
31 
34 
35  static constexpr auto thread_cluster_desc =
36  make_cluster_descriptor(ThreadClusterLengths_M_K{}, ThreadClusterArrangeOrder{});
37 
38  template <typename CountDataType>
39  __device__ static inline void
40  Merge(T& mean_a, T& var_a, CountDataType& count_a, T mean_b, T var_b, CountDataType count_b)
41  {
42  CountDataType count = count_a + count_b;
43  T count_b_over_count = count == 0 ? type_convert<T>(0) : type_convert<T>(count_b) / count;
44  T delta = mean_b - mean_a;
45  mean_a += delta * count_b_over_count;
46  var_a += var_b + delta * delta * count_a * count_b_over_count;
47  count_a = count;
48  }
49 
50  template <typename CountDataType>
51  __device__ static void Run(T& mean_value, T& var_value, CountDataType& count)
52  {
53  __shared__ T mean_block_buf[BlockSize];
54  __shared__ T var_block_buf[BlockSize];
55  __shared__ CountDataType count_block_buf[BlockSize];
56 
57  constexpr auto cluster_len_shift = get_shift<BufferLength_K>();
58 
59  const auto thread_cluster_idx =
61 
62  const auto thread_m_cluster_id = thread_cluster_idx[Number<0>{}];
63  const auto thread_k_cluster_id = thread_cluster_idx[Number<1>{}];
64 
65  index_t offset1 = block_buf_desc_m_k.CalculateOffset(thread_cluster_idx);
66 
67  mean_block_buf[offset1] = mean_value;
68  var_block_buf[offset1] = var_value;
69  count_block_buf[offset1] = count;
70 
72 
74  constexpr index_t indOffset = 1 << (cluster_len_shift - 1 - I());
75 
76  if(thread_k_cluster_id < indOffset)
77  {
78  index_t offset2 = block_buf_desc_m_k.CalculateOffset(thread_cluster_idx +
79  make_tuple(0, indOffset));
80 
81  T mean1 = mean_block_buf[offset1];
82  T var1 = var_block_buf[offset1];
83  CountDataType count1 = count_block_buf[offset1];
84 
85  T mean2 = mean_block_buf[offset2];
86  T var2 = var_block_buf[offset2];
87  CountDataType count2 = count_block_buf[offset2];
88 
89  Merge(mean1, var1, count1, mean2, var2, count2);
90 
91  mean_block_buf[offset1] = mean1;
92  var_block_buf[offset1] = var1;
93  count_block_buf[offset1] = count1;
94  }
95 
97  });
98 
99  index_t offset = block_buf_desc_m_k.CalculateOffset(make_tuple(thread_m_cluster_id, 0));
100 
101  count = count_block_buf[offset];
102  mean_value = mean_block_buf[offset];
103 
104  if constexpr(GetActualVariance)
105  var_value = var_block_buf[offset] / count;
106  else
107  var_value = var_block_buf[offset];
108  };
109 };
110 } // namespace ck
Definition: ck.hpp:267
__host__ constexpr __device__ auto make_multi_index(Xs &&... xs)
Definition: array_multi_index.hpp:15
__host__ constexpr __device__ auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition: tensor_descriptor_helper.hpp:101
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
int32_t index_t
Definition: ck.hpp:298
__device__ index_t get_thread_local_1d_id()
Definition: get_id.hpp:52
__device__ void block_sync_lds()
Definition: synchronization.hpp:10
__host__ constexpr __device__ auto make_cluster_descriptor(const Lengths &lengths, ArrangeOrder order=typename arithmetic_sequence_gen< 0, Lengths::Size(), 1 >::type{})
Definition: cluster_descriptor.hpp:13
Definition: blockwise_welford.hpp:25
static constexpr auto thread_cluster_desc
Definition: blockwise_welford.hpp:35
static constexpr auto BufferLength_K
Definition: blockwise_welford.hpp:30
static constexpr auto block_buf_desc_m_k
Definition: blockwise_welford.hpp:32
static __device__ void Run(T &mean_value, T &var_value, CountDataType &count)
Definition: blockwise_welford.hpp:51
static constexpr auto BufferLength_M
Definition: blockwise_welford.hpp:29
static __device__ void Merge(T &mean_a, T &var_a, CountDataType &count_a, T mean_b, T var_b, CountDataType count_b)
Definition: blockwise_welford.hpp:40
Definition: integral_constant.hpp:20
Definition: functional2.hpp:33