15     template <
typename OutDataType, 
typename AccDataType>
 
   19         for(
int n = 0; n < N; ++n)
 
   21             o(m, n) = ck_tile::type_convert<OutDataType>(acc(m, n));
 
   25     template <
typename OutDataType, 
typename AccDataType>
 
   34 template <
typename XDataType,
 
   35           typename GammaDataType,
 
   36           typename ComputeDataType,
 
   38           typename InvRmsDataType,
 
   39           typename UnquantYDataType,
 
   40           typename Epilogue = reference_rmsnorm2d_default_epilogue>
 
   46                              ComputeDataType epsilon,
 
   47                              Epilogue epilogue_functor = {},
 
   48                              const int use_model_sensitive_rmsnorm =
 
   51     auto rmsnorm2d_fwd_func = [&](
auto m) {
 
   54         ComputeDataType mean_square = 0;
 
   55         ComputeDataType divisor     = 0;
 
   57         for(
int n = 0; n < N; ++n)
 
   59             ComputeDataType x = ck_tile::type_convert<ComputeDataType>(x_m_n(m, n));
 
   63         mean_square = mean_square / N;
 
   64         divisor = ck_tile::type_convert<ComputeDataType>(1) / 
ck_tile::sqrt(mean_square + epsilon);
 
   66         if constexpr(!std::is_same_v<InvRmsDataType, ck_tile::null_type>)
 
   67             invRms_m(m) = ck_tile::type_convert<InvRmsDataType>(divisor);
 
   70         for(
int n = 0; n < N; ++n)
 
   72             ComputeDataType x     = ck_tile::type_convert<ComputeDataType>(x_m_n(m, n));
 
   73             ComputeDataType gamma = ck_tile::type_convert<ComputeDataType>(gamma_n(n));
 
   74             if(use_model_sensitive_rmsnorm ==
 
   78                 acc(m, n) = x * divisor * gamma;
 
   80             else if(use_model_sensitive_rmsnorm ==
 
   83                 if constexpr(std::is_same_v<XDataType, ck_tile::bf16_t>)
 
   85                     const auto tmp0 = float_to_bf16<bf16_rounding_mode::standard>(x * divisor);
 
   86                     const auto tmp1 = float_to_bf16<bf16_rounding_mode::standard>(
 
   87                         type_convert<ComputeDataType>(tmp0) * gamma);
 
   88                     const auto rmsn_ = type_convert<ComputeDataType>(tmp1);
 
   93                     const auto tmp   = type_convert<XDataType>(x * divisor);
 
   94                     const auto rmsn_ = type_convert<ComputeDataType>(tmp) * gamma;
 
  100         if constexpr(!std::is_same_v<UnquantYDataType, ck_tile::null_type>)
 
  102             epilogue_functor(m, unquant_y_m_n, y_m_n, acc);
 
  106             epilogue_functor(m, y_m_n, acc);
 
  111         std::thread::hardware_concurrency());
 
Definition: cluster_descriptor.hpp:13
 
CK_TILE_HOST auto make_ParallelTensorFunctor(F f, Xs... xs)
Definition: host_tensor.hpp:329
 
CK_TILE_DEVICE bfloat16_t sqrt(bfloat16_t x)
Definition: bfloat16.hpp:413
 
void reference_rmsnorm2d_fwd(const HostTensor< XDataType > &x_m_n, const HostTensor< GammaDataType > &gamma_n, HostTensor< YDataType > &y_m_n, HostTensor< InvRmsDataType > &invRms_m, HostTensor< UnquantYDataType > &unquant_y_m_n, ComputeDataType epsilon, Epilogue epilogue_functor={}, const int use_model_sensitive_rmsnorm=static_cast< int >(Rmsnorm2dSensitiveEnum::NO_SPECIFIC_MODEL))
Definition: reference_rmsnorm2d_fwd.hpp:41
 
const std::vector< std::size_t > & get_lengths() const
Definition: host_tensor.hpp:198
 
Definition: host_tensor.hpp:336
 
decltype(auto) get_lengths() const
Definition: host_tensor.hpp:390
 
decltype(auto) get_strides() const
Definition: host_tensor.hpp:394
 
Descriptor mDesc
Definition: host_tensor.hpp:800
 
Definition: reference_rmsnorm2d_fwd.hpp:14
 
auto operator()(int m, const HostTensor< AccDataType > &acc)
Definition: reference_rmsnorm2d_fwd.hpp:26
 
void operator()(int m, HostTensor< OutDataType > &o, const HostTensor< AccDataType > &acc)
Definition: reference_rmsnorm2d_fwd.hpp:16