14 template <
typename OutDataType,
typename AccDataType>
18 for(
int n = 0; n < N; ++n)
20 o(m, n) = ck_tile::type_convert<OutDataType>(acc(m, n));
24 template <
typename OutDataType,
typename AccDataType>
33 template <
typename XDataType,
34 typename GammaDataType,
35 typename BetaDataType,
36 typename ComputeDataType,
38 typename MeanDataType,
39 typename InvStdDataType,
40 typename Epilogue = reference_layernorm2d_default_epilogue>
47 ComputeDataType epsilon,
48 Epilogue epilogue_functor = {})
50 auto layernorm2d_fwd_func = [&](
auto m) {
54 ComputeDataType mean = 0;
55 ComputeDataType variance = 0;
56 ComputeDataType divisor = 0;
58 for(
int n = 0; n < N; ++n)
61 ComputeDataType x = ck_tile::type_convert<ComputeDataType>(x_m_n(m, n));
62 ComputeDataType delta = x - mean;
63 mean += delta / count;
64 ComputeDataType delta2 = x - mean;
65 variance += delta * delta2;
69 variance = variance / count;
70 divisor = ck_tile::type_convert<ComputeDataType>(1) /
ck_tile::sqrt(variance + epsilon);
72 if constexpr(!std::is_same_v<MeanDataType, ck_tile::null_type>)
73 mean_m(m) = ck_tile::type_convert<MeanDataType>(mean);
75 if constexpr(!std::is_same_v<InvStdDataType, ck_tile::null_type>)
76 invStd_m(m) = ck_tile::type_convert<InvStdDataType>(divisor);
79 for(
int n = 0; n < N; ++n)
81 ComputeDataType x = ck_tile::type_convert<ComputeDataType>(x_m_n(m, n));
82 ComputeDataType gamma = ck_tile::type_convert<ComputeDataType>(gamma_n(n));
83 ComputeDataType beta = ck_tile::type_convert<ComputeDataType>(beta_n(n));
84 auto a_ = (x - mean) * divisor;
85 a_ = a_ * gamma + beta;
90 epilogue_functor(m, y_m_n, acc);
Definition: cluster_descriptor.hpp:13
CK_TILE_HOST auto make_ParallelTensorFunctor(F f, Xs... xs)
Definition: host_tensor.hpp:329
CK_TILE_DEVICE bfloat16_t sqrt(bfloat16_t x)
Definition: bfloat16.hpp:417
void reference_layernorm2d_fwd(const HostTensor< XDataType > &x_m_n, const HostTensor< GammaDataType > &gamma_n, const HostTensor< BetaDataType > &beta_n, HostTensor< YDataType > &y_m_n, HostTensor< MeanDataType > &mean_m, HostTensor< InvStdDataType > &invStd_m, ComputeDataType epsilon, Epilogue epilogue_functor={})
Definition: reference_layernorm2d_fwd.hpp:41
const std::vector< std::size_t > & get_lengths() const
Definition: host_tensor.hpp:198
Definition: host_tensor.hpp:336
decltype(auto) get_lengths() const
Definition: host_tensor.hpp:390
decltype(auto) get_strides() const
Definition: host_tensor.hpp:394
Descriptor mDesc
Definition: host_tensor.hpp:800
Definition: reference_layernorm2d_fwd.hpp:13
auto operator()(int m, const HostTensor< AccDataType > &acc)
Definition: reference_layernorm2d_fwd.hpp:25
void operator()(int m, HostTensor< OutDataType > &o, const HostTensor< AccDataType > &acc)
Definition: reference_layernorm2d_fwd.hpp:15