/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/host/check_err.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/host/check_err.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/host/check_err.hpp Source File
check_err.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include <algorithm>
7 #include <cmath>
8 #include <cstdlib>
9 #include <iostream>
10 #include <iomanip>
11 #include <iterator>
12 #include <limits>
13 #include <type_traits>
14 #include <vector>
15 
16 #include "ck_tile/core.hpp"
17 #include "ck_tile/host/ranges.hpp"
18 
19 namespace ck_tile {
20 
22 constexpr int ERROR_DETAIL_LIMIT = 5;
23 
33 using F32 = float;
35 using I8 = int8_t;
37 using I32 = int32_t;
38 
51 template <typename ComputeDataType, typename OutDataType, typename AccDataType = ComputeDataType>
52 CK_TILE_HOST double get_relative_threshold(const int number_of_accumulations = 1)
53 {
54 
55  static_assert(
57  "Warning: Unhandled ComputeDataType for setting up the relative threshold!");
58 
59  double compute_error = 0;
61  {
62  return 0;
63  }
64  else
65  {
66  compute_error = std::pow(2, -numeric_traits<ComputeDataType>::mant) * 0.5;
67  }
68 
70  "Warning: Unhandled OutDataType for setting up the relative threshold!");
71 
72  double output_error = 0;
74  {
75  return 0;
76  }
77  else
78  {
79  output_error = std::pow(2, -numeric_traits<OutDataType>::mant) * 0.5;
80  }
81  double midway_error = std::max(compute_error, output_error);
82 
84  "Warning: Unhandled AccDataType for setting up the relative threshold!");
85 
86  double acc_error = 0;
88  {
89  return 0;
90  }
91  else
92  {
93  acc_error = std::pow(2, -numeric_traits<AccDataType>::mant) * 0.5 * number_of_accumulations;
94  }
95  return std::max(acc_error, midway_error);
96 }
97 
111 template <typename ComputeDataType, typename OutDataType, typename AccDataType = ComputeDataType>
112 CK_TILE_HOST double get_absolute_threshold(const double max_possible_num,
113  const int number_of_accumulations = 1)
114 {
115 
116  static_assert(
118  "Warning: Unhandled ComputeDataType for setting up the absolute threshold!");
119 
120  auto expo = std::log2(std::abs(max_possible_num));
121  double compute_error = 0;
123  {
124  return 0;
125  }
126  else
127  {
128  compute_error = std::pow(2, expo - numeric_traits<ComputeDataType>::mant) * 0.5;
129  }
130 
132  "Warning: Unhandled OutDataType for setting up the absolute threshold!");
133 
134  double output_error = 0;
136  {
137  return 0;
138  }
139  else
140  {
141  output_error = std::pow(2, expo - numeric_traits<OutDataType>::mant) * 0.5;
142  }
143  double midway_error = std::max(compute_error, output_error);
144 
146  "Warning: Unhandled AccDataType for setting up the absolute threshold!");
147 
148  double acc_error = 0;
150  {
151  return 0;
152  }
153  else
154  {
155  acc_error =
156  std::pow(2, expo - numeric_traits<AccDataType>::mant) * 0.5 * number_of_accumulations;
157  }
158  return std::max(acc_error, midway_error);
159 }
160 
171 template <typename T>
172 std::ostream& operator<<(std::ostream& os, const std::vector<T>& v)
173 {
174  using size_type = typename std::vector<T>::size_type;
175 
176  os << "[";
177  for(size_type idx = 0; idx < v.size(); ++idx)
178  {
179  if(0 < idx)
180  {
181  os << ", ";
182  }
183  os << v[idx];
184  }
185  return os << "]";
186 }
187 
200 template <typename Range, typename RefRange>
201 CK_TILE_HOST bool check_size_mismatch(const Range& out,
202  const RefRange& ref,
203  const std::string& msg = "Error: Incorrect results!")
204 {
205  if(out.size() != ref.size())
206  {
207  std::cerr << msg << " out.size() != ref.size(), :" << out.size() << " != " << ref.size()
208  << std::endl;
209  return true;
210  }
211  return false;
212 }
213 
223 CK_TILE_HOST void report_error_stats(int err_count, double max_err, std::size_t total_size)
224 {
225  const float error_percent =
226  static_cast<float>(err_count) / static_cast<float>(total_size) * 100.f;
227  std::cerr << "max err: " << max_err;
228  std::cerr << ", number of errors: " << err_count;
229  std::cerr << ", " << error_percent << "% wrong values" << std::endl;
230 }
231 
248 template <typename Range, typename RefRange>
249 typename std::enable_if<
250  std::is_same_v<ranges::range_value_t<Range>, ranges::range_value_t<RefRange>> &&
251  std::is_floating_point_v<ranges::range_value_t<Range>> &&
252  !std::is_same_v<ranges::range_value_t<Range>, half_t>,
253  bool>::type CK_TILE_HOST
254 check_err(const Range& out,
255  const RefRange& ref,
256  const std::string& msg = "Error: Incorrect results!",
257  double rtol = 1e-5,
258  double atol = 3e-6,
259  bool allow_infinity_ref = false)
260 {
261 
262  if(check_size_mismatch(out, ref, msg))
263  return false;
264 
265  const auto is_infinity_error = [=](auto o, auto r) {
266  const bool either_not_finite = !std::isfinite(o) || !std::isfinite(r);
267  const bool both_infinite_and_same =
268  std::isinf(o) && std::isinf(r) && (bit_cast<uint64_t>(o) == bit_cast<uint64_t>(r));
269 
270  return either_not_finite && !(allow_infinity_ref && both_infinite_and_same);
271  };
272 
273  bool res{true};
274  int err_count = 0;
275  double err = 0;
276  double max_err = std::numeric_limits<double>::min();
277  for(std::size_t i = 0; i < ref.size(); ++i)
278  {
279  const double o = *std::next(std::begin(out), i);
280  const double r = *std::next(std::begin(ref), i);
281  err = std::abs(o - r);
282  if(err > atol + rtol * std::abs(r) || is_infinity_error(o, r))
283  {
284  max_err = err > max_err ? err : max_err;
285  err_count++;
286  if(err_count < ERROR_DETAIL_LIMIT)
287  {
288  std::cerr << msg << std::setw(12) << std::setprecision(7) << " out[" << i
289  << "] != ref[" << i << "]: " << o << " != " << r << std::endl;
290  }
291  res = false;
292  }
293  }
294  if(!res)
295  {
296  report_error_stats(err_count, max_err, ref.size());
297  }
298  return res;
299 }
300 
317 template <typename Range, typename RefRange>
318 typename std::enable_if<
319  std::is_same_v<ranges::range_value_t<Range>, ranges::range_value_t<RefRange>> &&
320  std::is_same_v<ranges::range_value_t<Range>, bf16_t>,
321  bool>::type CK_TILE_HOST
322 check_err(const Range& out,
323  const RefRange& ref,
324  const std::string& msg = "Error: Incorrect results!",
325  double rtol = 1e-3,
326  double atol = 1e-3,
327  bool allow_infinity_ref = false)
328 {
329  if(check_size_mismatch(out, ref, msg))
330  return false;
331 
332  const auto is_infinity_error = [=](auto o, auto r) {
333  const bool either_not_finite = !std::isfinite(o) || !std::isfinite(r);
334  const bool both_infinite_and_same =
335  std::isinf(o) && std::isinf(r) && (bit_cast<uint64_t>(o) == bit_cast<uint64_t>(r));
336 
337  return either_not_finite && !(allow_infinity_ref && both_infinite_and_same);
338  };
339 
340  bool res{true};
341  int err_count = 0;
342  double err = 0;
343  // TODO: This is a hack. We should have proper specialization for bf16_t data type.
344  double max_err = std::numeric_limits<float>::min();
345  for(std::size_t i = 0; i < ref.size(); ++i)
346  {
347  const double o = type_convert<float>(*std::next(std::begin(out), i));
348  const double r = type_convert<float>(*std::next(std::begin(ref), i));
349  err = std::abs(o - r);
350  if(err > atol + rtol * std::abs(r) || is_infinity_error(o, r))
351  {
352  max_err = err > max_err ? err : max_err;
353  err_count++;
354  if(err_count < ERROR_DETAIL_LIMIT)
355  {
356  std::cerr << msg << std::setw(12) << std::setprecision(7) << " out[" << i
357  << "] != ref[" << i << "]: " << o << " != " << r << std::endl;
358  }
359  res = false;
360  }
361  }
362  if(!res)
363  {
364  report_error_stats(err_count, max_err, ref.size());
365  }
366  return res;
367 }
368 
386 template <typename Range, typename RefRange>
387 typename std::enable_if<
388  std::is_same_v<ranges::range_value_t<Range>, ranges::range_value_t<RefRange>> &&
389  std::is_same_v<ranges::range_value_t<Range>, half_t>,
390  bool>::type CK_TILE_HOST
391 check_err(const Range& out,
392  const RefRange& ref,
393  const std::string& msg = "Error: Incorrect results!",
394  double rtol = 1e-3,
395  double atol = 1e-3,
396  bool allow_infinity_ref = false)
397 {
398  if(check_size_mismatch(out, ref, msg))
399  return false;
400 
401  const auto is_infinity_error = [=](auto o, auto r) {
402  const bool either_not_finite = !std::isfinite(o) || !std::isfinite(r);
403  const bool both_infinite_and_same =
404  std::isinf(o) && std::isinf(r) && (bit_cast<uint64_t>(o) == bit_cast<uint64_t>(r));
405 
406  return either_not_finite && !(allow_infinity_ref && both_infinite_and_same);
407  };
408 
409  bool res{true};
410  int err_count = 0;
411  double err = 0;
412  double max_err = static_cast<double>(std::numeric_limits<ranges::range_value_t<Range>>::min());
413  for(std::size_t i = 0; i < ref.size(); ++i)
414  {
415  const double o = type_convert<float>(*std::next(std::begin(out), i));
416  const double r = type_convert<float>(*std::next(std::begin(ref), i));
417  err = std::abs(o - r);
418  if(err > atol + rtol * std::abs(r) || is_infinity_error(o, r))
419  {
420  max_err = err > max_err ? err : max_err;
421  err_count++;
422  if(err_count < ERROR_DETAIL_LIMIT)
423  {
424  std::cerr << msg << std::setw(12) << std::setprecision(7) << " out[" << i
425  << "] != ref[" << i << "]: " << o << " != " << r << std::endl;
426  }
427  res = false;
428  }
429  }
430  if(!res)
431  {
432  report_error_stats(err_count, max_err, ref.size());
433  }
434  return res;
435 }
436 
452 template <typename Range, typename RefRange>
453 std::enable_if_t<(std::is_same_v<ranges::range_value_t<Range>, ranges::range_value_t<RefRange>> &&
454  std::is_integral_v<ranges::range_value_t<Range>> &&
455  !std::is_same_v<ranges::range_value_t<Range>, bf16_t>)
456 #ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
457  || std::is_same_v<ranges::range_value_t<Range>, int4_t>
458 #endif
459  ,
460  bool>
461  CK_TILE_HOST check_err(const Range& out,
462  const RefRange& ref,
463  const std::string& msg = "Error: Incorrect results!",
464  double = 0,
465  double atol = 0)
466 {
467  if(check_size_mismatch(out, ref, msg))
468  return false;
469 
470  bool res{true};
471  int err_count = 0;
472  int64_t err = 0;
474  for(std::size_t i = 0; i < ref.size(); ++i)
475  {
476  const int64_t o = *std::next(std::begin(out), i);
477  const int64_t r = *std::next(std::begin(ref), i);
478  err = std::abs(o - r);
479 
480  if(err > atol)
481  {
482  max_err = err > max_err ? err : max_err;
483  err_count++;
484  if(err_count < ERROR_DETAIL_LIMIT)
485  {
486  std::cerr << msg << " out[" << i << "] != ref[" << i << "]: " << o << " != " << r
487  << std::endl;
488  }
489  res = false;
490  }
491  }
492  if(!res)
493  {
494  report_error_stats(err_count, static_cast<double>(max_err), ref.size());
495  }
496  return res;
497 }
498 
516 template <typename Range, typename RefRange>
517 std::enable_if_t<(std::is_same_v<ranges::range_value_t<Range>, ranges::range_value_t<RefRange>> &&
518  std::is_same_v<ranges::range_value_t<Range>, fp8_t>),
519  bool>
520  CK_TILE_HOST check_err(const Range& out,
521  const RefRange& ref,
522  const std::string& msg = "Error: Incorrect results!",
523  unsigned max_rounding_point_distance = 1,
524  double atol = 1e-1,
525  bool allow_infinity_ref = false)
526 {
527  if(check_size_mismatch(out, ref, msg))
528  return false;
529 
530  const auto is_infinity_error = [=](auto o, auto r) {
531  const bool either_not_finite = !std::isfinite(o) || !std::isfinite(r);
532  const bool both_infinite_and_same =
533  std::isinf(o) && std::isinf(r) && (bit_cast<uint64_t>(o) == bit_cast<uint64_t>(r));
534 
535  return either_not_finite && !(allow_infinity_ref && both_infinite_and_same);
536  };
537 
538  static const auto get_rounding_point_distance = [](fp8_t o, fp8_t r) -> unsigned {
539  static const auto get_sign_bit = [](fp8_t v) -> bool {
540  return 0x80 & bit_cast<uint8_t>(v);
541  };
542 
543  if(get_sign_bit(o) ^ get_sign_bit(r))
544  {
546  }
547  else
548  {
549  return std::abs(bit_cast<int8_t>(o) - bit_cast<int8_t>(r));
550  }
551  };
552 
553  bool res{true};
554  int err_count = 0;
555  double err = 0;
556  double max_err = std::numeric_limits<float>::min();
557  for(std::size_t i = 0; i < ref.size(); ++i)
558  {
559  const fp8_t o_fp8 = *std::next(std::begin(out), i);
560  const fp8_t r_fp8 = *std::next(std::begin(ref), i);
561  const double o_fp64 = type_convert<float>(o_fp8);
562  const double r_fp64 = type_convert<float>(r_fp8);
563  err = std::abs(o_fp64 - r_fp64);
564  if(!(less_equal<double>{}(err, atol) ||
565  get_rounding_point_distance(o_fp8, r_fp8) <= max_rounding_point_distance) ||
566  is_infinity_error(o_fp64, r_fp64))
567  {
568  max_err = err > max_err ? err : max_err;
569  err_count++;
570  if(err_count < ERROR_DETAIL_LIMIT)
571  {
572  std::cerr << msg << std::setw(12) << std::setprecision(7) << " out[" << i
573  << "] != ref[" << i << "]: " << o_fp64 << " != " << r_fp64 << std::endl;
574  }
575  res = false;
576  }
577  }
578  if(!res)
579  {
580  report_error_stats(err_count, max_err, ref.size());
581  }
582  return res;
583 }
584 
601 template <typename Range, typename RefRange>
602 std::enable_if_t<(std::is_same_v<ranges::range_value_t<Range>, ranges::range_value_t<RefRange>> &&
603  std::is_same_v<ranges::range_value_t<Range>, bf8_t>),
604  bool>
605  CK_TILE_HOST check_err(const Range& out,
606  const RefRange& ref,
607  const std::string& msg = "Error: Incorrect results!",
608  double rtol = 1e-3,
609  double atol = 1e-3,
610  bool allow_infinity_ref = false)
611 {
612  if(check_size_mismatch(out, ref, msg))
613  return false;
614 
615  const auto is_infinity_error = [=](auto o, auto r) {
616  const bool either_not_finite = !std::isfinite(o) || !std::isfinite(r);
617  const bool both_infinite_and_same =
618  std::isinf(o) && std::isinf(r) && (bit_cast<uint64_t>(o) == bit_cast<uint64_t>(r));
619 
620  return either_not_finite && !(allow_infinity_ref && both_infinite_and_same);
621  };
622 
623  bool res{true};
624  int err_count = 0;
625  double err = 0;
626  double max_err = std::numeric_limits<float>::min();
627  for(std::size_t i = 0; i < ref.size(); ++i)
628  {
629  const double o = type_convert<float>(*std::next(std::begin(out), i));
630  const double r = type_convert<float>(*std::next(std::begin(ref), i));
631  err = std::abs(o - r);
632  if(err > atol + rtol * std::abs(r) || is_infinity_error(o, r))
633  {
634  max_err = err > max_err ? err : max_err;
635  err_count++;
636  if(err_count < ERROR_DETAIL_LIMIT)
637  {
638  std::cerr << msg << std::setw(12) << std::setprecision(7) << " out[" << i
639  << "] != ref[" << i << "]: " << o << " != " << r << std::endl;
640  }
641  res = false;
642  }
643  }
644  if(!res)
645  {
646  report_error_stats(err_count, max_err, ref.size());
647  }
648  return res;
649 }
650 
651 } // namespace ck_tile
#define CK_TILE_HOST
Definition: config.hpp:40
__host__ T pow(T x, T gamma)
Definition: math_v2.hpp:427
__host__ constexpr __device__ T max(T x)
Definition: math.hpp:84
__host__ constexpr __device__ T min(T x)
Definition: math.hpp:116
Definition: cluster_descriptor.hpp:13
float F32
32-bit floating point (single precision) type
Definition: check_err.hpp:33
ck_tile::bf16_t BF16
16-bit brain floating point type
Definition: check_err.hpp:31
_BitInt(8) fp8_t
Definition: float8.hpp:204
CK_TILE_HOST bool check_size_mismatch(const Range &out, const RefRange &ref, const std::string &msg="Error: Incorrect results!")
Check for size mismatch between output and reference ranges.
Definition: check_err.hpp:201
int8_t int8_t
Definition: int8.hpp:20
bfloat16_t bf16_t
Definition: bfloat16.hpp:113
CK_TILE_HOST double get_relative_threshold(const int number_of_accumulations=1)
Calculate relative error threshold for numerical comparisons.
Definition: check_err.hpp:52
CK_TILE_HOST void report_error_stats(int err_count, double max_err, std::size_t total_size)
Report error statistics for numerical comparisons.
Definition: check_err.hpp:223
CK_TILE_HOST double get_absolute_threshold(const double max_possible_num, const int number_of_accumulations=1)
Calculate absolute error threshold for numerical comparisons.
Definition: check_err.hpp:112
std::ostream & operator<<(std::ostream &os, const std::vector< T > &v)
Stream operator overload for vector output.
Definition: check_err.hpp:172
std::enable_if< std::is_same_v< ranges::range_value_t< Range >, ranges::range_value_t< RefRange > > &&std::is_floating_point_v< ranges::range_value_t< Range > > &&!std::is_same_v< ranges::range_value_t< Range >, half_t >, bool >::type CK_TILE_HOST check_err(const Range &out, const RefRange &ref, const std::string &msg="Error: Incorrect results!", double rtol=1e-5, double atol=3e-6, bool allow_infinity_ref=false)
Check errors between floating point ranges using the specified tolerances.
Definition: check_err.hpp:254
ck_tile::half_t F16
16-bit floating point (half precision) type
Definition: check_err.hpp:29
int32_t int32_t
Definition: integer.hpp:10
ck_tile::bf8_t BF8
8-bit brain floating point type
Definition: check_err.hpp:27
int32_t I32
32-bit signed integer type
Definition: check_err.hpp:37
unsigned _BitInt(8) bf8_t
Definition: float8.hpp:206
ck_tile::fp8_t F8
8-bit floating point type
Definition: check_err.hpp:25
constexpr int ERROR_DETAIL_LIMIT
Maximum number of error values to display when checking errors.
Definition: check_err.hpp:22
constexpr CK_TILE_HOST_DEVICE T min(T x)
Definition: math.hpp:210
_Float16 half_t
Definition: half.hpp:111
int8_t I8
8-bit signed integer type
Definition: check_err.hpp:35
_BitInt(4) int4_t
Definition: data_type.hpp:31
std::enable_if< B, T > enable_if
Definition: enable_if.hpp:24
constexpr bool is_same_v
Definition: type.hpp:283
typename std::enable_if< B, T >::type enable_if_t
Definition: enable_if.hpp:27
signed __int64 int64_t
Definition: stdint.h:135
Definition: type_traits.hpp:115
Definition: math.hpp:395
Definition: numeric.hpp:81