14 template <
typename OutDataType,
typename AccDataType>
18 for(
int n = 0; n < N; ++n)
20 o(m, n) = ck_tile::type_convert<OutDataType>(acc(m, n));
24 template <
typename OutDataType,
typename AccDataType>
33 template <
typename XDataType,
34 typename GammaDataType,
35 typename ComputeDataType,
37 typename InvRmsDataType,
38 typename Epilogue = reference_rmsnorm2d_default_epilogue>
43 ComputeDataType epsilon,
44 Epilogue epilogue_functor = {})
46 auto rmsnorm2d_fwd_func = [&](
auto m) {
49 ComputeDataType mean_square = 0;
50 ComputeDataType divisor = 0;
52 for(
int n = 0; n < N; ++n)
54 ComputeDataType x = ck_tile::type_convert<ComputeDataType>(x_m_n(m, n));
58 mean_square = mean_square / N;
59 divisor = ck_tile::type_convert<ComputeDataType>(1) /
ck_tile::sqrt(mean_square + epsilon);
61 if constexpr(!std::is_same_v<InvRmsDataType, ck_tile::null_type>)
62 invRms_m(m) = ck_tile::type_convert<InvRmsDataType>(divisor);
65 for(
int n = 0; n < N; ++n)
67 ComputeDataType x = ck_tile::type_convert<ComputeDataType>(x_m_n(m, n));
68 ComputeDataType gamma = ck_tile::type_convert<ComputeDataType>(gamma_n(n));
69 acc(m, n) = x * divisor * gamma;
72 epilogue_functor(m, y_m_n, acc);
76 std::thread::hardware_concurrency());
Definition: cluster_descriptor.hpp:13
CK_TILE_HOST auto make_ParallelTensorFunctor(F f, Xs... xs)
Definition: host_tensor.hpp:272
CK_TILE_DEVICE bfloat16_t sqrt(bfloat16_t x)
Definition: bfloat16.hpp:408
void reference_rmsnorm2d_fwd(const HostTensor< XDataType > &x_m_n, const HostTensor< GammaDataType > &gamma_n, HostTensor< YDataType > &y_m_n, HostTensor< InvRmsDataType > &invRms_m, ComputeDataType epsilon, Epilogue epilogue_functor={})
Definition: reference_rmsnorm2d_fwd.hpp:39
const std::vector< std::size_t > & get_lengths() const
Definition: host_tensor.hpp:162
Definition: host_tensor.hpp:279
decltype(auto) get_lengths() const
Definition: host_tensor.hpp:333
decltype(auto) get_strides() const
Definition: host_tensor.hpp:337
Descriptor mDesc
Definition: host_tensor.hpp:678
Definition: reference_rmsnorm2d_fwd.hpp:13
auto operator()(int m, const HostTensor< AccDataType > &acc)
Definition: reference_rmsnorm2d_fwd.hpp:25
void operator()(int m, HostTensor< OutDataType > &o, const HostTensor< AccDataType > &acc)
Definition: reference_rmsnorm2d_fwd.hpp:15