13 #include <type_traits> 
   51 template <
typename ComputeDataType, 
typename OutDataType, 
typename AccDataType = ComputeDataType>
 
   56         is_any_of<ComputeDataType, F8, BF8, F16, BF16, F32, pk_int4_t, I8, I32, int>::value,
 
   57         "Warning: Unhandled ComputeDataType for setting up the relative threshold!");
 
   59     double compute_error = 0;
 
   69     static_assert(
is_any_of<OutDataType, F8, BF8, F16, BF16, F32, pk_int4_t, I8, I32, int>::value,
 
   70                   "Warning: Unhandled OutDataType for setting up the relative threshold!");
 
   72     double output_error = 0;
 
   81     double midway_error = 
std::max(compute_error, output_error);
 
   83     static_assert(
is_any_of<AccDataType, F8, BF8, F16, BF16, F32, pk_int4_t, I8, I32, int>::value,
 
   84                   "Warning: Unhandled AccDataType for setting up the relative threshold!");
 
   95     return std::max(acc_error, midway_error);
 
  111 template <
typename ComputeDataType, 
typename OutDataType, 
typename AccDataType = ComputeDataType>
 
  113                                            const int number_of_accumulations = 1)
 
  117         is_any_of<ComputeDataType, F8, BF8, F16, BF16, F32, pk_int4_t, I8, I32, int>::value,
 
  118         "Warning: Unhandled ComputeDataType for setting up the absolute threshold!");
 
  120     auto expo            = std::log2(std::abs(max_possible_num));
 
  121     double compute_error = 0;
 
  131     static_assert(
is_any_of<OutDataType, F8, BF8, F16, BF16, F32, pk_int4_t, I8, I32, int>::value,
 
  132                   "Warning: Unhandled OutDataType for setting up the absolute threshold!");
 
  134     double output_error = 0;
 
  143     double midway_error = 
std::max(compute_error, output_error);
 
  145     static_assert(
is_any_of<AccDataType, F8, BF8, F16, BF16, F32, pk_int4_t, I8, I32, int>::value,
 
  146                   "Warning: Unhandled AccDataType for setting up the absolute threshold!");
 
  148     double acc_error = 0;
 
  158     return std::max(acc_error, midway_error);
 
  171 template <
typename T>
 
  172 std::ostream& 
operator<<(std::ostream& os, 
const std::vector<T>& v)
 
  174     using size_type = 
typename std::vector<T>::size_type;
 
  177     for(size_type idx = 0; idx < v.size(); ++idx)
 
  200 template <
typename Range, 
typename RefRange>
 
  203                                       const std::string& msg = 
"Error: Incorrect results!")
 
  205     if(out.size() != ref.size())
 
  207         std::cerr << msg << 
" out.size() != ref.size(), :" << out.size() << 
" != " << ref.size()
 
  225     const float error_percent =
 
  226         static_cast<float>(err_count) / 
static_cast<float>(total_size) * 100.f;
 
  227     std::cerr << 
"max err: " << max_err;
 
  228     std::cerr << 
", number of errors: " << err_count;
 
  229     std::cerr << 
", " << error_percent << 
"% wrong values" << std::endl;
 
  248 template <
typename Range, 
typename RefRange>
 
  250     std::is_same_v<ranges::range_value_t<Range>, ranges::range_value_t<RefRange>> &&
 
  251         std::is_floating_point_v<ranges::range_value_t<Range>> &&
 
  252         !std::is_same_v<ranges::range_value_t<Range>, 
half_t>,
 
  256           const std::string& msg  = 
"Error: Incorrect results!",
 
  259           bool allow_infinity_ref = 
false)
 
  265     const auto is_infinity_error = [=](
auto o, 
auto r) {
 
  266         const bool either_not_finite = !std::isfinite(o) || !std::isfinite(r);
 
  267         const bool both_infinite_and_same =
 
  268             std::isinf(o) && std::isinf(r) && (bit_cast<uint64_t>(o) == bit_cast<uint64_t>(r));
 
  270         return either_not_finite && !(allow_infinity_ref && both_infinite_and_same);
 
  277     for(std::size_t i = 0; i < ref.size(); ++i)
 
  279         const double o = *std::next(std::begin(out), i);
 
  280         const double r = *std::next(std::begin(ref), i);
 
  281         err            = std::abs(o - r);
 
  282         if(err > atol + rtol * std::abs(r) || is_infinity_error(o, r))
 
  284             max_err = err > max_err ? err : max_err;
 
  288                 std::cerr << msg << std::setw(12) << std::setprecision(7) << 
" out[" << i
 
  289                           << 
"] != ref[" << i << 
"]: " << o << 
" != " << r << std::endl;
 
  317 template <
typename Range, 
typename RefRange>
 
  319     std::is_same_v<ranges::range_value_t<Range>, ranges::range_value_t<RefRange>> &&
 
  320         std::is_same_v<ranges::range_value_t<Range>, 
bf16_t>,
 
  324           const std::string& msg  = 
"Error: Incorrect results!",
 
  327           bool allow_infinity_ref = 
false)
 
  332     const auto is_infinity_error = [=](
auto o, 
auto r) {
 
  333         const bool either_not_finite = !std::isfinite(o) || !std::isfinite(r);
 
  334         const bool both_infinite_and_same =
 
  335             std::isinf(o) && std::isinf(r) && (bit_cast<uint64_t>(o) == bit_cast<uint64_t>(r));
 
  337         return either_not_finite && !(allow_infinity_ref && both_infinite_and_same);
 
  345     for(std::size_t i = 0; i < ref.size(); ++i)
 
  347         const double o = type_convert<float>(*std::next(std::begin(out), i));
 
  348         const double r = type_convert<float>(*std::next(std::begin(ref), i));
 
  349         err            = std::abs(o - r);
 
  350         if(err > atol + rtol * std::abs(r) || is_infinity_error(o, r))
 
  352             max_err = err > max_err ? err : max_err;
 
  356                 std::cerr << msg << std::setw(12) << std::setprecision(7) << 
" out[" << i
 
  357                           << 
"] != ref[" << i << 
"]: " << o << 
" != " << r << std::endl;
 
  386 template <
typename Range, 
typename RefRange>
 
  388     std::is_same_v<ranges::range_value_t<Range>, ranges::range_value_t<RefRange>> &&
 
  389         std::is_same_v<ranges::range_value_t<Range>, 
half_t>,
 
  393           const std::string& msg  = 
"Error: Incorrect results!",
 
  396           bool allow_infinity_ref = 
false)
 
  401     const auto is_infinity_error = [=](
auto o, 
auto r) {
 
  402         const bool either_not_finite = !std::isfinite(o) || !std::isfinite(r);
 
  403         const bool both_infinite_and_same =
 
  404             std::isinf(o) && std::isinf(r) && (bit_cast<uint64_t>(o) == bit_cast<uint64_t>(r));
 
  406         return either_not_finite && !(allow_infinity_ref && both_infinite_and_same);
 
  412     double max_err = 
static_cast<double>(std::numeric_limits<ranges::range_value_t<Range>>
::min());
 
  413     for(std::size_t i = 0; i < ref.size(); ++i)
 
  415         const double o = type_convert<float>(*std::next(std::begin(out), i));
 
  416         const double r = type_convert<float>(*std::next(std::begin(ref), i));
 
  417         err            = std::abs(o - r);
 
  418         if(err > atol + rtol * std::abs(r) || is_infinity_error(o, r))
 
  420             max_err = err > max_err ? err : max_err;
 
  424                 std::cerr << msg << std::setw(12) << std::setprecision(7) << 
" out[" << i
 
  425                           << 
"] != ref[" << i << 
"]: " << o << 
" != " << r << std::endl;
 
  452 template <
typename Range, 
typename RefRange>
 
  453 std::enable_if_t<(std::is_same_v<ranges::range_value_t<Range>, ranges::range_value_t<RefRange>> &&
 
  454                   std::is_integral_v<ranges::range_value_t<Range>> &&
 
  455                   !std::is_same_v<ranges::range_value_t<Range>, 
bf16_t>)
 
  456 #ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
 
  463                            const std::string& msg = 
"Error: Incorrect results!",
 
  474     for(std::size_t i = 0; i < ref.size(); ++i)
 
  476         const int64_t o = *std::next(std::begin(out), i);
 
  477         const int64_t r = *std::next(std::begin(ref), i);
 
  478         err             = std::abs(o - r);
 
  482             max_err = err > max_err ? err : max_err;
 
  486                 std::cerr << msg << 
" out[" << i << 
"] != ref[" << i << 
"]: " << o << 
" != " << r
 
  516 template <
typename Range, 
typename RefRange>
 
  517 std::enable_if_t<(std::is_same_v<ranges::range_value_t<Range>, ranges::range_value_t<RefRange>> &&
 
  518                   std::is_same_v<ranges::range_value_t<Range>, 
fp8_t>),
 
  522                            const std::string& msg               = 
"Error: Incorrect results!",
 
  523                            unsigned max_rounding_point_distance = 1,
 
  525                            bool allow_infinity_ref              = 
false)
 
  530     const auto is_infinity_error = [=](
auto o, 
auto r) {
 
  531         const bool either_not_finite = !std::isfinite(o) || !std::isfinite(r);
 
  532         const bool both_infinite_and_same =
 
  533             std::isinf(o) && std::isinf(r) && (bit_cast<uint64_t>(o) == bit_cast<uint64_t>(r));
 
  535         return either_not_finite && !(allow_infinity_ref && both_infinite_and_same);
 
  538     static const auto get_rounding_point_distance = [](
fp8_t o, 
fp8_t r) -> 
unsigned {
 
  539         static const auto get_sign_bit = [](
fp8_t v) -> 
bool {
 
  540             return 0x80 & bit_cast<uint8_t>(v);
 
  543         if(get_sign_bit(o) ^ get_sign_bit(r))
 
  549             return std::abs(bit_cast<int8_t>(o) - bit_cast<int8_t>(r));
 
  557     for(std::size_t i = 0; i < ref.size(); ++i)
 
  559         const fp8_t o_fp8   = *std::next(std::begin(out), i);
 
  560         const fp8_t r_fp8   = *std::next(std::begin(ref), i);
 
  561         const double o_fp64 = type_convert<float>(o_fp8);
 
  562         const double r_fp64 = type_convert<float>(r_fp8);
 
  563         err                 = std::abs(o_fp64 - r_fp64);
 
  565              get_rounding_point_distance(o_fp8, r_fp8) <= max_rounding_point_distance) ||
 
  566            is_infinity_error(o_fp64, r_fp64))
 
  568             max_err = err > max_err ? err : max_err;
 
  572                 std::cerr << msg << std::setw(12) << std::setprecision(7) << 
" out[" << i
 
  573                           << 
"] != ref[" << i << 
"]: " << o_fp64 << 
" != " << r_fp64 << std::endl;
 
  601 template <
typename Range, 
typename RefRange>
 
  602 std::enable_if_t<(std::is_same_v<ranges::range_value_t<Range>, ranges::range_value_t<RefRange>> &&
 
  603                   std::is_same_v<ranges::range_value_t<Range>, 
bf8_t>),
 
  607                            const std::string& msg  = 
"Error: Incorrect results!",
 
  610                            bool allow_infinity_ref = 
false)
 
  615     const auto is_infinity_error = [=](
auto o, 
auto r) {
 
  616         const bool either_not_finite = !std::isfinite(o) || !std::isfinite(r);
 
  617         const bool both_infinite_and_same =
 
  618             std::isinf(o) && std::isinf(r) && (bit_cast<uint64_t>(o) == bit_cast<uint64_t>(r));
 
  620         return either_not_finite && !(allow_infinity_ref && both_infinite_and_same);
 
  627     for(std::size_t i = 0; i < ref.size(); ++i)
 
  629         const double o = type_convert<float>(*std::next(std::begin(out), i));
 
  630         const double r = type_convert<float>(*std::next(std::begin(ref), i));
 
  631         err            = std::abs(o - r);
 
  632         if(err > atol + rtol * std::abs(r) || is_infinity_error(o, r))
 
  634             max_err = err > max_err ? err : max_err;
 
  638                 std::cerr << msg << std::setw(12) << std::setprecision(7) << 
" out[" << i
 
  639                           << 
"] != ref[" << i << 
"]: " << o << 
" != " << r << std::endl;
 
  664 template <
typename Range, 
typename RefRange>
 
  665 std::enable_if_t<(std::is_same_v<ranges::range_value_t<Range>, ranges::range_value_t<RefRange>> &&
 
  666                   std::is_same_v<ranges::range_value_t<Range>, 
pk_fp4_t>),
 
  670                            const std::string& msg = 
"Error: Incorrect results!",
 
  682             std::cerr << msg << 
" out[" << index << 
"] != ref[" << index
 
  683                       << 
"]: " << type_convert<float>(
pk_fp4_t{o})
 
  684                       << 
" != " << type_convert<float>(
pk_fp4_t{r}) << std::endl;
 
  689     for(std::size_t i = 0; i < ref.size(); ++i)
 
  691         const pk_fp4_t o = *std::next(std::begin(out), i);
 
  692         const pk_fp4_t r = *std::next(std::begin(ref), i);
 
  700     return err_count == 0;
 
#define CK_TILE_HOST
Definition: config.hpp:40
 
__host__ T pow(T x, T gamma)
Definition: math_v2.hpp:427
 
__host__ constexpr __device__ T max(T x)
Definition: math.hpp:84
 
__host__ constexpr __device__ T min(T x)
Definition: math.hpp:116
 
Definition: cluster_descriptor.hpp:13
 
float F32
32-bit floating point (single precision) type
Definition: check_err.hpp:33
 
typename pk_fp4_t::type pk_fp4_raw_t
Definition: pk_fp4.hpp:152
 
ck_tile::bf16_t BF16
16-bit brain floating point type
Definition: check_err.hpp:31
 
_BitInt(8) fp8_t
Definition: float8.hpp:204
 
CK_TILE_HOST bool check_size_mismatch(const Range &out, const RefRange &ref, const std::string &msg="Error: Incorrect results!")
Check for size mismatch between output and reference ranges.
Definition: check_err.hpp:201
 
int8_t int8_t
Definition: int8.hpp:20
 
bfloat16_t bf16_t
Definition: bfloat16.hpp:113
 
CK_TILE_HOST double get_relative_threshold(const int number_of_accumulations=1)
Calculate relative error threshold for numerical comparisons.
Definition: check_err.hpp:52
 
pk_float4_e2m1_t pk_fp4_t
Definition: pk_fp4.hpp:151
 
CK_TILE_HOST void report_error_stats(int err_count, double max_err, std::size_t total_size)
Report error statistics for numerical comparisons.
Definition: check_err.hpp:223
 
CK_TILE_HOST double get_absolute_threshold(const double max_possible_num, const int number_of_accumulations=1)
Calculate absolute error threshold for numerical comparisons.
Definition: check_err.hpp:112
 
std::ostream & operator<<(std::ostream &os, const std::vector< T > &v)
Stream operator overload for vector output.
Definition: check_err.hpp:172
 
std::enable_if< std::is_same_v< ranges::range_value_t< Range >, ranges::range_value_t< RefRange > > &&std::is_floating_point_v< ranges::range_value_t< Range > > &&!std::is_same_v< ranges::range_value_t< Range >, half_t >, bool >::type CK_TILE_HOST check_err(const Range &out, const RefRange &ref, const std::string &msg="Error: Incorrect results!", double rtol=1e-5, double atol=3e-6, bool allow_infinity_ref=false)
Check errors between floating point ranges using the specified tolerances.
Definition: check_err.hpp:254
 
ck_tile::half_t F16
16-bit floating point (half precision) type
Definition: check_err.hpp:29
 
int32_t int32_t
Definition: integer.hpp:10
 
ck_tile::bf8_t BF8
8-bit brain floating point type
Definition: check_err.hpp:27
 
int32_t I32
32-bit signed integer type
Definition: check_err.hpp:37
 
unsigned _BitInt(8) bf8_t
Definition: float8.hpp:206
 
ck_tile::fp8_t F8
8-bit floating point type
Definition: check_err.hpp:25
 
constexpr int ERROR_DETAIL_LIMIT
Maximum number of error values to display when checking errors.
Definition: check_err.hpp:22
 
constexpr CK_TILE_HOST_DEVICE T min(T x)
Definition: math.hpp:210
 
_Float16 half_t
Definition: half.hpp:111
 
int8_t I8
8-bit signed integer type
Definition: check_err.hpp:35
 
_BitInt(4) int4_t
Definition: data_type.hpp:32
 
std::enable_if< B, T > enable_if
Definition: enable_if.hpp:24
 
constexpr bool is_same_v
Definition: type.hpp:283
 
typename std::enable_if< B, T >::type enable_if_t
Definition: enable_if.hpp:27
 
signed __int64 int64_t
Definition: stdint.h:135
 
Definition: integral_constant.hpp:13
 
Definition: type_traits.hpp:115
 
Definition: numeric.hpp:81
 
Definition: numeric.hpp:18
 
Definition: pk_fp4.hpp:76
 
constexpr CK_TILE_HOST_DEVICE type _unpack(number< I >) const