/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/host/check_err.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/host/check_err.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/host/check_err.hpp Source File
check_err.hpp
Go to the documentation of this file.
1 // Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
2 // SPDX-License-Identifier: MIT
3 
4 #pragma once
5 
6 #include <algorithm>
7 #include <cmath>
8 #include <cstdlib>
9 #include <iostream>
10 #include <iomanip>
11 #include <iterator>
12 #include <limits>
13 #include <type_traits>
14 #include <vector>
15 
16 #include "ck_tile/core.hpp"
17 #include "ck_tile/host/ranges.hpp"
18 
19 namespace ck_tile {
20 
22 constexpr int ERROR_DETAIL_LIMIT = 128;
23 
33 using F32 = float;
35 using I8 = int8_t;
37 using I32 = int32_t;
38 
51 template <typename ComputeDataType, typename OutDataType, typename AccDataType = ComputeDataType>
52 CK_TILE_HOST double get_relative_threshold(const int number_of_accumulations = 1)
53 {
54 
55  static_assert(is_any_of<ComputeDataType,
56  F8,
57  BF8,
58  F16,
59  BF16,
60  F32,
61  pk_fp4_t,
63  pk_int4_t,
64  I8,
65  I32,
66  int>::value,
67  "Warning: Unhandled ComputeDataType for setting up the relative threshold!");
68 
69  double compute_error = 0;
71  {
72  return 0;
73  }
74  else
75  {
76  compute_error = std::pow(2, -numeric_traits<ComputeDataType>::mant) * 0.5;
77  }
78 
80  "Warning: Unhandled OutDataType for setting up the relative threshold!");
81 
82  double output_error = 0;
84  {
85  return 0;
86  }
87  else
88  {
89  output_error = std::pow(2, -numeric_traits<OutDataType>::mant) * 0.5;
90  }
91  double midway_error = std::max(compute_error, output_error);
92 
94  "Warning: Unhandled AccDataType for setting up the relative threshold!");
95 
96  double acc_error = 0;
98  {
99  return 0;
100  }
101  else
102  {
103  acc_error = std::pow(2, -numeric_traits<AccDataType>::mant) * 0.5 * number_of_accumulations;
104  }
105  return std::max(acc_error, midway_error);
106 }
107 
121 template <typename ComputeDataType, typename OutDataType, typename AccDataType = ComputeDataType>
122 CK_TILE_HOST double get_absolute_threshold(const double max_possible_num,
123  const int number_of_accumulations = 1)
124 {
125 
126  static_assert(is_any_of<ComputeDataType,
127  F8,
128  BF8,
129  F16,
130  BF16,
131  F32,
132  pk_fp4_t,
133  pk_fp4_raw_t,
134  pk_int4_t,
135  I8,
136  I32,
137  int>::value,
138  "Warning: Unhandled ComputeDataType for setting up the absolute threshold!");
139 
140  auto expo = std::log2(std::abs(max_possible_num));
141  double compute_error = 0;
143  {
144  return 0;
145  }
146  else
147  {
148  compute_error = std::pow(2, expo - numeric_traits<ComputeDataType>::mant) * 0.5;
149  }
150 
152  "Warning: Unhandled OutDataType for setting up the absolute threshold!");
153 
154  double output_error = 0;
156  {
157  return 0;
158  }
159  else
160  {
161  output_error = std::pow(2, expo - numeric_traits<OutDataType>::mant) * 0.5;
162  }
163  double midway_error = std::max(compute_error, output_error);
164 
166  "Warning: Unhandled AccDataType for setting up the absolute threshold!");
167 
168  double acc_error = 0;
170  {
171  return 0;
172  }
173  else
174  {
175  acc_error =
176  std::pow(2, expo - numeric_traits<AccDataType>::mant) * 0.5 * number_of_accumulations;
177  }
178  return std::max(acc_error, midway_error);
179 }
180 
191 template <typename T>
192 std::ostream& operator<<(std::ostream& os, const std::vector<T>& v)
193 {
194  using size_type = typename std::vector<T>::size_type;
195 
196  os << "[";
197  for(size_type idx = 0; idx < v.size(); ++idx)
198  {
199  if(0 < idx)
200  {
201  os << ", ";
202  }
203  os << v[idx];
204  }
205  return os << "]";
206 }
207 
220 template <typename Range, typename RefRange>
221 CK_TILE_HOST bool check_size_mismatch(const Range& out,
222  const RefRange& ref,
223  const std::string& msg = "Error: Incorrect results!")
224 {
225  if(out.size() != ref.size())
226  {
227  std::cerr << msg << " out.size() != ref.size(), :" << out.size() << " != " << ref.size()
228  << std::endl;
229  return true;
230  }
231  return false;
232 }
233 
243 CK_TILE_HOST void report_error_stats(int err_count, double max_err, std::size_t total_size)
244 {
245  const float error_percent =
246  static_cast<float>(err_count) / static_cast<float>(total_size) * 100.f;
247  std::cerr << "max err: " << max_err;
248  std::cerr << ", number of errors: " << err_count;
249  std::cerr << ", " << error_percent << "% wrong values" << std::endl;
250 }
251 
268 template <typename Range, typename RefRange>
269 typename std::enable_if<
270  std::is_same_v<ranges::range_value_t<Range>, ranges::range_value_t<RefRange>> &&
271  std::is_floating_point_v<ranges::range_value_t<Range>> &&
272  !std::is_same_v<ranges::range_value_t<Range>, half_t>,
273  bool>::type CK_TILE_HOST
274 check_err(const Range& out,
275  const RefRange& ref,
276  const std::string& msg = "Error: Incorrect results!",
277  double rtol = 1e-5,
278  double atol = 3e-6,
279  bool allow_infinity_ref = false)
280 {
281 
282  if(check_size_mismatch(out, ref, msg))
283  return false;
284 
285  const auto is_infinity_error = [=](auto o, auto r) {
286  const bool either_not_finite = !std::isfinite(o) || !std::isfinite(r);
287  const bool both_infinite_and_same =
288  std::isinf(o) && std::isinf(r) && (bit_cast<uint64_t>(o) == bit_cast<uint64_t>(r));
289 
290  return either_not_finite && !(allow_infinity_ref && both_infinite_and_same);
291  };
292 
293  bool res{true};
294  int err_count = 0;
295  double err = 0;
296  double max_err = std::numeric_limits<double>::min();
297  for(std::size_t i = 0; i < ref.size(); ++i)
298  {
299  const double o = *std::next(std::begin(out), i);
300  const double r = *std::next(std::begin(ref), i);
301  err = std::abs(o - r);
302  if(err > atol + rtol * std::abs(r) || is_infinity_error(o, r))
303  {
304  max_err = err > max_err ? err : max_err;
305  err_count++;
306  if(err_count < ERROR_DETAIL_LIMIT)
307  {
308  std::cerr << msg << std::setw(12) << std::setprecision(7) << " out[" << i
309  << "] != ref[" << i << "]: " << o << " != " << r << std::endl;
310  }
311  res = false;
312  }
313  }
314  if(!res)
315  {
316  report_error_stats(err_count, max_err, ref.size());
317  }
318  return res;
319 }
320 
337 template <typename Range, typename RefRange>
338 typename std::enable_if<
339  std::is_same_v<ranges::range_value_t<Range>, ranges::range_value_t<RefRange>> &&
340  std::is_same_v<ranges::range_value_t<Range>, bf16_t>,
341  bool>::type CK_TILE_HOST
342 check_err(const Range& out,
343  const RefRange& ref,
344  const std::string& msg = "Error: Incorrect results!",
345  double rtol = 1e-3,
346  double atol = 1e-3,
347  bool allow_infinity_ref = false)
348 {
349  if(check_size_mismatch(out, ref, msg))
350  return false;
351 
352  const auto is_infinity_error = [=](auto o, auto r) {
353  const bool either_not_finite = !std::isfinite(o) || !std::isfinite(r);
354  const bool both_infinite_and_same =
355  std::isinf(o) && std::isinf(r) && (bit_cast<uint64_t>(o) == bit_cast<uint64_t>(r));
356 
357  return either_not_finite && !(allow_infinity_ref && both_infinite_and_same);
358  };
359 
360  bool res{true};
361  int err_count = 0;
362  double err = 0;
363  // TODO: This is a hack. We should have proper specialization for bf16_t data type.
364  double max_err = std::numeric_limits<float>::min();
365  for(std::size_t i = 0; i < ref.size(); ++i)
366  {
367  const double o = type_convert<float>(*std::next(std::begin(out), i));
368  const double r = type_convert<float>(*std::next(std::begin(ref), i));
369  err = std::abs(o - r);
370  if(err > atol + rtol * std::abs(r) || is_infinity_error(o, r))
371  {
372  max_err = err > max_err ? err : max_err;
373  err_count++;
374  if(err_count < ERROR_DETAIL_LIMIT)
375  {
376  std::cerr << msg << std::setw(12) << std::setprecision(7) << " out[" << i
377  << "] != ref[" << i << "]: " << o << " != " << r << std::endl;
378  }
379  res = false;
380  }
381  }
382  if(!res)
383  {
384  report_error_stats(err_count, max_err, ref.size());
385  }
386  return res;
387 }
388 
406 template <typename Range, typename RefRange>
407 typename std::enable_if<
408  std::is_same_v<ranges::range_value_t<Range>, ranges::range_value_t<RefRange>> &&
409  std::is_same_v<ranges::range_value_t<Range>, half_t>,
410  bool>::type CK_TILE_HOST
411 check_err(const Range& out,
412  const RefRange& ref,
413  const std::string& msg = "Error: Incorrect results!",
414  double rtol = 1e-3,
415  double atol = 1e-3,
416  bool allow_infinity_ref = false)
417 {
418  if(check_size_mismatch(out, ref, msg))
419  return false;
420 
421  const auto is_infinity_error = [=](auto o, auto r) {
422  const bool either_not_finite = !std::isfinite(o) || !std::isfinite(r);
423  const bool both_infinite_and_same =
424  std::isinf(o) && std::isinf(r) && (bit_cast<uint64_t>(o) == bit_cast<uint64_t>(r));
425 
426  return either_not_finite && !(allow_infinity_ref && both_infinite_and_same);
427  };
428 
429  bool res{true};
430  int err_count = 0;
431  double err = 0;
432  double max_err = static_cast<double>(std::numeric_limits<ranges::range_value_t<Range>>::min());
433  for(std::size_t i = 0; i < ref.size(); ++i)
434  {
435  const double o = type_convert<float>(*std::next(std::begin(out), i));
436  const double r = type_convert<float>(*std::next(std::begin(ref), i));
437  err = std::abs(o - r);
438  if(err > atol + rtol * std::abs(r) || is_infinity_error(o, r))
439  {
440  max_err = err > max_err ? err : max_err;
441  err_count++;
442  if(err_count < ERROR_DETAIL_LIMIT)
443  {
444  std::cerr << msg << std::setw(12) << std::setprecision(7) << " out[" << i
445  << "] != ref[" << i << "]: " << o << " != " << r << std::endl;
446  }
447  res = false;
448  }
449  }
450  if(!res)
451  {
452  report_error_stats(err_count, max_err, ref.size());
453  }
454  return res;
455 }
456 
472 template <typename Range, typename RefRange>
473 std::enable_if_t<(std::is_same_v<ranges::range_value_t<Range>, ranges::range_value_t<RefRange>> &&
474  std::is_integral_v<ranges::range_value_t<Range>> &&
475  !std::is_same_v<ranges::range_value_t<Range>, bf16_t>)
476 #ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
477  || std::is_same_v<ranges::range_value_t<Range>, int4_t>
478 #endif
479  ,
480  bool>
481  CK_TILE_HOST check_err(const Range& out,
482  const RefRange& ref,
483  const std::string& msg = "Error: Incorrect results!",
484  double = 0,
485  double atol = 0)
486 {
487  if(check_size_mismatch(out, ref, msg))
488  return false;
489 
490  bool res{true};
491  int err_count = 0;
492  int64_t err = 0;
494  for(std::size_t i = 0; i < ref.size(); ++i)
495  {
496  const int64_t o = *std::next(std::begin(out), i);
497  const int64_t r = *std::next(std::begin(ref), i);
498  err = std::abs(o - r);
499 
500  if(err > atol)
501  {
502  max_err = err > max_err ? err : max_err;
503  err_count++;
504  if(err_count < ERROR_DETAIL_LIMIT)
505  {
506  std::cerr << msg << " out[" << i << "] != ref[" << i << "]: " << o << " != " << r
507  << std::endl;
508  }
509  res = false;
510  }
511  }
512  if(!res)
513  {
514  report_error_stats(err_count, static_cast<double>(max_err), ref.size());
515  }
516  return res;
517 }
518 
536 template <typename Range, typename RefRange>
537 std::enable_if_t<(std::is_same_v<ranges::range_value_t<Range>, ranges::range_value_t<RefRange>> &&
538  std::is_same_v<ranges::range_value_t<Range>, fp8_t>),
539  bool>
540  CK_TILE_HOST check_err(const Range& out,
541  const RefRange& ref,
542  const std::string& msg = "Error: Incorrect results!",
543  unsigned max_rounding_point_distance = 1,
544  double atol = 1e-1,
545  bool allow_infinity_ref = false)
546 {
547  if(check_size_mismatch(out, ref, msg))
548  return false;
549 
550  const auto is_infinity_error = [=](auto o, auto r) {
551  const bool either_not_finite = !std::isfinite(o) || !std::isfinite(r);
552  const bool both_infinite_and_same =
553  std::isinf(o) && std::isinf(r) && (bit_cast<uint64_t>(o) == bit_cast<uint64_t>(r));
554 
555  return either_not_finite && !(allow_infinity_ref && both_infinite_and_same);
556  };
557 
558  static const auto get_rounding_point_distance = [](fp8_t o, fp8_t r) -> unsigned {
559  static const auto get_sign_bit = [](fp8_t v) -> bool {
560  return 0x80 & bit_cast<uint8_t>(v);
561  };
562 
563  if(get_sign_bit(o) ^ get_sign_bit(r))
564  {
566  }
567  else
568  {
569  return std::abs(bit_cast<int8_t>(o) - bit_cast<int8_t>(r));
570  }
571  };
572 
573  bool res{true};
574  int err_count = 0;
575  double err = 0;
576  double max_err = std::numeric_limits<float>::min();
577  for(std::size_t i = 0; i < ref.size(); ++i)
578  {
579  const fp8_t o_fp8 = *std::next(std::begin(out), i);
580  const fp8_t r_fp8 = *std::next(std::begin(ref), i);
581  const double o_fp64 = type_convert<float>(o_fp8);
582  const double r_fp64 = type_convert<float>(r_fp8);
583  err = std::abs(o_fp64 - r_fp64);
584  if(!(less_equal<double>{}(err, atol) ||
585  get_rounding_point_distance(o_fp8, r_fp8) <= max_rounding_point_distance) ||
586  is_infinity_error(o_fp64, r_fp64))
587  {
588  max_err = err > max_err ? err : max_err;
589  err_count++;
590  if(err_count < ERROR_DETAIL_LIMIT)
591  {
592  std::cerr << msg << std::setw(12) << std::setprecision(7) << " out[" << i
593  << "] != ref[" << i << "]: " << o_fp64 << " != " << r_fp64 << std::endl;
594  }
595  res = false;
596  }
597  }
598  if(!res)
599  {
600  report_error_stats(err_count, max_err, ref.size());
601  }
602  return res;
603 }
604 
621 template <typename Range, typename RefRange>
622 std::enable_if_t<(std::is_same_v<ranges::range_value_t<Range>, ranges::range_value_t<RefRange>> &&
623  std::is_same_v<ranges::range_value_t<Range>, bf8_t>),
624  bool>
625  CK_TILE_HOST check_err(const Range& out,
626  const RefRange& ref,
627  const std::string& msg = "Error: Incorrect results!",
628  double rtol = 1e-3,
629  double atol = 1e-3,
630  bool allow_infinity_ref = false)
631 {
632  if(check_size_mismatch(out, ref, msg))
633  return false;
634 
635  const auto is_infinity_error = [=](auto o, auto r) {
636  const bool either_not_finite = !std::isfinite(o) || !std::isfinite(r);
637  const bool both_infinite_and_same =
638  std::isinf(o) && std::isinf(r) && (bit_cast<uint64_t>(o) == bit_cast<uint64_t>(r));
639 
640  return either_not_finite && !(allow_infinity_ref && both_infinite_and_same);
641  };
642 
643  bool res{true};
644  int err_count = 0;
645  double err = 0;
646  double max_err = std::numeric_limits<float>::min();
647  for(std::size_t i = 0; i < ref.size(); ++i)
648  {
649  const double o = type_convert<float>(*std::next(std::begin(out), i));
650  const double r = type_convert<float>(*std::next(std::begin(ref), i));
651  err = std::abs(o - r);
652  if(err > atol + rtol * std::abs(r) || is_infinity_error(o, r))
653  {
654  max_err = err > max_err ? err : max_err;
655  err_count++;
656  if(err_count < ERROR_DETAIL_LIMIT)
657  {
658  std::cerr << msg << std::setw(12) << std::setprecision(7) << " out[" << i
659  << "] != ref[" << i << "]: " << o << " != " << r << std::endl;
660  }
661  res = false;
662  }
663  }
664  if(!res)
665  {
666  report_error_stats(err_count, max_err, ref.size());
667  }
668  return res;
669 }
670 
684 template <typename Range, typename RefRange>
685 std::enable_if_t<(std::is_same_v<ranges::range_value_t<Range>, ranges::range_value_t<RefRange>> &&
686  std::is_same_v<ranges::range_value_t<Range>, pk_fp4_t>),
687  bool>
688  CK_TILE_HOST check_err(const Range& out,
689  const RefRange& ref,
690  const std::string& msg = "Error: Incorrect results!",
691  double = 0,
692  double = 0)
693 {
694  if(check_size_mismatch(out, ref, msg))
695  return false;
696 
697  int err_count = 0;
698 
699  auto update_err = [&](pk_fp4_raw_t o, pk_fp4_raw_t r, std::size_t index) {
700  if(o != r)
701  {
702  std::cerr << msg << " out[" << index << "] != ref[" << index
703  << "]: " << type_convert<float>(pk_fp4_t{o})
704  << " != " << type_convert<float>(pk_fp4_t{r}) << std::endl;
705  ++err_count;
706  }
707  };
708 
709  for(std::size_t i = 0; i < ref.size(); ++i)
710  {
711  const pk_fp4_t o = *std::next(std::begin(out), i);
712  const pk_fp4_t r = *std::next(std::begin(ref), i);
713  update_err(o._unpack(number<0>{}), r._unpack(number<0>{}), i * 2);
714  update_err(o._unpack(number<1>{}), r._unpack(number<1>{}), i * 2 + 1);
715  }
716  if(err_count > 0)
717  {
718  report_error_stats(err_count, numeric<pk_fp4_t>::max(), ref.size());
719  }
720  return err_count == 0;
721 }
722 
723 } // namespace ck_tile
#define CK_TILE_HOST
Definition: config.hpp:44
__host__ T pow(T x, T gamma)
Definition: math_v2.hpp:427
__host__ constexpr __device__ T max(T x)
Definition: math.hpp:84
__host__ constexpr __device__ T min(T x)
Definition: math.hpp:116
Definition: cluster_descriptor.hpp:13
float F32
32-bit floating point (single precision) type
Definition: check_err.hpp:33
typename pk_fp4_t::type pk_fp4_raw_t
Definition: pk_fp4.hpp:152
ck_tile::bf16_t BF16
16-bit brain floating point type
Definition: check_err.hpp:31
_BitInt(8) fp8_t
Definition: float8.hpp:204
CK_TILE_HOST bool check_size_mismatch(const Range &out, const RefRange &ref, const std::string &msg="Error: Incorrect results!")
Check for size mismatch between output and reference ranges.
Definition: check_err.hpp:221
int8_t int8_t
Definition: int8.hpp:20
bfloat16_t bf16_t
Definition: bfloat16.hpp:113
CK_TILE_HOST double get_relative_threshold(const int number_of_accumulations=1)
Calculate relative error threshold for numerical comparisons.
Definition: check_err.hpp:52
pk_float4_e2m1_t pk_fp4_t
Definition: pk_fp4.hpp:151
CK_TILE_HOST void report_error_stats(int err_count, double max_err, std::size_t total_size)
Report error statistics for numerical comparisons.
Definition: check_err.hpp:243
CK_TILE_HOST double get_absolute_threshold(const double max_possible_num, const int number_of_accumulations=1)
Calculate absolute error threshold for numerical comparisons.
Definition: check_err.hpp:122
std::ostream & operator<<(std::ostream &os, const std::vector< T > &v)
Stream operator overload for vector output.
Definition: check_err.hpp:192
std::enable_if< std::is_same_v< ranges::range_value_t< Range >, ranges::range_value_t< RefRange > > &&std::is_floating_point_v< ranges::range_value_t< Range > > &&!std::is_same_v< ranges::range_value_t< Range >, half_t >, bool >::type CK_TILE_HOST check_err(const Range &out, const RefRange &ref, const std::string &msg="Error: Incorrect results!", double rtol=1e-5, double atol=3e-6, bool allow_infinity_ref=false)
Check errors between floating point ranges using the specified tolerances.
Definition: check_err.hpp:274
ck_tile::half_t F16
16-bit floating point (half precision) type
Definition: check_err.hpp:29
int32_t int32_t
Definition: integer.hpp:10
ck_tile::bf8_t BF8
8-bit brain floating point type
Definition: check_err.hpp:27
int32_t I32
32-bit signed integer type
Definition: check_err.hpp:37
unsigned _BitInt(8) bf8_t
Definition: float8.hpp:206
ck_tile::fp8_t F8
8-bit floating point type
Definition: check_err.hpp:25
constexpr int ERROR_DETAIL_LIMIT
Maximum number of error values to display when checking errors.
Definition: check_err.hpp:22
constexpr CK_TILE_HOST_DEVICE T min(T x)
Definition: math.hpp:206
_Float16 half_t
Definition: half.hpp:111
int8_t I8
8-bit signed integer type
Definition: check_err.hpp:35
_BitInt(4) int4_t
Definition: data_type.hpp:32
std::enable_if< B, T > enable_if
Definition: enable_if.hpp:24
constexpr bool is_same_v
Definition: type.hpp:283
typename std::enable_if< B, T >::type enable_if_t
Definition: enable_if.hpp:27
const GenericPointer< typename T::ValueType > T2 value
Definition: pointer.h:1697
signed __int64 int64_t
Definition: stdint.h:135
Definition: integral_constant.hpp:13
Definition: type_traits.hpp:115
Definition: math.hpp:389
Definition: numeric.hpp:81
Definition: numeric.hpp:18
Definition: pk_fp4.hpp:76
constexpr CK_TILE_HOST_DEVICE type _unpack(number< I >) const
Definition: pk_int4.hpp:21