13 #include <type_traits>
51 template <
typename ComputeDataType,
typename OutDataType,
typename AccDataType = ComputeDataType>
67 "Warning: Unhandled ComputeDataType for setting up the relative threshold!");
69 double compute_error = 0;
79 static_assert(
is_any_of<OutDataType, F8, BF8, F16, BF16, F32, pk_int4_t, I8, I32, int>::value,
80 "Warning: Unhandled OutDataType for setting up the relative threshold!");
82 double output_error = 0;
91 double midway_error =
std::max(compute_error, output_error);
93 static_assert(
is_any_of<AccDataType, F8, BF8, F16, BF16, F32, pk_int4_t, I8, I32, int>::value,
94 "Warning: Unhandled AccDataType for setting up the relative threshold!");
105 return std::max(acc_error, midway_error);
121 template <
typename ComputeDataType,
typename OutDataType,
typename AccDataType = ComputeDataType>
123 const int number_of_accumulations = 1)
138 "Warning: Unhandled ComputeDataType for setting up the absolute threshold!");
140 auto expo = std::log2(std::abs(max_possible_num));
141 double compute_error = 0;
151 static_assert(
is_any_of<OutDataType, F8, BF8, F16, BF16, F32, pk_int4_t, I8, I32, int>::value,
152 "Warning: Unhandled OutDataType for setting up the absolute threshold!");
154 double output_error = 0;
163 double midway_error =
std::max(compute_error, output_error);
165 static_assert(
is_any_of<AccDataType, F8, BF8, F16, BF16, F32, pk_int4_t, I8, I32, int>::value,
166 "Warning: Unhandled AccDataType for setting up the absolute threshold!");
168 double acc_error = 0;
178 return std::max(acc_error, midway_error);
191 template <
typename T>
192 std::ostream&
operator<<(std::ostream& os,
const std::vector<T>& v)
194 using size_type =
typename std::vector<T>::size_type;
197 for(size_type idx = 0; idx < v.size(); ++idx)
220 template <
typename Range,
typename RefRange>
223 const std::string& msg =
"Error: Incorrect results!")
225 if(out.size() != ref.size())
227 std::cerr << msg <<
" out.size() != ref.size(), :" << out.size() <<
" != " << ref.size()
245 const float error_percent =
246 static_cast<float>(err_count) /
static_cast<float>(total_size) * 100.f;
247 std::cerr <<
"max err: " << max_err;
248 std::cerr <<
", number of errors: " << err_count;
249 std::cerr <<
", " << error_percent <<
"% wrong values" << std::endl;
268 template <
typename Range,
typename RefRange>
270 std::is_same_v<ranges::range_value_t<Range>, ranges::range_value_t<RefRange>> &&
271 std::is_floating_point_v<ranges::range_value_t<Range>> &&
272 !std::is_same_v<ranges::range_value_t<Range>,
half_t>,
276 const std::string& msg =
"Error: Incorrect results!",
279 bool allow_infinity_ref =
false)
285 const auto is_infinity_error = [=](
auto o,
auto r) {
286 const bool either_not_finite = !std::isfinite(o) || !std::isfinite(r);
287 const bool both_infinite_and_same =
288 std::isinf(o) && std::isinf(r) && (bit_cast<uint64_t>(o) == bit_cast<uint64_t>(r));
290 return either_not_finite && !(allow_infinity_ref && both_infinite_and_same);
297 for(std::size_t i = 0; i < ref.size(); ++i)
299 const double o = *std::next(std::begin(out), i);
300 const double r = *std::next(std::begin(ref), i);
301 err = std::abs(o - r);
302 if(err > atol + rtol * std::abs(r) || is_infinity_error(o, r))
304 max_err = err > max_err ? err : max_err;
308 std::cerr << msg << std::setw(12) << std::setprecision(7) <<
" out[" << i
309 <<
"] != ref[" << i <<
"]: " << o <<
" != " << r << std::endl;
337 template <
typename Range,
typename RefRange>
339 std::is_same_v<ranges::range_value_t<Range>, ranges::range_value_t<RefRange>> &&
340 std::is_same_v<ranges::range_value_t<Range>,
bf16_t>,
344 const std::string& msg =
"Error: Incorrect results!",
347 bool allow_infinity_ref =
false)
352 const auto is_infinity_error = [=](
auto o,
auto r) {
353 const bool either_not_finite = !std::isfinite(o) || !std::isfinite(r);
354 const bool both_infinite_and_same =
355 std::isinf(o) && std::isinf(r) && (bit_cast<uint64_t>(o) == bit_cast<uint64_t>(r));
357 return either_not_finite && !(allow_infinity_ref && both_infinite_and_same);
365 for(std::size_t i = 0; i < ref.size(); ++i)
367 const double o = type_convert<float>(*std::next(std::begin(out), i));
368 const double r = type_convert<float>(*std::next(std::begin(ref), i));
369 err = std::abs(o - r);
370 if(err > atol + rtol * std::abs(r) || is_infinity_error(o, r))
372 max_err = err > max_err ? err : max_err;
376 std::cerr << msg << std::setw(12) << std::setprecision(7) <<
" out[" << i
377 <<
"] != ref[" << i <<
"]: " << o <<
" != " << r << std::endl;
406 template <
typename Range,
typename RefRange>
408 std::is_same_v<ranges::range_value_t<Range>, ranges::range_value_t<RefRange>> &&
409 std::is_same_v<ranges::range_value_t<Range>,
half_t>,
413 const std::string& msg =
"Error: Incorrect results!",
416 bool allow_infinity_ref =
false)
421 const auto is_infinity_error = [=](
auto o,
auto r) {
422 const bool either_not_finite = !std::isfinite(o) || !std::isfinite(r);
423 const bool both_infinite_and_same =
424 std::isinf(o) && std::isinf(r) && (bit_cast<uint64_t>(o) == bit_cast<uint64_t>(r));
426 return either_not_finite && !(allow_infinity_ref && both_infinite_and_same);
432 double max_err =
static_cast<double>(std::numeric_limits<ranges::range_value_t<Range>>
::min());
433 for(std::size_t i = 0; i < ref.size(); ++i)
435 const double o = type_convert<float>(*std::next(std::begin(out), i));
436 const double r = type_convert<float>(*std::next(std::begin(ref), i));
437 err = std::abs(o - r);
438 if(err > atol + rtol * std::abs(r) || is_infinity_error(o, r))
440 max_err = err > max_err ? err : max_err;
444 std::cerr << msg << std::setw(12) << std::setprecision(7) <<
" out[" << i
445 <<
"] != ref[" << i <<
"]: " << o <<
" != " << r << std::endl;
472 template <
typename Range,
typename RefRange>
473 std::enable_if_t<(std::is_same_v<ranges::range_value_t<Range>, ranges::range_value_t<RefRange>> &&
474 std::is_integral_v<ranges::range_value_t<Range>> &&
475 !std::is_same_v<ranges::range_value_t<Range>,
bf16_t>)
476 #ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
483 const std::string& msg =
"Error: Incorrect results!",
494 for(std::size_t i = 0; i < ref.size(); ++i)
496 const int64_t o = *std::next(std::begin(out), i);
497 const int64_t r = *std::next(std::begin(ref), i);
498 err = std::abs(o - r);
502 max_err = err > max_err ? err : max_err;
506 std::cerr << msg <<
" out[" << i <<
"] != ref[" << i <<
"]: " << o <<
" != " << r
536 template <
typename Range,
typename RefRange>
537 std::enable_if_t<(std::is_same_v<ranges::range_value_t<Range>, ranges::range_value_t<RefRange>> &&
538 std::is_same_v<ranges::range_value_t<Range>,
fp8_t>),
542 const std::string& msg =
"Error: Incorrect results!",
543 unsigned max_rounding_point_distance = 1,
545 bool allow_infinity_ref =
false)
550 const auto is_infinity_error = [=](
auto o,
auto r) {
551 const bool either_not_finite = !std::isfinite(o) || !std::isfinite(r);
552 const bool both_infinite_and_same =
553 std::isinf(o) && std::isinf(r) && (bit_cast<uint64_t>(o) == bit_cast<uint64_t>(r));
555 return either_not_finite && !(allow_infinity_ref && both_infinite_and_same);
558 static const auto get_rounding_point_distance = [](
fp8_t o,
fp8_t r) ->
unsigned {
559 static const auto get_sign_bit = [](
fp8_t v) ->
bool {
560 return 0x80 & bit_cast<uint8_t>(v);
563 if(get_sign_bit(o) ^ get_sign_bit(r))
569 return std::abs(bit_cast<int8_t>(o) - bit_cast<int8_t>(r));
577 for(std::size_t i = 0; i < ref.size(); ++i)
579 const fp8_t o_fp8 = *std::next(std::begin(out), i);
580 const fp8_t r_fp8 = *std::next(std::begin(ref), i);
581 const double o_fp64 = type_convert<float>(o_fp8);
582 const double r_fp64 = type_convert<float>(r_fp8);
583 err = std::abs(o_fp64 - r_fp64);
585 get_rounding_point_distance(o_fp8, r_fp8) <= max_rounding_point_distance) ||
586 is_infinity_error(o_fp64, r_fp64))
588 max_err = err > max_err ? err : max_err;
592 std::cerr << msg << std::setw(12) << std::setprecision(7) <<
" out[" << i
593 <<
"] != ref[" << i <<
"]: " << o_fp64 <<
" != " << r_fp64 << std::endl;
621 template <
typename Range,
typename RefRange>
622 std::enable_if_t<(std::is_same_v<ranges::range_value_t<Range>, ranges::range_value_t<RefRange>> &&
623 std::is_same_v<ranges::range_value_t<Range>,
bf8_t>),
627 const std::string& msg =
"Error: Incorrect results!",
630 bool allow_infinity_ref =
false)
635 const auto is_infinity_error = [=](
auto o,
auto r) {
636 const bool either_not_finite = !std::isfinite(o) || !std::isfinite(r);
637 const bool both_infinite_and_same =
638 std::isinf(o) && std::isinf(r) && (bit_cast<uint64_t>(o) == bit_cast<uint64_t>(r));
640 return either_not_finite && !(allow_infinity_ref && both_infinite_and_same);
647 for(std::size_t i = 0; i < ref.size(); ++i)
649 const double o = type_convert<float>(*std::next(std::begin(out), i));
650 const double r = type_convert<float>(*std::next(std::begin(ref), i));
651 err = std::abs(o - r);
652 if(err > atol + rtol * std::abs(r) || is_infinity_error(o, r))
654 max_err = err > max_err ? err : max_err;
658 std::cerr << msg << std::setw(12) << std::setprecision(7) <<
" out[" << i
659 <<
"] != ref[" << i <<
"]: " << o <<
" != " << r << std::endl;
684 template <
typename Range,
typename RefRange>
685 std::enable_if_t<(std::is_same_v<ranges::range_value_t<Range>, ranges::range_value_t<RefRange>> &&
686 std::is_same_v<ranges::range_value_t<Range>,
pk_fp4_t>),
690 const std::string& msg =
"Error: Incorrect results!",
702 std::cerr << msg <<
" out[" << index <<
"] != ref[" << index
703 <<
"]: " << type_convert<float>(
pk_fp4_t{o})
704 <<
" != " << type_convert<float>(
pk_fp4_t{r}) << std::endl;
709 for(std::size_t i = 0; i < ref.size(); ++i)
711 const pk_fp4_t o = *std::next(std::begin(out), i);
712 const pk_fp4_t r = *std::next(std::begin(ref), i);
720 return err_count == 0;
#define CK_TILE_HOST
Definition: config.hpp:44
__host__ T pow(T x, T gamma)
Definition: math_v2.hpp:427
__host__ constexpr __device__ T max(T x)
Definition: math.hpp:84
__host__ constexpr __device__ T min(T x)
Definition: math.hpp:116
Definition: cluster_descriptor.hpp:13
float F32
32-bit floating point (single precision) type
Definition: check_err.hpp:33
typename pk_fp4_t::type pk_fp4_raw_t
Definition: pk_fp4.hpp:152
ck_tile::bf16_t BF16
16-bit brain floating point type
Definition: check_err.hpp:31
_BitInt(8) fp8_t
Definition: float8.hpp:204
CK_TILE_HOST bool check_size_mismatch(const Range &out, const RefRange &ref, const std::string &msg="Error: Incorrect results!")
Check for size mismatch between output and reference ranges.
Definition: check_err.hpp:221
int8_t int8_t
Definition: int8.hpp:20
bfloat16_t bf16_t
Definition: bfloat16.hpp:113
CK_TILE_HOST double get_relative_threshold(const int number_of_accumulations=1)
Calculate relative error threshold for numerical comparisons.
Definition: check_err.hpp:52
pk_float4_e2m1_t pk_fp4_t
Definition: pk_fp4.hpp:151
CK_TILE_HOST void report_error_stats(int err_count, double max_err, std::size_t total_size)
Report error statistics for numerical comparisons.
Definition: check_err.hpp:243
CK_TILE_HOST double get_absolute_threshold(const double max_possible_num, const int number_of_accumulations=1)
Calculate absolute error threshold for numerical comparisons.
Definition: check_err.hpp:122
std::ostream & operator<<(std::ostream &os, const std::vector< T > &v)
Stream operator overload for vector output.
Definition: check_err.hpp:192
std::enable_if< std::is_same_v< ranges::range_value_t< Range >, ranges::range_value_t< RefRange > > &&std::is_floating_point_v< ranges::range_value_t< Range > > &&!std::is_same_v< ranges::range_value_t< Range >, half_t >, bool >::type CK_TILE_HOST check_err(const Range &out, const RefRange &ref, const std::string &msg="Error: Incorrect results!", double rtol=1e-5, double atol=3e-6, bool allow_infinity_ref=false)
Check errors between floating point ranges using the specified tolerances.
Definition: check_err.hpp:274
ck_tile::half_t F16
16-bit floating point (half precision) type
Definition: check_err.hpp:29
int32_t int32_t
Definition: integer.hpp:10
ck_tile::bf8_t BF8
8-bit brain floating point type
Definition: check_err.hpp:27
int32_t I32
32-bit signed integer type
Definition: check_err.hpp:37
unsigned _BitInt(8) bf8_t
Definition: float8.hpp:206
ck_tile::fp8_t F8
8-bit floating point type
Definition: check_err.hpp:25
constexpr int ERROR_DETAIL_LIMIT
Maximum number of error values to display when checking errors.
Definition: check_err.hpp:22
constexpr CK_TILE_HOST_DEVICE T min(T x)
Definition: math.hpp:206
_Float16 half_t
Definition: half.hpp:111
int8_t I8
8-bit signed integer type
Definition: check_err.hpp:35
_BitInt(4) int4_t
Definition: data_type.hpp:32
std::enable_if< B, T > enable_if
Definition: enable_if.hpp:24
constexpr bool is_same_v
Definition: type.hpp:283
typename std::enable_if< B, T >::type enable_if_t
Definition: enable_if.hpp:27
const GenericPointer< typename T::ValueType > T2 value
Definition: pointer.h:1697
signed __int64 int64_t
Definition: stdint.h:135
Definition: integral_constant.hpp:13
Definition: type_traits.hpp:115
Definition: numeric.hpp:81
Definition: numeric.hpp:18
Definition: pk_fp4.hpp:76
constexpr CK_TILE_HOST_DEVICE type _unpack(number< I >) const
Definition: pk_int4.hpp:21