13 #include <type_traits>
26 template <
typename ComputeDataType,
typename OutDataType,
typename AccDataType = ComputeDataType>
37 static_assert(is_same_v<ComputeDataType, F4> || is_same_v<ComputeDataType, F8> ||
38 is_same_v<ComputeDataType, F16> || is_same_v<ComputeDataType, BF16> ||
39 is_same_v<ComputeDataType, F32> || is_same_v<ComputeDataType, I8> ||
40 is_same_v<ComputeDataType, I32> || is_same_v<ComputeDataType, int>,
41 "Warning: Unhandled ComputeDataType for setting up the relative threshold!");
42 double compute_error = 0;
43 if constexpr(is_same_v<ComputeDataType, I8> || is_same_v<ComputeDataType, I32> ||
44 is_same_v<ComputeDataType, int>)
53 static_assert(is_same_v<OutDataType, F4> || is_same_v<OutDataType, F8> ||
54 is_same_v<OutDataType, F16> || is_same_v<OutDataType, BF16> ||
55 is_same_v<OutDataType, F32> || is_same_v<OutDataType, I8> ||
56 is_same_v<OutDataType, I32> || is_same_v<OutDataType, int>,
57 "Warning: Unhandled OutDataType for setting up the relative threshold!");
58 double output_error = 0;
59 if constexpr(is_same_v<OutDataType, I8> || is_same_v<OutDataType, I32> ||
60 is_same_v<OutDataType, int>)
68 double midway_error =
std::max(compute_error, output_error);
70 static_assert(is_same_v<AccDataType, F4> || is_same_v<AccDataType, F8> ||
71 is_same_v<AccDataType, F16> || is_same_v<AccDataType, BF16> ||
72 is_same_v<AccDataType, F32> || is_same_v<AccDataType, I8> ||
73 is_same_v<AccDataType, I32> || is_same_v<AccDataType, int>,
74 "Warning: Unhandled AccDataType for setting up the relative threshold!");
76 if constexpr(is_same_v<AccDataType, I8> || is_same_v<AccDataType, I32> ||
77 is_same_v<AccDataType, int>)
85 return std::max(acc_error, midway_error);
88 template <
typename ComputeDataType,
typename OutDataType,
typename AccDataType = ComputeDataType>
99 static_assert(is_same_v<ComputeDataType, F4> || is_same_v<ComputeDataType, F8> ||
100 is_same_v<ComputeDataType, F16> || is_same_v<ComputeDataType, BF16> ||
101 is_same_v<ComputeDataType, F32> || is_same_v<ComputeDataType, I8> ||
102 is_same_v<ComputeDataType, I32> || is_same_v<ComputeDataType, int>,
103 "Warning: Unhandled ComputeDataType for setting up the absolute threshold!");
104 auto expo = std::log2(std::abs(max_possible_num));
105 double compute_error = 0;
106 if constexpr(is_same_v<ComputeDataType, I8> || is_same_v<ComputeDataType, I32> ||
107 is_same_v<ComputeDataType, int>)
116 static_assert(is_same_v<OutDataType, F4> || is_same_v<OutDataType, F8> ||
117 is_same_v<OutDataType, F16> || is_same_v<OutDataType, BF16> ||
118 is_same_v<OutDataType, F32> || is_same_v<OutDataType, I8> ||
119 is_same_v<OutDataType, I32> || is_same_v<OutDataType, int>,
120 "Warning: Unhandled OutDataType for setting up the absolute threshold!");
121 double output_error = 0;
122 if constexpr(is_same_v<OutDataType, I8> || is_same_v<OutDataType, I32> ||
123 is_same_v<OutDataType, int>)
131 double midway_error =
std::max(compute_error, output_error);
133 static_assert(is_same_v<AccDataType, F4> || is_same_v<AccDataType, F8> ||
134 is_same_v<AccDataType, F16> || is_same_v<AccDataType, BF16> ||
135 is_same_v<AccDataType, F32> || is_same_v<AccDataType, I8> ||
136 is_same_v<AccDataType, I32> || is_same_v<AccDataType, int>,
137 "Warning: Unhandled AccDataType for setting up the absolute threshold!");
138 double acc_error = 0;
139 if constexpr(is_same_v<AccDataType, I8> || is_same_v<AccDataType, I32> ||
140 is_same_v<AccDataType, int>)
149 return std::max(acc_error, midway_error);
152 template <
typename Range,
typename RefRange>
155 std::is_floating_point_v<ranges::range_value_t<Range>> &&
156 !std::is_same_v<ranges::range_value_t<Range>,
half_t>,
160 const std::string& msg =
"Error: Incorrect results!",
164 if(out.size() != ref.size())
166 std::cerr << msg <<
" out.size() != ref.size(), :" << out.size() <<
" != " << ref.size()
175 for(std::size_t i = 0; i < ref.size(); ++i)
177 const double o = *std::next(std::begin(out), i);
178 const double r = *std::next(std::begin(ref), i);
179 err = std::abs(o - r);
180 if(err > atol + rtol * std::abs(r) || !std::isfinite(o) || !std::isfinite(r))
182 max_err = err > max_err ? err : max_err;
186 std::cerr << msg << std::setw(12) << std::setprecision(7) <<
" out[" << i
187 <<
"] != ref[" << i <<
"]: " << o <<
" != " << r << std::endl;
194 const float error_percent =
195 static_cast<float>(err_count) /
static_cast<float>(out.size()) * 100.f;
196 std::cerr <<
"max err: " << max_err;
197 std::cerr <<
", number of errors: " << err_count;
198 std::cerr <<
", " << error_percent <<
"% wrong values" << std::endl;
203 template <
typename Range,
typename RefRange>
206 std::is_same_v<ranges::range_value_t<Range>,
bhalf_t>,
210 const std::string& msg =
"Error: Incorrect results!",
214 if(out.size() != ref.size())
216 std::cerr << msg <<
" out.size() != ref.size(), :" << out.size() <<
" != " << ref.size()
226 for(std::size_t i = 0; i < ref.size(); ++i)
228 const double o = type_convert<float>(*std::next(std::begin(out), i));
229 const double r = type_convert<float>(*std::next(std::begin(ref), i));
230 err = std::abs(o - r);
231 if(err > atol + rtol * std::abs(r) || !std::isfinite(o) || !std::isfinite(r))
233 max_err = err > max_err ? err : max_err;
237 std::cerr << msg << std::setw(12) << std::setprecision(7) <<
" out[" << i
238 <<
"] != ref[" << i <<
"]: " << o <<
" != " << r << std::endl;
245 const float error_percent =
246 static_cast<float>(err_count) /
static_cast<float>(out.size()) * 100.f;
247 std::cerr <<
"max err: " << max_err;
248 std::cerr <<
", number of errors: " << err_count;
249 std::cerr <<
", " << error_percent <<
"% wrong values" << std::endl;
254 template <
typename Range,
typename RefRange>
257 std::is_same_v<ranges::range_value_t<Range>,
half_t>,
261 const std::string& msg =
"Error: Incorrect results!",
265 if(out.size() != ref.size())
267 std::cerr << msg <<
" out.size() != ref.size(), :" << out.size() <<
" != " << ref.size()
276 for(std::size_t i = 0; i < ref.size(); ++i)
278 const double o = type_convert<float>(*std::next(std::begin(out), i));
279 const double r = type_convert<float>(*std::next(std::begin(ref), i));
280 err = std::abs(o - r);
281 if(err > atol + rtol * std::abs(r) || !std::isfinite(o) || !std::isfinite(r))
283 max_err = err > max_err ? err : max_err;
287 std::cerr << msg << std::setw(12) << std::setprecision(7) <<
" out[" << i
288 <<
"] != ref[" << i <<
"]: " << o <<
" != " << r << std::endl;
295 const float error_percent =
296 static_cast<float>(err_count) /
static_cast<float>(out.size()) * 100.f;
297 std::cerr <<
"max err: " << max_err;
298 std::cerr <<
", number of errors: " << err_count;
299 std::cerr <<
", " << error_percent <<
"% wrong values" << std::endl;
304 template <
typename Range,
typename RefRange>
306 std::is_integral_v<ranges::range_value_t<Range>> &&
307 !std::is_same_v<ranges::range_value_t<Range>,
bhalf_t> &&
308 !std::is_same_v<ranges::range_value_t<Range>,
f8_t> &&
309 !std::is_same_v<ranges::range_value_t<Range>,
bf8_t>)
310 #ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
317 const std::string& msg =
"Error: Incorrect results!",
321 if(out.size() != ref.size())
323 std::cerr << msg <<
" out.size() != ref.size(), :" << out.size() <<
" != " << ref.size()
332 for(std::size_t i = 0; i < ref.size(); ++i)
334 const int64_t o = *std::next(std::begin(out), i);
335 const int64_t r = *std::next(std::begin(ref), i);
336 err = std::abs(o - r);
340 max_err = err > max_err ? err : max_err;
344 std::cerr << msg <<
" out[" << i <<
"] != ref[" << i <<
"]: " << o <<
" != " << r
352 const float error_percent =
353 static_cast<float>(err_count) /
static_cast<float>(out.size()) * 100.f;
354 std::cerr <<
"max err: " << max_err;
355 std::cerr <<
", number of errors: " << err_count;
356 std::cerr <<
", " << error_percent <<
"% wrong values" << std::endl;
361 template <
typename Range,
typename RefRange>
363 std::is_same_v<ranges::range_value_t<Range>,
f8_t>),
367 const std::string& msg =
"Error: Incorrect results!",
371 if(out.size() != ref.size())
373 std::cerr << msg <<
" out.size() != ref.size(), :" << out.size() <<
" != " << ref.size()
383 for(std::size_t i = 0; i < ref.size(); ++i)
385 const double o = type_convert<float>(*std::next(std::begin(out), i));
386 const double r = type_convert<float>(*std::next(std::begin(ref), i));
387 err = std::abs(o - r);
389 if(err > atol + rtol * std::abs(r) || !std::isfinite(o) || !std::isfinite(r))
391 max_err = err > max_err ? err : max_err;
395 std::cerr << msg << std::setw(12) << std::setprecision(7) <<
" out[" << i
396 <<
"] != ref[" << i <<
"]: " << o <<
" != " << r << std::endl;
404 std::cerr << std::setw(12) << std::setprecision(7) <<
"max err: " << max_err
405 <<
" number of errors: " << err_count << std::endl;
410 template <
typename Range,
typename RefRange>
412 std::is_same_v<ranges::range_value_t<Range>,
bf8_t>),
416 const std::string& msg =
"Error: Incorrect results!",
420 if(out.size() != ref.size())
422 std::cerr << msg <<
" out.size() != ref.size(), :" << out.size() <<
" != " << ref.size()
431 for(std::size_t i = 0; i < ref.size(); ++i)
433 const double o = type_convert<float>(*std::next(std::begin(out), i));
434 const double r = type_convert<float>(*std::next(std::begin(ref), i));
435 err = std::abs(o - r);
436 if(err > atol + rtol * std::abs(r) || !std::isfinite(o) || !std::isfinite(r))
438 max_err = err > max_err ? err : max_err;
442 std::cerr << msg << std::setw(12) << std::setprecision(7) <<
" out[" << i
443 <<
"] != ref[" << i <<
"]: " << o <<
" != " << r << std::endl;
450 std::cerr << std::setw(12) << std::setprecision(7) <<
"max err: " << max_err << std::endl;
455 template <
typename Range,
typename RefRange>
457 std::is_same_v<ranges::range_value_t<Range>,
f4_t>),
461 const std::string& msg =
"Error: Incorrect results!",
465 if(out.size() != ref.size())
467 std::cerr << msg <<
" out.size() != ref.size(), :" << out.size() <<
" != " << ref.size()
477 for(std::size_t i = 0; i < ref.size(); ++i)
479 const double o = type_convert<float>(*std::next(std::begin(out), i));
480 const double r = type_convert<float>(*std::next(std::begin(ref), i));
481 err = std::abs(o - r);
483 if(err > atol + rtol * std::abs(r) || !std::isfinite(o) || !std::isfinite(r))
485 max_err = err > max_err ? err : max_err;
489 std::cerr << msg << std::setw(12) << std::setprecision(7) <<
" out[" << i
490 <<
"] != ref[" << i <<
"]: " << o <<
" != " << r << std::endl;
498 std::cerr << std::setw(12) << std::setprecision(7) <<
"max err: " << max_err
499 <<
" number of errors: " << err_count << std::endl;
__host__ T pow(T x, T gamma)
Definition: math_v2.hpp:427
__host__ constexpr __device__ T max(T x)
Definition: math.hpp:84
__host__ constexpr __device__ T min(T x)
Definition: math.hpp:116
iter_value_t< ranges::iterator_t< R > > range_value_t
Definition: ranges.hpp:28
std::enable_if< std::is_same_v< ranges::range_value_t< Range >, ranges::range_value_t< RefRange > > &&std::is_floating_point_v< ranges::range_value_t< Range > > &&!std::is_same_v< ranges::range_value_t< Range >, half_t >, bool >::type check_err(const Range &out, const RefRange &ref, const std::string &msg="Error: Incorrect results!", double rtol=1e-5, double atol=3e-6)
Definition: check_err.hpp:158
double get_absolute_threshold(const double max_possible_num, const int number_of_accumulations=1)
Definition: check_err.hpp:89
double get_relative_threshold(const int number_of_accumulations=1)
Definition: check_err.hpp:27
float F32
32-bit floating point (single precision) type
Definition: check_err.hpp:33
ck_tile::bf16_t BF16
16-bit brain floating point type
Definition: check_err.hpp:31
ck_tile::half_t F16
16-bit floating point (half precision) type
Definition: check_err.hpp:29
int32_t I32
32-bit signed integer type
Definition: check_err.hpp:37
ck_tile::fp8_t F8
8-bit floating point type
Definition: check_err.hpp:25
int8_t I8
8-bit signed integer type
Definition: check_err.hpp:35
bf8_fnuz_t bf8_t
Definition: amd_ck_fp8.hpp:1738
f8_fnuz_t f8_t
Definition: amd_ck_fp8.hpp:1737
unsigned _BitInt(4) f4_t
Definition: data_type.hpp:32
_Float16 half_t
Definition: data_type.hpp:30
ushort bhalf_t
Definition: data_type.hpp:29
_BitInt(4) int4_t
Definition: data_type.hpp:31
std::enable_if< B, T > enable_if
Definition: enable_if.hpp:24
constexpr bool is_same_v
Definition: type.hpp:283
long int64_t
Definition: data_type.hpp:461
typename std::enable_if< B, T >::type enable_if_t
Definition: enable_if.hpp:27
signed int int32_t
Definition: stdint.h:123
signed char int8_t
Definition: stdint.h:121
Definition: numeric_limits.hpp:309
Definition: numeric_utils.hpp:10