13 #include <type_traits>
51 template <
typename ComputeDataType,
typename OutDataType,
typename AccDataType = ComputeDataType>
56 is_any_of<ComputeDataType, F8, BF8, F16, BF16, F32, pk_int4_t, I8, I32, int>::value,
57 "Warning: Unhandled ComputeDataType for setting up the relative threshold!");
59 double compute_error = 0;
69 static_assert(
is_any_of<OutDataType, F8, BF8, F16, BF16, F32, pk_int4_t, I8, I32, int>::value,
70 "Warning: Unhandled OutDataType for setting up the relative threshold!");
72 double output_error = 0;
81 double midway_error =
std::max(compute_error, output_error);
83 static_assert(
is_any_of<AccDataType, F8, BF8, F16, BF16, F32, pk_int4_t, I8, I32, int>::value,
84 "Warning: Unhandled AccDataType for setting up the relative threshold!");
95 return std::max(acc_error, midway_error);
111 template <
typename ComputeDataType,
typename OutDataType,
typename AccDataType = ComputeDataType>
113 const int number_of_accumulations = 1)
117 is_any_of<ComputeDataType, F8, BF8, F16, BF16, F32, pk_int4_t, I8, I32, int>::value,
118 "Warning: Unhandled ComputeDataType for setting up the absolute threshold!");
120 auto expo = std::log2(std::abs(max_possible_num));
121 double compute_error = 0;
131 static_assert(
is_any_of<OutDataType, F8, BF8, F16, BF16, F32, pk_int4_t, I8, I32, int>::value,
132 "Warning: Unhandled OutDataType for setting up the absolute threshold!");
134 double output_error = 0;
143 double midway_error =
std::max(compute_error, output_error);
145 static_assert(
is_any_of<AccDataType, F8, BF8, F16, BF16, F32, pk_int4_t, I8, I32, int>::value,
146 "Warning: Unhandled AccDataType for setting up the absolute threshold!");
148 double acc_error = 0;
158 return std::max(acc_error, midway_error);
171 template <
typename T>
172 std::ostream&
operator<<(std::ostream& os,
const std::vector<T>& v)
174 using size_type =
typename std::vector<T>::size_type;
177 for(size_type idx = 0; idx < v.size(); ++idx)
200 template <
typename Range,
typename RefRange>
203 const std::string& msg =
"Error: Incorrect results!")
205 if(out.size() != ref.size())
207 std::cerr << msg <<
" out.size() != ref.size(), :" << out.size() <<
" != " << ref.size()
225 const float error_percent =
226 static_cast<float>(err_count) /
static_cast<float>(total_size) * 100.f;
227 std::cerr <<
"max err: " << max_err;
228 std::cerr <<
", number of errors: " << err_count;
229 std::cerr <<
", " << error_percent <<
"% wrong values" << std::endl;
248 template <
typename Range,
typename RefRange>
250 std::is_same_v<ranges::range_value_t<Range>, ranges::range_value_t<RefRange>> &&
251 std::is_floating_point_v<ranges::range_value_t<Range>> &&
252 !std::is_same_v<ranges::range_value_t<Range>,
half_t>,
256 const std::string& msg =
"Error: Incorrect results!",
259 bool allow_infinity_ref =
false)
265 const auto is_infinity_error = [=](
auto o,
auto r) {
266 const bool either_not_finite = !std::isfinite(o) || !std::isfinite(r);
267 const bool both_infinite_and_same =
268 std::isinf(o) && std::isinf(r) && (bit_cast<uint64_t>(o) == bit_cast<uint64_t>(r));
270 return either_not_finite && !(allow_infinity_ref && both_infinite_and_same);
277 for(std::size_t i = 0; i < ref.size(); ++i)
279 const double o = *std::next(std::begin(out), i);
280 const double r = *std::next(std::begin(ref), i);
281 err = std::abs(o - r);
282 if(err > atol + rtol * std::abs(r) || is_infinity_error(o, r))
284 max_err = err > max_err ? err : max_err;
288 std::cerr << msg << std::setw(12) << std::setprecision(7) <<
" out[" << i
289 <<
"] != ref[" << i <<
"]: " << o <<
" != " << r << std::endl;
317 template <
typename Range,
typename RefRange>
319 std::is_same_v<ranges::range_value_t<Range>, ranges::range_value_t<RefRange>> &&
320 std::is_same_v<ranges::range_value_t<Range>,
bf16_t>,
324 const std::string& msg =
"Error: Incorrect results!",
327 bool allow_infinity_ref =
false)
332 const auto is_infinity_error = [=](
auto o,
auto r) {
333 const bool either_not_finite = !std::isfinite(o) || !std::isfinite(r);
334 const bool both_infinite_and_same =
335 std::isinf(o) && std::isinf(r) && (bit_cast<uint64_t>(o) == bit_cast<uint64_t>(r));
337 return either_not_finite && !(allow_infinity_ref && both_infinite_and_same);
345 for(std::size_t i = 0; i < ref.size(); ++i)
347 const double o = type_convert<float>(*std::next(std::begin(out), i));
348 const double r = type_convert<float>(*std::next(std::begin(ref), i));
349 err = std::abs(o - r);
350 if(err > atol + rtol * std::abs(r) || is_infinity_error(o, r))
352 max_err = err > max_err ? err : max_err;
356 std::cerr << msg << std::setw(12) << std::setprecision(7) <<
" out[" << i
357 <<
"] != ref[" << i <<
"]: " << o <<
" != " << r << std::endl;
386 template <
typename Range,
typename RefRange>
388 std::is_same_v<ranges::range_value_t<Range>, ranges::range_value_t<RefRange>> &&
389 std::is_same_v<ranges::range_value_t<Range>,
half_t>,
393 const std::string& msg =
"Error: Incorrect results!",
396 bool allow_infinity_ref =
false)
401 const auto is_infinity_error = [=](
auto o,
auto r) {
402 const bool either_not_finite = !std::isfinite(o) || !std::isfinite(r);
403 const bool both_infinite_and_same =
404 std::isinf(o) && std::isinf(r) && (bit_cast<uint64_t>(o) == bit_cast<uint64_t>(r));
406 return either_not_finite && !(allow_infinity_ref && both_infinite_and_same);
412 double max_err =
static_cast<double>(std::numeric_limits<ranges::range_value_t<Range>>
::min());
413 for(std::size_t i = 0; i < ref.size(); ++i)
415 const double o = type_convert<float>(*std::next(std::begin(out), i));
416 const double r = type_convert<float>(*std::next(std::begin(ref), i));
417 err = std::abs(o - r);
418 if(err > atol + rtol * std::abs(r) || is_infinity_error(o, r))
420 max_err = err > max_err ? err : max_err;
424 std::cerr << msg << std::setw(12) << std::setprecision(7) <<
" out[" << i
425 <<
"] != ref[" << i <<
"]: " << o <<
" != " << r << std::endl;
452 template <
typename Range,
typename RefRange>
453 std::enable_if_t<(std::is_same_v<ranges::range_value_t<Range>, ranges::range_value_t<RefRange>> &&
454 std::is_integral_v<ranges::range_value_t<Range>> &&
455 !std::is_same_v<ranges::range_value_t<Range>,
bf16_t>)
456 #ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
463 const std::string& msg =
"Error: Incorrect results!",
474 for(std::size_t i = 0; i < ref.size(); ++i)
476 const int64_t o = *std::next(std::begin(out), i);
477 const int64_t r = *std::next(std::begin(ref), i);
478 err = std::abs(o - r);
482 max_err = err > max_err ? err : max_err;
486 std::cerr << msg <<
" out[" << i <<
"] != ref[" << i <<
"]: " << o <<
" != " << r
516 template <
typename Range,
typename RefRange>
517 std::enable_if_t<(std::is_same_v<ranges::range_value_t<Range>, ranges::range_value_t<RefRange>> &&
518 std::is_same_v<ranges::range_value_t<Range>,
fp8_t>),
522 const std::string& msg =
"Error: Incorrect results!",
523 unsigned max_rounding_point_distance = 1,
525 bool allow_infinity_ref =
false)
530 const auto is_infinity_error = [=](
auto o,
auto r) {
531 const bool either_not_finite = !std::isfinite(o) || !std::isfinite(r);
532 const bool both_infinite_and_same =
533 std::isinf(o) && std::isinf(r) && (bit_cast<uint64_t>(o) == bit_cast<uint64_t>(r));
535 return either_not_finite && !(allow_infinity_ref && both_infinite_and_same);
538 static const auto get_rounding_point_distance = [](
fp8_t o,
fp8_t r) ->
unsigned {
539 static const auto get_sign_bit = [](
fp8_t v) ->
bool {
540 return 0x80 & bit_cast<uint8_t>(v);
543 if(get_sign_bit(o) ^ get_sign_bit(r))
549 return std::abs(bit_cast<int8_t>(o) - bit_cast<int8_t>(r));
557 for(std::size_t i = 0; i < ref.size(); ++i)
559 const fp8_t o_fp8 = *std::next(std::begin(out), i);
560 const fp8_t r_fp8 = *std::next(std::begin(ref), i);
561 const double o_fp64 = type_convert<float>(o_fp8);
562 const double r_fp64 = type_convert<float>(r_fp8);
563 err = std::abs(o_fp64 - r_fp64);
565 get_rounding_point_distance(o_fp8, r_fp8) <= max_rounding_point_distance) ||
566 is_infinity_error(o_fp64, r_fp64))
568 max_err = err > max_err ? err : max_err;
572 std::cerr << msg << std::setw(12) << std::setprecision(7) <<
" out[" << i
573 <<
"] != ref[" << i <<
"]: " << o_fp64 <<
" != " << r_fp64 << std::endl;
601 template <
typename Range,
typename RefRange>
602 std::enable_if_t<(std::is_same_v<ranges::range_value_t<Range>, ranges::range_value_t<RefRange>> &&
603 std::is_same_v<ranges::range_value_t<Range>,
bf8_t>),
607 const std::string& msg =
"Error: Incorrect results!",
610 bool allow_infinity_ref =
false)
615 const auto is_infinity_error = [=](
auto o,
auto r) {
616 const bool either_not_finite = !std::isfinite(o) || !std::isfinite(r);
617 const bool both_infinite_and_same =
618 std::isinf(o) && std::isinf(r) && (bit_cast<uint64_t>(o) == bit_cast<uint64_t>(r));
620 return either_not_finite && !(allow_infinity_ref && both_infinite_and_same);
627 for(std::size_t i = 0; i < ref.size(); ++i)
629 const double o = type_convert<float>(*std::next(std::begin(out), i));
630 const double r = type_convert<float>(*std::next(std::begin(ref), i));
631 err = std::abs(o - r);
632 if(err > atol + rtol * std::abs(r) || is_infinity_error(o, r))
634 max_err = err > max_err ? err : max_err;
638 std::cerr << msg << std::setw(12) << std::setprecision(7) <<
" out[" << i
639 <<
"] != ref[" << i <<
"]: " << o <<
" != " << r << std::endl;
#define CK_TILE_HOST
Definition: config.hpp:40
__host__ T pow(T x, T gamma)
Definition: math_v2.hpp:427
__host__ constexpr __device__ T max(T x)
Definition: math.hpp:84
__host__ constexpr __device__ T min(T x)
Definition: math.hpp:116
Definition: cluster_descriptor.hpp:13
float F32
32-bit floating point (single precision) type
Definition: check_err.hpp:33
ck_tile::bf16_t BF16
16-bit brain floating point type
Definition: check_err.hpp:31
_BitInt(8) fp8_t
Definition: float8.hpp:204
CK_TILE_HOST bool check_size_mismatch(const Range &out, const RefRange &ref, const std::string &msg="Error: Incorrect results!")
Check for size mismatch between output and reference ranges.
Definition: check_err.hpp:201
int8_t int8_t
Definition: int8.hpp:20
bfloat16_t bf16_t
Definition: bfloat16.hpp:113
CK_TILE_HOST double get_relative_threshold(const int number_of_accumulations=1)
Calculate relative error threshold for numerical comparisons.
Definition: check_err.hpp:52
CK_TILE_HOST void report_error_stats(int err_count, double max_err, std::size_t total_size)
Report error statistics for numerical comparisons.
Definition: check_err.hpp:223
CK_TILE_HOST double get_absolute_threshold(const double max_possible_num, const int number_of_accumulations=1)
Calculate absolute error threshold for numerical comparisons.
Definition: check_err.hpp:112
std::ostream & operator<<(std::ostream &os, const std::vector< T > &v)
Stream operator overload for vector output.
Definition: check_err.hpp:172
std::enable_if< std::is_same_v< ranges::range_value_t< Range >, ranges::range_value_t< RefRange > > &&std::is_floating_point_v< ranges::range_value_t< Range > > &&!std::is_same_v< ranges::range_value_t< Range >, half_t >, bool >::type CK_TILE_HOST check_err(const Range &out, const RefRange &ref, const std::string &msg="Error: Incorrect results!", double rtol=1e-5, double atol=3e-6, bool allow_infinity_ref=false)
Check errors between floating point ranges using the specified tolerances.
Definition: check_err.hpp:254
ck_tile::half_t F16
16-bit floating point (half precision) type
Definition: check_err.hpp:29
int32_t int32_t
Definition: integer.hpp:10
ck_tile::bf8_t BF8
8-bit brain floating point type
Definition: check_err.hpp:27
int32_t I32
32-bit signed integer type
Definition: check_err.hpp:37
unsigned _BitInt(8) bf8_t
Definition: float8.hpp:206
ck_tile::fp8_t F8
8-bit floating point type
Definition: check_err.hpp:25
constexpr int ERROR_DETAIL_LIMIT
Maximum number of error values to display when checking errors.
Definition: check_err.hpp:22
constexpr CK_TILE_HOST_DEVICE T min(T x)
Definition: math.hpp:210
_Float16 half_t
Definition: half.hpp:111
int8_t I8
8-bit signed integer type
Definition: check_err.hpp:35
_BitInt(4) int4_t
Definition: data_type.hpp:31
std::enable_if< B, T > enable_if
Definition: enable_if.hpp:24
constexpr bool is_same_v
Definition: type.hpp:283
typename std::enable_if< B, T >::type enable_if_t
Definition: enable_if.hpp:27
signed __int64 int64_t
Definition: stdint.h:135
Definition: type_traits.hpp:115
Definition: numeric.hpp:81