/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/host/reference/reference_rmsnorm2d_fwd.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/host/reference/reference_rmsnorm2d_fwd.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/host/reference/reference_rmsnorm2d_fwd.hpp Source File
reference_rmsnorm2d_fwd.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include "ck_tile/core.hpp"
8 
9 namespace ck_tile {
10 
11 // Note: for simplicity, each functor only care about single M
13 {
14  template <typename OutDataType, typename AccDataType>
16  {
17  const int N = acc.mDesc.get_lengths()[1];
18  for(int n = 0; n < N; ++n)
19  {
20  o(m, n) = ck_tile::type_convert<OutDataType>(acc(m, n));
21  }
22  }
23 
24  template <typename OutDataType, typename AccDataType>
25  auto operator()(int m, const HostTensor<AccDataType>& acc)
26  {
28  operator()(m, o, acc);
29  return o;
30  }
31 };
32 
33 template <typename XDataType,
34  typename GammaDataType,
35  typename ComputeDataType,
36  typename YDataType,
37  typename InvRmsDataType,
38  typename UnquantYDataType,
39  typename Epilogue = reference_rmsnorm2d_default_epilogue>
41  const HostTensor<GammaDataType>& gamma_n,
42  HostTensor<YDataType>& y_m_n,
44  HostTensor<UnquantYDataType>& unquant_y_m_n,
45  ComputeDataType epsilon,
46  Epilogue epilogue_functor = {})
47 {
48  auto rmsnorm2d_fwd_func = [&](auto m) {
49  const int N = x_m_n.mDesc.get_lengths()[1];
50 
51  ComputeDataType mean_square = 0;
52  ComputeDataType divisor = 0;
53 
54  for(int n = 0; n < N; ++n)
55  {
56  ComputeDataType x = ck_tile::type_convert<ComputeDataType>(x_m_n(m, n));
57  mean_square += x * x;
58  }
59 
60  mean_square = mean_square / N;
61  divisor = ck_tile::type_convert<ComputeDataType>(1) / ck_tile::sqrt(mean_square + epsilon);
62 
63  if constexpr(!std::is_same_v<InvRmsDataType, ck_tile::null_type>)
64  invRms_m(m) = ck_tile::type_convert<InvRmsDataType>(divisor);
65 
66  HostTensor<ComputeDataType> acc(x_m_n.get_lengths(), x_m_n.get_strides());
67  for(int n = 0; n < N; ++n)
68  {
69  ComputeDataType x = ck_tile::type_convert<ComputeDataType>(x_m_n(m, n));
70  ComputeDataType gamma = ck_tile::type_convert<ComputeDataType>(gamma_n(n));
71  acc(m, n) = x * divisor * gamma;
72  }
73 
74  if constexpr(!std::is_same_v<UnquantYDataType, ck_tile::null_type>)
75  {
76  epilogue_functor(m, unquant_y_m_n, y_m_n, acc);
77  }
78  else
79  {
80  epilogue_functor(m, y_m_n, acc);
81  }
82  };
83 
84  make_ParallelTensorFunctor(rmsnorm2d_fwd_func, invRms_m.mDesc.get_lengths()[0])(
85  std::thread::hardware_concurrency());
86 }
87 } // namespace ck_tile
Definition: cluster_descriptor.hpp:13
CK_TILE_HOST auto make_ParallelTensorFunctor(F f, Xs... xs)
Definition: host_tensor.hpp:329
CK_TILE_DEVICE bfloat16_t sqrt(bfloat16_t x)
Definition: bfloat16.hpp:417
void reference_rmsnorm2d_fwd(const HostTensor< XDataType > &x_m_n, const HostTensor< GammaDataType > &gamma_n, HostTensor< YDataType > &y_m_n, HostTensor< InvRmsDataType > &invRms_m, HostTensor< UnquantYDataType > &unquant_y_m_n, ComputeDataType epsilon, Epilogue epilogue_functor={})
Definition: reference_rmsnorm2d_fwd.hpp:40
const std::vector< std::size_t > & get_lengths() const
Definition: host_tensor.hpp:198
Definition: host_tensor.hpp:336
decltype(auto) get_lengths() const
Definition: host_tensor.hpp:390
decltype(auto) get_strides() const
Definition: host_tensor.hpp:394
Descriptor mDesc
Definition: host_tensor.hpp:800
Definition: reference_rmsnorm2d_fwd.hpp:13
auto operator()(int m, const HostTensor< AccDataType > &acc)
Definition: reference_rmsnorm2d_fwd.hpp:25
void operator()(int m, HostTensor< OutDataType > &o, const HostTensor< AccDataType > &acc)
Definition: reference_rmsnorm2d_fwd.hpp:15