21 typename ThreadClusterLengths_M_K,
22 typename ThreadClusterArrangeOrder,
23 bool GetActualVariance =
true>
26 static_assert(BlockSize == ThreadClusterLengths_M_K::At(0) * ThreadClusterLengths_M_K::At(1),
27 "The product of cluster lengths should be same as BlockSize!");
38 template <
typename CountDataType>
39 __device__
static inline void
40 Merge(T& mean_a, T& var_a, CountDataType& count_a, T mean_b, T var_b, CountDataType count_b)
42 CountDataType count = count_a + count_b;
43 T count_b_over_count = count == 0 ? type_convert<T>(0) : type_convert<T>(count_b) / count;
44 T delta = mean_b - mean_a;
45 mean_a += delta * count_b_over_count;
46 var_a += var_b + delta * delta * count_a * count_b_over_count;
50 template <
typename CountDataType>
51 __device__
static void Run(T& mean_value, T& var_value, CountDataType& count)
53 __shared__ T mean_block_buf[BlockSize];
54 __shared__ T var_block_buf[BlockSize];
55 __shared__ CountDataType count_block_buf[BlockSize];
57 constexpr
auto cluster_len_shift = get_shift<BufferLength_K>();
59 const auto thread_cluster_idx =
62 const auto thread_m_cluster_id = thread_cluster_idx[
Number<0>{}];
63 const auto thread_k_cluster_id = thread_cluster_idx[
Number<1>{}];
67 mean_block_buf[offset1] = mean_value;
68 var_block_buf[offset1] = var_value;
69 count_block_buf[offset1] = count;
74 constexpr
index_t indOffset = 1 << (cluster_len_shift - 1 - I());
76 if(thread_k_cluster_id < indOffset)
81 T mean1 = mean_block_buf[offset1];
82 T var1 = var_block_buf[offset1];
83 CountDataType count1 = count_block_buf[offset1];
85 T mean2 = mean_block_buf[offset2];
86 T var2 = var_block_buf[offset2];
87 CountDataType count2 = count_block_buf[offset2];
89 Merge(mean1, var1, count1, mean2, var2, count2);
91 mean_block_buf[offset1] = mean1;
92 var_block_buf[offset1] = var1;
93 count_block_buf[offset1] = count1;
101 count = count_block_buf[offset];
102 mean_value = mean_block_buf[offset];
104 if constexpr(GetActualVariance)
105 var_value = var_block_buf[offset] / count;
107 var_value = var_block_buf[offset];
__host__ constexpr __device__ auto make_multi_index(Xs &&... xs)
Definition: array_multi_index.hpp:15
__host__ constexpr __device__ auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition: tensor_descriptor_helper.hpp:101
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
int32_t index_t
Definition: ck.hpp:298
__device__ index_t get_thread_local_1d_id()
Definition: get_id.hpp:52
__device__ void block_sync_lds()
Definition: synchronization.hpp:10
__host__ constexpr __device__ auto make_cluster_descriptor(const Lengths &lengths, ArrangeOrder order=typename arithmetic_sequence_gen< 0, Lengths::Size(), 1 >::type{})
Definition: cluster_descriptor.hpp:13
Definition: blockwise_welford.hpp:25
static constexpr auto thread_cluster_desc
Definition: blockwise_welford.hpp:35
static constexpr auto BufferLength_K
Definition: blockwise_welford.hpp:30
static constexpr auto block_buf_desc_m_k
Definition: blockwise_welford.hpp:32
static __device__ void Run(T &mean_value, T &var_value, CountDataType &count)
Definition: blockwise_welford.hpp:51
static constexpr auto BufferLength_M
Definition: blockwise_welford.hpp:29
static __device__ void Merge(T &mean_a, T &var_a, CountDataType &count_a, T mean_b, T var_b, CountDataType count_b)
Definition: blockwise_welford.hpp:40
Definition: integral_constant.hpp:20
Definition: functional2.hpp:33