13 #include <type_traits> 
   26 template <
typename ComputeDataType, 
typename OutDataType, 
typename AccDataType = ComputeDataType>
 
   38     static_assert(is_same_v<ComputeDataType, F4> || is_same_v<ComputeDataType, F8> ||
 
   39                       is_same_v<ComputeDataType, F16> || is_same_v<ComputeDataType, BF16> ||
 
   40                       is_same_v<ComputeDataType, F32> || is_same_v<ComputeDataType, TF32> ||
 
   41                       is_same_v<ComputeDataType, I8> || is_same_v<ComputeDataType, I32> ||
 
   42                       is_same_v<ComputeDataType, int>,
 
   43                   "Warning: Unhandled ComputeDataType for setting up the relative threshold!");
 
   44     double compute_error = 0;
 
   45     if constexpr(is_same_v<ComputeDataType, I8> || is_same_v<ComputeDataType, I32> ||
 
   46                  is_same_v<ComputeDataType, int>)
 
   55     static_assert(is_same_v<OutDataType, F4> || is_same_v<OutDataType, F8> ||
 
   56                       is_same_v<OutDataType, F16> || is_same_v<OutDataType, BF16> ||
 
   57                       is_same_v<OutDataType, F32> || is_same_v<ComputeDataType, TF32> ||
 
   58                       is_same_v<OutDataType, I8> || is_same_v<OutDataType, I32> ||
 
   59                       is_same_v<OutDataType, int>,
 
   60                   "Warning: Unhandled OutDataType for setting up the relative threshold!");
 
   61     double output_error = 0;
 
   62     if constexpr(is_same_v<OutDataType, I8> || is_same_v<OutDataType, I32> ||
 
   63                  is_same_v<OutDataType, int>)
 
   71     double midway_error = 
std::max(compute_error, output_error);
 
   73     static_assert(is_same_v<AccDataType, F4> || is_same_v<AccDataType, F8> ||
 
   74                       is_same_v<AccDataType, F16> || is_same_v<AccDataType, BF16> ||
 
   75                       is_same_v<AccDataType, F32> || is_same_v<ComputeDataType, TF32> ||
 
   76                       is_same_v<AccDataType, I8> || is_same_v<AccDataType, I32> ||
 
   77                       is_same_v<AccDataType, int>,
 
   78                   "Warning: Unhandled AccDataType for setting up the relative threshold!");
 
   80     if constexpr(is_same_v<AccDataType, I8> || is_same_v<AccDataType, I32> ||
 
   81                  is_same_v<AccDataType, int>)
 
   89     return std::max(acc_error, midway_error);
 
   92 template <
typename ComputeDataType, 
typename OutDataType, 
typename AccDataType = ComputeDataType>
 
  104     static_assert(is_same_v<ComputeDataType, F4> || is_same_v<ComputeDataType, F8> ||
 
  105                       is_same_v<ComputeDataType, F16> || is_same_v<ComputeDataType, BF16> ||
 
  106                       is_same_v<ComputeDataType, F32> || is_same_v<ComputeDataType, TF32> ||
 
  107                       is_same_v<ComputeDataType, I8> || is_same_v<ComputeDataType, I32> ||
 
  108                       is_same_v<ComputeDataType, int>,
 
  109                   "Warning: Unhandled ComputeDataType for setting up the absolute threshold!");
 
  110     auto expo            = std::log2(std::abs(max_possible_num));
 
  111     double compute_error = 0;
 
  112     if constexpr(is_same_v<ComputeDataType, I8> || is_same_v<ComputeDataType, I32> ||
 
  113                  is_same_v<ComputeDataType, int>)
 
  122     static_assert(is_same_v<OutDataType, F4> || is_same_v<OutDataType, F8> ||
 
  123                       is_same_v<OutDataType, F16> || is_same_v<OutDataType, BF16> ||
 
  124                       is_same_v<OutDataType, F32> || is_same_v<ComputeDataType, TF32> ||
 
  125                       is_same_v<OutDataType, I8> || is_same_v<OutDataType, I32> ||
 
  126                       is_same_v<OutDataType, int>,
 
  127                   "Warning: Unhandled OutDataType for setting up the absolute threshold!");
 
  128     double output_error = 0;
 
  129     if constexpr(is_same_v<OutDataType, I8> || is_same_v<OutDataType, I32> ||
 
  130                  is_same_v<OutDataType, int>)
 
  138     double midway_error = 
std::max(compute_error, output_error);
 
  140     static_assert(is_same_v<AccDataType, F4> || is_same_v<AccDataType, F8> ||
 
  141                       is_same_v<AccDataType, F16> || is_same_v<AccDataType, BF16> ||
 
  142                       is_same_v<AccDataType, F32> || is_same_v<ComputeDataType, TF32> ||
 
  143                       is_same_v<AccDataType, I8> || is_same_v<AccDataType, I32> ||
 
  144                       is_same_v<AccDataType, int>,
 
  145                   "Warning: Unhandled AccDataType for setting up the absolute threshold!");
 
  146     double acc_error = 0;
 
  147     if constexpr(is_same_v<AccDataType, I8> || is_same_v<AccDataType, I32> ||
 
  148                  is_same_v<AccDataType, int>)
 
  157     return std::max(acc_error, midway_error);
 
  160 template <
typename Range,
 
  165         std::is_same_v<ranges::range_value_t<Range>, 
float> &&
 
  166         std::is_same_v<ComputeDataType, ck::tf32_t>,
 
  170           const std::string& msg = 
"Error: Incorrect results!",
 
  174     if(out.size() != ref.size())
 
  176         std::cerr << msg << 
" out.size() != ref.size(), :" << out.size() << 
" != " << ref.size()
 
  185     for(std::size_t i = 0; i < ref.size(); ++i)
 
  187         const double o = *std::next(std::begin(out), i);
 
  188         const double r = *std::next(std::begin(ref), i);
 
  189         err            = std::abs(o - r);
 
  190         if(err > atol + rtol * std::abs(r) || !std::isfinite(o) || !std::isfinite(r))
 
  192             max_err = err > max_err ? err : max_err;
 
  195                 std::cerr << msg << std::setw(12) << std::setprecision(7) << 
" out[" << i
 
  196                           << 
"] != ref[" << i << 
"]: " << o << 
" != " << r << std::endl;
 
  204         const float error_percent =
 
  205             static_cast<float>(err_count) / 
static_cast<float>(out.size()) * 100.f;
 
  206         std::cerr << 
"max err: " << max_err;
 
  207         std::cerr << 
", number of errors: " << err_count;
 
  208         std::cerr << 
", " << error_percent << 
"% wrong values" << std::endl;
 
  213 template <
typename Range,
 
  218         std::is_floating_point_v<ranges::range_value_t<Range>> &&
 
  219         !std::is_same_v<ranges::range_value_t<Range>, 
half_t> &&
 
  220         !std::is_same_v<ComputeDataType, ck::tf32_t>,
 
  224           const std::string& msg = 
"Error: Incorrect results!",
 
  228     if(out.size() != ref.size())
 
  230         std::cerr << msg << 
" out.size() != ref.size(), :" << out.size() << 
" != " << ref.size()
 
  239     for(std::size_t i = 0; i < ref.size(); ++i)
 
  241         const double o = *std::next(std::begin(out), i);
 
  242         const double r = *std::next(std::begin(ref), i);
 
  243         err            = std::abs(o - r);
 
  244         if(err > atol + rtol * std::abs(r) || !std::isfinite(o) || !std::isfinite(r))
 
  246             max_err = err > max_err ? err : max_err;
 
  249                 std::cerr << msg << std::setw(12) << std::setprecision(7) << 
" out[" << i
 
  250                           << 
"] != ref[" << i << 
"]: " << o << 
" != " << r << std::endl;
 
  258         const float error_percent =
 
  259             static_cast<float>(err_count) / 
static_cast<float>(out.size()) * 100.f;
 
  260         std::cerr << 
"max err: " << max_err;
 
  261         std::cerr << 
", number of errors: " << err_count;
 
  262         std::cerr << 
", " << error_percent << 
"% wrong values" << std::endl;
 
  267 template <
typename Range,
 
  272         std::is_same_v<ranges::range_value_t<Range>, 
bhalf_t>,
 
  276           const std::string& msg = 
"Error: Incorrect results!",
 
  280     if(out.size() != ref.size())
 
  282         std::cerr << msg << 
" out.size() != ref.size(), :" << out.size() << 
" != " << ref.size()
 
  292     for(std::size_t i = 0; i < ref.size(); ++i)
 
  294         const double o = type_convert<float>(*std::next(std::begin(out), i));
 
  295         const double r = type_convert<float>(*std::next(std::begin(ref), i));
 
  296         err            = std::abs(o - r);
 
  297         if(err > atol + rtol * std::abs(r) || !std::isfinite(o) || !std::isfinite(r))
 
  299             max_err = err > max_err ? err : max_err;
 
  303                 std::cerr << msg << std::setw(12) << std::setprecision(7) << 
" out[" << i
 
  304                           << 
"] != ref[" << i << 
"]: " << o << 
" != " << r << std::endl;
 
  311         const float error_percent =
 
  312             static_cast<float>(err_count) / 
static_cast<float>(out.size()) * 100.f;
 
  313         std::cerr << 
"max err: " << max_err;
 
  314         std::cerr << 
", number of errors: " << err_count;
 
  315         std::cerr << 
", " << error_percent << 
"% wrong values" << std::endl;
 
  320 template <
typename Range,
 
  325         std::is_same_v<ranges::range_value_t<Range>, 
half_t>,
 
  329           const std::string& msg = 
"Error: Incorrect results!",
 
  333     if(out.size() != ref.size())
 
  335         std::cerr << msg << 
" out.size() != ref.size(), :" << out.size() << 
" != " << ref.size()
 
  344     for(std::size_t i = 0; i < ref.size(); ++i)
 
  346         const double o = type_convert<float>(*std::next(std::begin(out), i));
 
  347         const double r = type_convert<float>(*std::next(std::begin(ref), i));
 
  348         err            = std::abs(o - r);
 
  349         if(err > atol + rtol * std::abs(r) || !std::isfinite(o) || !std::isfinite(r))
 
  351             max_err = err > max_err ? err : max_err;
 
  355                 std::cerr << msg << std::setw(12) << std::setprecision(7) << 
" out[" << i
 
  356                           << 
"] != ref[" << i << 
"]: " << o << 
" != " << r << std::endl;
 
  363         const float error_percent =
 
  364             static_cast<float>(err_count) / 
static_cast<float>(out.size()) * 100.f;
 
  365         std::cerr << 
"max err: " << max_err;
 
  366         std::cerr << 
", number of errors: " << err_count;
 
  367         std::cerr << 
", " << error_percent << 
"% wrong values" << std::endl;
 
  372 template <
typename Range,
 
  376                   std::is_integral_v<ranges::range_value_t<Range>> &&
 
  377                   !std::is_same_v<ranges::range_value_t<Range>, 
bhalf_t> &&
 
  378                   !std::is_same_v<ranges::range_value_t<Range>, 
f8_t> &&
 
  379                   !std::is_same_v<ranges::range_value_t<Range>, 
bf8_t>)
 
  380 #ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
 
  387           const std::string& msg = 
"Error: Incorrect results!",
 
  391     if(out.size() != ref.size())
 
  393         std::cerr << msg << 
" out.size() != ref.size(), :" << out.size() << 
" != " << ref.size()
 
  402     for(std::size_t i = 0; i < ref.size(); ++i)
 
  404         const int64_t o = *std::next(std::begin(out), i);
 
  405         const int64_t r = *std::next(std::begin(ref), i);
 
  406         err             = std::abs(o - r);
 
  410             max_err = err > max_err ? err : max_err;
 
  414                 std::cerr << msg << 
" out[" << i << 
"] != ref[" << i << 
"]: " << o << 
" != " << r
 
  422         const float error_percent =
 
  423             static_cast<float>(err_count) / 
static_cast<float>(out.size()) * 100.f;
 
  424         std::cerr << 
"max err: " << max_err;
 
  425         std::cerr << 
", number of errors: " << err_count;
 
  426         std::cerr << 
", " << error_percent << 
"% wrong values" << std::endl;
 
  431 template <
typename Range,
 
  435                   std::is_same_v<ranges::range_value_t<Range>, 
f8_t>),
 
  439           const std::string& msg = 
"Error: Incorrect results!",
 
  443     if(out.size() != ref.size())
 
  445         std::cerr << msg << 
" out.size() != ref.size(), :" << out.size() << 
" != " << ref.size()
 
  455     for(std::size_t i = 0; i < ref.size(); ++i)
 
  457         const double o = type_convert<float>(*std::next(std::begin(out), i));
 
  458         const double r = type_convert<float>(*std::next(std::begin(ref), i));
 
  459         err            = std::abs(o - r);
 
  461         if(err > atol + rtol * std::abs(r) || !std::isfinite(o) || !std::isfinite(r))
 
  463             max_err = err > max_err ? err : max_err;
 
  467                 std::cerr << msg << std::setw(12) << std::setprecision(7) << 
" out[" << i
 
  468                           << 
"] != ref[" << i << 
"]: " << o << 
" != " << r << std::endl;
 
  476         std::cerr << std::setw(12) << std::setprecision(7) << 
"max err: " << max_err
 
  477                   << 
" number of errors: " << err_count << std::endl;
 
  482 template <
typename Range,
 
  486                   std::is_same_v<ranges::range_value_t<Range>, 
bf8_t>),
 
  490           const std::string& msg = 
"Error: Incorrect results!",
 
  494     if(out.size() != ref.size())
 
  496         std::cerr << msg << 
" out.size() != ref.size(), :" << out.size() << 
" != " << ref.size()
 
  505     for(std::size_t i = 0; i < ref.size(); ++i)
 
  507         const double o = type_convert<float>(*std::next(std::begin(out), i));
 
  508         const double r = type_convert<float>(*std::next(std::begin(ref), i));
 
  509         err            = std::abs(o - r);
 
  510         if(err > atol + rtol * std::abs(r) || !std::isfinite(o) || !std::isfinite(r))
 
  512             max_err = err > max_err ? err : max_err;
 
  516                 std::cerr << msg << std::setw(12) << std::setprecision(7) << 
" out[" << i
 
  517                           << 
"] != ref[" << i << 
"]: " << o << 
" != " << r << std::endl;
 
  524         std::cerr << std::setw(12) << std::setprecision(7) << 
"max err: " << max_err << std::endl;
 
  529 template <
typename Range,
 
  533                   std::is_same_v<ranges::range_value_t<Range>, 
f4_t>),
 
  537           const std::string& msg = 
"Error: Incorrect results!",
 
  541     if(out.size() != ref.size())
 
  543         std::cerr << msg << 
" out.size() != ref.size(), :" << out.size() << 
" != " << ref.size()
 
  553     for(std::size_t i = 0; i < ref.size(); ++i)
 
  555         const double o = type_convert<float>(*std::next(std::begin(out), i));
 
  556         const double r = type_convert<float>(*std::next(std::begin(ref), i));
 
  557         err            = std::abs(o - r);
 
  559         if(err > atol + rtol * std::abs(r) || !std::isfinite(o) || !std::isfinite(r))
 
  561             max_err = err > max_err ? err : max_err;
 
  565                 std::cerr << msg << std::setw(12) << std::setprecision(7) << 
" out[" << i
 
  566                           << 
"] != ref[" << i << 
"]: " << o << 
" != " << r << std::endl;
 
  574         std::cerr << std::setw(12) << std::setprecision(7) << 
"max err: " << max_err
 
  575                   << 
" number of errors: " << err_count << std::endl;
 
__host__ T pow(T x, T gamma)
Definition: math_v2.hpp:427
 
__host__ constexpr __device__ T max(T x)
Definition: math.hpp:84
 
__host__ constexpr __device__ T min(T x)
Definition: math.hpp:116
 
iter_value_t< ranges::iterator_t< R > > range_value_t
Definition: ranges.hpp:28
 
double get_absolute_threshold(const double max_possible_num, const int number_of_accumulations=1)
Definition: check_err.hpp:93
 
double get_relative_threshold(const int number_of_accumulations=1)
Definition: check_err.hpp:27
 
std::enable_if< std::is_same_v< ranges::range_value_t< Range >, ranges::range_value_t< RefRange > > &&std::is_same_v< ranges::range_value_t< Range >, float > &&std::is_same_v< ComputeDataType, ck::tf32_t >, bool >::type check_err(const Range &out, const RefRange &ref, const std::string &msg="Error: Incorrect results!", double rtol=1e-5, double atol=3e-5)
Definition: check_err.hpp:168
 
float F32
32-bit floating point (single precision) type
Definition: check_err.hpp:33
 
ck_tile::bf16_t BF16
16-bit brain floating point type
Definition: check_err.hpp:31
 
ck_tile::half_t F16
16-bit floating point (half precision) type
Definition: check_err.hpp:29
 
int32_t I32
32-bit signed integer type
Definition: check_err.hpp:37
 
ck_tile::fp8_t F8
8-bit floating point type
Definition: check_err.hpp:25
 
int8_t I8
8-bit signed integer type
Definition: check_err.hpp:35
 
f8_fnuz_t f8_t
Definition: amd_ck_fp8.hpp:1762
 
unsigned _BitInt(4) f4_t
Definition: data_type.hpp:33
 
_Float16 half_t
Definition: data_type.hpp:31
 
_BitInt(19) tf32_t
Definition: data_type.hpp:29
 
ushort bhalf_t
Definition: data_type.hpp:30
 
_BitInt(4) int4_t
Definition: data_type.hpp:32
 
std::enable_if< B, T > enable_if
Definition: enable_if.hpp:24
 
constexpr bool is_same_v
Definition: type.hpp:283
 
long int64_t
Definition: data_type.hpp:464
 
typename std::enable_if< B, T >::type enable_if_t
Definition: enable_if.hpp:27
 
signed int int32_t
Definition: stdint.h:123
 
signed char int8_t
Definition: stdint.h:121
 
Definition: numeric_limits.hpp:309
 
Definition: numeric_utils.hpp:10
 
Definition: amd_ck_fp8.hpp:49
 
Definition: amd_ck_fp8.hpp:36