/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/norm_reduce/thread/thread_welford.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/norm_reduce/thread/thread_welford.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/norm_reduce/thread/thread_welford.hpp Source File
thread_welford.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include "ck_tile/core.hpp"
7 
8 namespace ck_tile {
9 
10 template <typename T, bool kFastFDiv = false>
11 CK_TILE_DEVICE void welford_update(T& mean, T& var, T x, int count, bool_constant<kFastFDiv> = {})
12 {
13  // TODO: check nan? maybe no
14  T delta = x - mean;
15  if(kFastFDiv && std::is_same_v<T, float>)
16  {
17  mean += delta * __builtin_amdgcn_rcpf(count);
18  }
19  else
20  {
21  mean += delta / count;
22  }
23  T delta2 = x - mean;
24  var += delta * delta2;
25 }
26 
27 template <typename T, bool kFastFDiv = false>
28 CK_TILE_DEVICE static void welford_merge(T& mean_a,
29  T& var_a,
30  int& count_a,
31  T mean_b,
32  T var_b,
33  int count_b,
34  bool_constant<kFastFDiv> = {})
35 {
36  int count = count_a + count_b;
37  T count_ = type_convert<T>(count);
38  T count_a_ = type_convert<T>(count_a);
39  T count_b_ = type_convert<T>(count_b);
40  T count_b_over_count;
41  if(kFastFDiv && std::is_same_v<T, float>)
42  {
43  count_b_over_count =
44  count == 0 ? type_convert<T>(0) : count_b_ * __builtin_amdgcn_rcpf(count_);
45  }
46  else
47  {
48  count_b_over_count = count == 0 ? type_convert<T>(0) : count_b_ / count_;
49  }
50 
51  T delta = mean_b - mean_a;
52  mean_a += delta * count_b_over_count;
53  var_a += var_b + delta * delta * count_a_ * count_b_over_count;
54  count_a = count;
55 }
56 
57 } // namespace ck_tile
#define CK_TILE_DEVICE
Definition: config.hpp:41
Definition: cluster_descriptor.hpp:13
CK_TILE_DEVICE void welford_update(T &mean, T &var, T x, int count, bool_constant< kFastFDiv >={})
Definition: thread_welford.hpp:11
Definition: integral_constant.hpp:13