13 #include <type_traits>
21 template <
typename ComputeDataType,
typename OutDataType,
typename AccDataType = ComputeDataType>
33 "Warning: Unhandled ComputeDataType for setting up the relative threshold!");
35 double compute_error = 0;
46 "Warning: Unhandled OutDataType for setting up the relative threshold!");
48 double output_error = 0;
57 double midway_error =
std::max(compute_error, output_error);
60 "Warning: Unhandled AccDataType for setting up the relative threshold!");
71 return std::max(acc_error, midway_error);
74 template <
typename ComputeDataType,
typename OutDataType,
typename AccDataType = ComputeDataType>
86 "Warning: Unhandled ComputeDataType for setting up the absolute threshold!");
88 auto expo = std::log2(std::abs(max_possible_num));
89 double compute_error = 0;
100 "Warning: Unhandled OutDataType for setting up the absolute threshold!");
102 double output_error = 0;
111 double midway_error =
std::max(compute_error, output_error);
114 "Warning: Unhandled AccDataType for setting up the absolute threshold!");
116 double acc_error = 0;
126 return std::max(acc_error, midway_error);
129 template <
typename T>
130 std::ostream&
operator<<(std::ostream& os,
const std::vector<T>& v)
132 using size_type =
typename std::vector<T>::size_type;
135 for(size_type idx = 0; idx < v.size(); ++idx)
146 template <
typename Range,
typename RefRange>
148 std::is_same_v<ranges::range_value_t<Range>, ranges::range_value_t<RefRange>> &&
149 std::is_floating_point_v<ranges::range_value_t<Range>> &&
150 !std::is_same_v<ranges::range_value_t<Range>,
half_t>,
154 const std::string& msg =
"Error: Incorrect results!",
157 bool allow_infinity_ref =
false)
159 if(out.size() != ref.size())
161 std::cerr << msg <<
" out.size() != ref.size(), :" << out.size() <<
" != " << ref.size()
166 const auto is_infinity_error = [=](
auto o,
auto r) {
167 const bool either_not_finite = !std::isfinite(o) || !std::isfinite(r);
168 const bool both_infinite_and_same =
169 std::isinf(o) && std::isinf(r) && (bit_cast<uint64_t>(o) == bit_cast<uint64_t>(r));
171 return either_not_finite && !(allow_infinity_ref && both_infinite_and_same);
178 for(std::size_t i = 0; i < ref.size(); ++i)
180 const double o = *std::next(std::begin(out), i);
181 const double r = *std::next(std::begin(ref), i);
182 err = std::abs(o - r);
183 if(err > atol + rtol * std::abs(r) || is_infinity_error(o, r))
185 max_err = err > max_err ? err : max_err;
189 std::cerr << msg << std::setw(12) << std::setprecision(7) <<
" out[" << i
190 <<
"] != ref[" << i <<
"]: " << o <<
" != " << r << std::endl;
197 const float error_percent =
198 static_cast<float>(err_count) /
static_cast<float>(out.size()) * 100.f;
199 std::cerr <<
"max err: " << max_err;
200 std::cerr <<
", number of errors: " << err_count;
201 std::cerr <<
", " << error_percent <<
"% wrong values" << std::endl;
206 template <
typename Range,
typename RefRange>
208 std::is_same_v<ranges::range_value_t<Range>, ranges::range_value_t<RefRange>> &&
209 std::is_same_v<ranges::range_value_t<Range>,
bf16_t>,
213 const std::string& msg =
"Error: Incorrect results!",
216 bool allow_infinity_ref =
false)
218 if(out.size() != ref.size())
220 std::cerr << msg <<
" out.size() != ref.size(), :" << out.size() <<
" != " << ref.size()
225 const auto is_infinity_error = [=](
auto o,
auto r) {
226 const bool either_not_finite = !std::isfinite(o) || !std::isfinite(r);
227 const bool both_infinite_and_same =
228 std::isinf(o) && std::isinf(r) && (bit_cast<uint64_t>(o) == bit_cast<uint64_t>(r));
230 return either_not_finite && !(allow_infinity_ref && both_infinite_and_same);
238 for(std::size_t i = 0; i < ref.size(); ++i)
240 const double o = type_convert<float>(*std::next(std::begin(out), i));
241 const double r = type_convert<float>(*std::next(std::begin(ref), i));
242 err = std::abs(o - r);
243 if(err > atol + rtol * std::abs(r) || is_infinity_error(o, r))
245 max_err = err > max_err ? err : max_err;
249 std::cerr << msg << std::setw(12) << std::setprecision(7) <<
" out[" << i
250 <<
"] != ref[" << i <<
"]: " << o <<
" != " << r << std::endl;
257 const float error_percent =
258 static_cast<float>(err_count) /
static_cast<float>(out.size()) * 100.f;
259 std::cerr <<
"max err: " << max_err;
260 std::cerr <<
", number of errors: " << err_count;
261 std::cerr <<
", " << error_percent <<
"% wrong values" << std::endl;
266 template <
typename Range,
typename RefRange>
268 std::is_same_v<ranges::range_value_t<Range>, ranges::range_value_t<RefRange>> &&
269 std::is_same_v<ranges::range_value_t<Range>,
half_t>,
273 const std::string& msg =
"Error: Incorrect results!",
276 bool allow_infinity_ref =
false)
278 if(out.size() != ref.size())
280 std::cerr << msg <<
" out.size() != ref.size(), :" << out.size() <<
" != " << ref.size()
285 const auto is_infinity_error = [=](
auto o,
auto r) {
286 const bool either_not_finite = !std::isfinite(o) || !std::isfinite(r);
287 const bool both_infinite_and_same =
288 std::isinf(o) && std::isinf(r) && (bit_cast<uint64_t>(o) == bit_cast<uint64_t>(r));
290 return either_not_finite && !(allow_infinity_ref && both_infinite_and_same);
296 double max_err =
static_cast<double>(std::numeric_limits<ranges::range_value_t<Range>>
::min());
297 for(std::size_t i = 0; i < ref.size(); ++i)
299 const double o = type_convert<float>(*std::next(std::begin(out), i));
300 const double r = type_convert<float>(*std::next(std::begin(ref), i));
301 err = std::abs(o - r);
302 if(err > atol + rtol * std::abs(r) || is_infinity_error(o, r))
304 max_err = err > max_err ? err : max_err;
308 std::cerr << msg << std::setw(12) << std::setprecision(7) <<
" out[" << i
309 <<
"] != ref[" << i <<
"]: " << o <<
" != " << r << std::endl;
316 const float error_percent =
317 static_cast<float>(err_count) /
static_cast<float>(out.size()) * 100.f;
318 std::cerr <<
"max err: " << max_err;
319 std::cerr <<
", number of errors: " << err_count;
320 std::cerr <<
", " << error_percent <<
"% wrong values" << std::endl;
325 template <
typename Range,
typename RefRange>
326 std::enable_if_t<(std::is_same_v<ranges::range_value_t<Range>, ranges::range_value_t<RefRange>> &&
327 std::is_integral_v<ranges::range_value_t<Range>> &&
328 !std::is_same_v<ranges::range_value_t<Range>,
bf16_t>)
329 #ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
336 const std::string& msg =
"Error: Incorrect results!",
340 if(out.size() != ref.size())
342 std::cerr << msg <<
" out.size() != ref.size(), :" << out.size() <<
" != " << ref.size()
351 for(std::size_t i = 0; i < ref.size(); ++i)
353 const int64_t o = *std::next(std::begin(out), i);
354 const int64_t r = *std::next(std::begin(ref), i);
355 err = std::abs(o - r);
359 max_err = err > max_err ? err : max_err;
363 std::cerr << msg <<
" out[" << i <<
"] != ref[" << i <<
"]: " << o <<
" != " << r
371 const float error_percent =
372 static_cast<float>(err_count) /
static_cast<float>(out.size()) * 100.f;
373 std::cerr <<
"max err: " << max_err;
374 std::cerr <<
", number of errors: " << err_count;
375 std::cerr <<
", " << error_percent <<
"% wrong values" << std::endl;
380 template <
typename Range,
typename RefRange>
381 std::enable_if_t<(std::is_same_v<ranges::range_value_t<Range>, ranges::range_value_t<RefRange>> &&
382 std::is_same_v<ranges::range_value_t<Range>,
fp8_t>),
386 const std::string& msg =
"Error: Incorrect results!",
387 unsigned max_rounding_point_distance = 1,
389 bool allow_infinity_ref =
false)
391 if(out.size() != ref.size())
393 std::cerr << msg <<
" out.size() != ref.size(), :" << out.size() <<
" != " << ref.size()
398 const auto is_infinity_error = [=](
auto o,
auto r) {
399 const bool either_not_finite = !std::isfinite(o) || !std::isfinite(r);
400 const bool both_infinite_and_same =
401 std::isinf(o) && std::isinf(r) && (bit_cast<uint64_t>(o) == bit_cast<uint64_t>(r));
403 return either_not_finite && !(allow_infinity_ref && both_infinite_and_same);
406 static const auto get_rounding_point_distance = [](
fp8_t o,
fp8_t r) ->
unsigned {
407 static const auto get_sign_bit = [](
fp8_t v) ->
bool {
408 return 0x80 & bit_cast<uint8_t>(v);
411 if(get_sign_bit(o) ^ get_sign_bit(r))
417 return std::abs(bit_cast<int8_t>(o) - bit_cast<int8_t>(r));
425 for(std::size_t i = 0; i < ref.size(); ++i)
427 const fp8_t o_fp8 = *std::next(std::begin(out), i);
428 const fp8_t r_fp8 = *std::next(std::begin(ref), i);
429 const double o_fp64 = type_convert<float>(o_fp8);
430 const double r_fp64 = type_convert<float>(r_fp8);
431 err = std::abs(o_fp64 - r_fp64);
433 get_rounding_point_distance(o_fp8, r_fp8) <= max_rounding_point_distance) ||
434 is_infinity_error(o_fp64, r_fp64))
436 max_err = err > max_err ? err : max_err;
440 std::cerr << msg << std::setw(12) << std::setprecision(7) <<
" out[" << i
441 <<
"] != ref[" << i <<
"]: " << o_fp64 <<
" != " << r_fp64 << std::endl;
448 const float error_percent =
449 static_cast<float>(err_count) /
static_cast<float>(out.size()) * 100.f;
450 std::cerr <<
"max err: " << max_err;
451 std::cerr <<
", number of errors: " << err_count;
452 std::cerr <<
", " << error_percent <<
"% wrong values" << std::endl;
457 template <
typename Range,
typename RefRange>
458 std::enable_if_t<(std::is_same_v<ranges::range_value_t<Range>, ranges::range_value_t<RefRange>> &&
459 std::is_same_v<ranges::range_value_t<Range>,
bf8_t>),
463 const std::string& msg =
"Error: Incorrect results!",
466 bool allow_infinity_ref =
false)
468 if(out.size() != ref.size())
470 std::cerr << msg <<
" out.size() != ref.size(), :" << out.size() <<
" != " << ref.size()
475 const auto is_infinity_error = [=](
auto o,
auto r) {
476 const bool either_not_finite = !std::isfinite(o) || !std::isfinite(r);
477 const bool both_infinite_and_same =
478 std::isinf(o) && std::isinf(r) && (bit_cast<uint64_t>(o) == bit_cast<uint64_t>(r));
480 return either_not_finite && !(allow_infinity_ref && both_infinite_and_same);
487 for(std::size_t i = 0; i < ref.size(); ++i)
489 const double o = type_convert<float>(*std::next(std::begin(out), i));
490 const double r = type_convert<float>(*std::next(std::begin(ref), i));
491 err = std::abs(o - r);
492 if(err > atol + rtol * std::abs(r) || is_infinity_error(o, r))
494 max_err = err > max_err ? err : max_err;
498 std::cerr << msg << std::setw(12) << std::setprecision(7) <<
" out[" << i
499 <<
"] != ref[" << i <<
"]: " << o <<
" != " << r << std::endl;
506 const float error_percent =
507 static_cast<float>(err_count) /
static_cast<float>(out.size()) * 100.f;
508 std::cerr <<
"max err: " << max_err;
509 std::cerr <<
", number of errors: " << err_count;
510 std::cerr <<
", " << error_percent <<
"% wrong values" << std::endl;
#define CK_TILE_HOST
Definition: config.hpp:39
__host__ T pow(T x, T gamma)
Definition: math_v2.hpp:427
__host__ constexpr __device__ T max(T x)
Definition: math.hpp:84
__host__ constexpr __device__ T min(T x)
Definition: math.hpp:116
Definition: cluster_descriptor.hpp:13
_BitInt(8) fp8_t
Definition: float8.hpp:204
int8_t int8_t
Definition: int8.hpp:20
double get_absolute_threshold(const double max_possible_num, const int number_of_accumulations=1)
Definition: check_err.hpp:75
bfloat16_t bf16_t
Definition: bfloat16.hpp:106
std::ostream & operator<<(std::ostream &os, const std::vector< T > &v)
Definition: check_err.hpp:130
std::enable_if< std::is_same_v< ranges::range_value_t< Range >, ranges::range_value_t< RefRange > > &&std::is_floating_point_v< ranges::range_value_t< Range > > &&!std::is_same_v< ranges::range_value_t< Range >, half_t >, bool >::type CK_TILE_HOST check_err(const Range &out, const RefRange &ref, const std::string &msg="Error: Incorrect results!", double rtol=1e-5, double atol=3e-6, bool allow_infinity_ref=false)
Definition: check_err.hpp:152
unsigned _BitInt(8) bf8_t
Definition: float8.hpp:206
double get_relative_threshold(const int number_of_accumulations=1)
Definition: check_err.hpp:22
constexpr CK_TILE_HOST_DEVICE T min(T x)
Definition: math.hpp:210
_Float16 half_t
Definition: half.hpp:111
_BitInt(4) int4_t
Definition: data_type.hpp:26
std::enable_if< B, T > enable_if
Definition: enable_if.hpp:10
constexpr bool is_same_v
Definition: type.hpp:283
long int64_t
Definition: data_type.hpp:2474
typename std::enable_if< B, T >::type enable_if_t
Definition: enable_if.hpp:13
Definition: type_traits.hpp:114
Definition: bfloat16.hpp:380