include/ck_tile/host/check_err.hpp Source File

include/ck_tile/host/check_err.hpp Source File#

Composable Kernel: include/ck_tile/host/check_err.hpp Source File
check_err.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include <algorithm>
7 #include <cmath>
8 #include <cstdlib>
9 #include <iostream>
10 #include <iomanip>
11 #include <iterator>
12 #include <limits>
13 #include <type_traits>
14 #include <vector>
15 
16 #include "ck_tile/core.hpp"
17 #include "ck_tile/host/ranges.hpp"
18 
19 namespace ck_tile {
20 
21 template <typename ComputeDataType, typename OutDataType, typename AccDataType = ComputeDataType>
22 double get_relative_threshold(const int number_of_accumulations = 1)
23 {
24  using F8 = ck_tile::fp8_t;
25  using BF8 = ck_tile::bf8_t;
26  using F16 = ck_tile::half_t;
27  using BF16 = ck_tile::bf16_t;
28  using F32 = float;
29  using I8 = int8_t;
30  using I32 = int32_t;
31 
33  "Warning: Unhandled ComputeDataType for setting up the relative threshold!");
34 
35  double compute_error = 0;
37  {
38  return 0;
39  }
40  else
41  {
42  compute_error = std::pow(2, -numeric_traits<ComputeDataType>::mant) * 0.5;
43  }
44 
46  "Warning: Unhandled OutDataType for setting up the relative threshold!");
47 
48  double output_error = 0;
50  {
51  return 0;
52  }
53  else
54  {
55  output_error = std::pow(2, -numeric_traits<OutDataType>::mant) * 0.5;
56  }
57  double midway_error = std::max(compute_error, output_error);
58 
60  "Warning: Unhandled AccDataType for setting up the relative threshold!");
61 
62  double acc_error = 0;
64  {
65  return 0;
66  }
67  else
68  {
69  acc_error = std::pow(2, -numeric_traits<AccDataType>::mant) * 0.5 * number_of_accumulations;
70  }
71  return std::max(acc_error, midway_error);
72 }
73 
74 template <typename ComputeDataType, typename OutDataType, typename AccDataType = ComputeDataType>
75 double get_absolute_threshold(const double max_possible_num, const int number_of_accumulations = 1)
76 {
77  using F8 = ck_tile::fp8_t;
78  using BF8 = ck_tile::bf8_t;
79  using F16 = ck_tile::half_t;
80  using BF16 = ck_tile::bf16_t;
81  using F32 = float;
82  using I8 = int8_t;
83  using I32 = int32_t;
84 
86  "Warning: Unhandled ComputeDataType for setting up the absolute threshold!");
87 
88  auto expo = std::log2(std::abs(max_possible_num));
89  double compute_error = 0;
91  {
92  return 0;
93  }
94  else
95  {
96  compute_error = std::pow(2, expo - numeric_traits<ComputeDataType>::mant) * 0.5;
97  }
98 
100  "Warning: Unhandled OutDataType for setting up the absolute threshold!");
101 
102  double output_error = 0;
104  {
105  return 0;
106  }
107  else
108  {
109  output_error = std::pow(2, expo - numeric_traits<OutDataType>::mant) * 0.5;
110  }
111  double midway_error = std::max(compute_error, output_error);
112 
114  "Warning: Unhandled AccDataType for setting up the absolute threshold!");
115 
116  double acc_error = 0;
118  {
119  return 0;
120  }
121  else
122  {
123  acc_error =
124  std::pow(2, expo - numeric_traits<AccDataType>::mant) * 0.5 * number_of_accumulations;
125  }
126  return std::max(acc_error, midway_error);
127 }
128 
129 template <typename T>
130 std::ostream& operator<<(std::ostream& os, const std::vector<T>& v)
131 {
132  using size_type = typename std::vector<T>::size_type;
133 
134  os << "[";
135  for(size_type idx = 0; idx < v.size(); ++idx)
136  {
137  if(0 < idx)
138  {
139  os << ", ";
140  }
141  os << v[idx];
142  }
143  return os << "]";
144 }
145 
146 template <typename Range, typename RefRange>
147 typename std::enable_if<
148  std::is_same_v<ranges::range_value_t<Range>, ranges::range_value_t<RefRange>> &&
149  std::is_floating_point_v<ranges::range_value_t<Range>> &&
150  !std::is_same_v<ranges::range_value_t<Range>, half_t>,
151  bool>::type CK_TILE_HOST
152 check_err(const Range& out,
153  const RefRange& ref,
154  const std::string& msg = "Error: Incorrect results!",
155  double rtol = 1e-5,
156  double atol = 3e-6,
157  bool allow_infinity_ref = false)
158 {
159  if(out.size() != ref.size())
160  {
161  std::cerr << msg << " out.size() != ref.size(), :" << out.size() << " != " << ref.size()
162  << std::endl;
163  return false;
164  }
165 
166  const auto is_infinity_error = [=](auto o, auto r) {
167  const bool either_not_finite = !std::isfinite(o) || !std::isfinite(r);
168  const bool both_infinite_and_same =
169  std::isinf(o) && std::isinf(r) && (bit_cast<uint64_t>(o) == bit_cast<uint64_t>(r));
170 
171  return either_not_finite && !(allow_infinity_ref && both_infinite_and_same);
172  };
173 
174  bool res{true};
175  int err_count = 0;
176  double err = 0;
177  double max_err = std::numeric_limits<double>::min();
178  for(std::size_t i = 0; i < ref.size(); ++i)
179  {
180  const double o = *std::next(std::begin(out), i);
181  const double r = *std::next(std::begin(ref), i);
182  err = std::abs(o - r);
183  if(err > atol + rtol * std::abs(r) || is_infinity_error(o, r))
184  {
185  max_err = err > max_err ? err : max_err;
186  err_count++;
187  if(err_count < 5)
188  {
189  std::cerr << msg << std::setw(12) << std::setprecision(7) << " out[" << i
190  << "] != ref[" << i << "]: " << o << " != " << r << std::endl;
191  }
192  res = false;
193  }
194  }
195  if(!res)
196  {
197  const float error_percent =
198  static_cast<float>(err_count) / static_cast<float>(out.size()) * 100.f;
199  std::cerr << "max err: " << max_err;
200  std::cerr << ", number of errors: " << err_count;
201  std::cerr << ", " << error_percent << "% wrong values" << std::endl;
202  }
203  return res;
204 }
205 
206 template <typename Range, typename RefRange>
207 typename std::enable_if<
208  std::is_same_v<ranges::range_value_t<Range>, ranges::range_value_t<RefRange>> &&
209  std::is_same_v<ranges::range_value_t<Range>, bf16_t>,
210  bool>::type CK_TILE_HOST
211 check_err(const Range& out,
212  const RefRange& ref,
213  const std::string& msg = "Error: Incorrect results!",
214  double rtol = 1e-3,
215  double atol = 1e-3,
216  bool allow_infinity_ref = false)
217 {
218  if(out.size() != ref.size())
219  {
220  std::cerr << msg << " out.size() != ref.size(), :" << out.size() << " != " << ref.size()
221  << std::endl;
222  return false;
223  }
224 
225  const auto is_infinity_error = [=](auto o, auto r) {
226  const bool either_not_finite = !std::isfinite(o) || !std::isfinite(r);
227  const bool both_infinite_and_same =
228  std::isinf(o) && std::isinf(r) && (bit_cast<uint64_t>(o) == bit_cast<uint64_t>(r));
229 
230  return either_not_finite && !(allow_infinity_ref && both_infinite_and_same);
231  };
232 
233  bool res{true};
234  int err_count = 0;
235  double err = 0;
236  // TODO: This is a hack. We should have proper specialization for bf16_t data type.
237  double max_err = std::numeric_limits<float>::min();
238  for(std::size_t i = 0; i < ref.size(); ++i)
239  {
240  const double o = type_convert<float>(*std::next(std::begin(out), i));
241  const double r = type_convert<float>(*std::next(std::begin(ref), i));
242  err = std::abs(o - r);
243  if(err > atol + rtol * std::abs(r) || is_infinity_error(o, r))
244  {
245  max_err = err > max_err ? err : max_err;
246  err_count++;
247  if(err_count < 5)
248  {
249  std::cerr << msg << std::setw(12) << std::setprecision(7) << " out[" << i
250  << "] != ref[" << i << "]: " << o << " != " << r << std::endl;
251  }
252  res = false;
253  }
254  }
255  if(!res)
256  {
257  const float error_percent =
258  static_cast<float>(err_count) / static_cast<float>(out.size()) * 100.f;
259  std::cerr << "max err: " << max_err;
260  std::cerr << ", number of errors: " << err_count;
261  std::cerr << ", " << error_percent << "% wrong values" << std::endl;
262  }
263  return res;
264 }
265 
266 template <typename Range, typename RefRange>
267 typename std::enable_if<
268  std::is_same_v<ranges::range_value_t<Range>, ranges::range_value_t<RefRange>> &&
269  std::is_same_v<ranges::range_value_t<Range>, half_t>,
270  bool>::type CK_TILE_HOST
271 check_err(const Range& out,
272  const RefRange& ref,
273  const std::string& msg = "Error: Incorrect results!",
274  double rtol = 1e-3,
275  double atol = 1e-3,
276  bool allow_infinity_ref = false)
277 {
278  if(out.size() != ref.size())
279  {
280  std::cerr << msg << " out.size() != ref.size(), :" << out.size() << " != " << ref.size()
281  << std::endl;
282  return false;
283  }
284 
285  const auto is_infinity_error = [=](auto o, auto r) {
286  const bool either_not_finite = !std::isfinite(o) || !std::isfinite(r);
287  const bool both_infinite_and_same =
288  std::isinf(o) && std::isinf(r) && (bit_cast<uint64_t>(o) == bit_cast<uint64_t>(r));
289 
290  return either_not_finite && !(allow_infinity_ref && both_infinite_and_same);
291  };
292 
293  bool res{true};
294  int err_count = 0;
295  double err = 0;
296  double max_err = static_cast<double>(std::numeric_limits<ranges::range_value_t<Range>>::min());
297  for(std::size_t i = 0; i < ref.size(); ++i)
298  {
299  const double o = type_convert<float>(*std::next(std::begin(out), i));
300  const double r = type_convert<float>(*std::next(std::begin(ref), i));
301  err = std::abs(o - r);
302  if(err > atol + rtol * std::abs(r) || is_infinity_error(o, r))
303  {
304  max_err = err > max_err ? err : max_err;
305  err_count++;
306  if(err_count < 5)
307  {
308  std::cerr << msg << std::setw(12) << std::setprecision(7) << " out[" << i
309  << "] != ref[" << i << "]: " << o << " != " << r << std::endl;
310  }
311  res = false;
312  }
313  }
314  if(!res)
315  {
316  const float error_percent =
317  static_cast<float>(err_count) / static_cast<float>(out.size()) * 100.f;
318  std::cerr << "max err: " << max_err;
319  std::cerr << ", number of errors: " << err_count;
320  std::cerr << ", " << error_percent << "% wrong values" << std::endl;
321  }
322  return res;
323 }
324 
325 template <typename Range, typename RefRange>
326 std::enable_if_t<(std::is_same_v<ranges::range_value_t<Range>, ranges::range_value_t<RefRange>> &&
327  std::is_integral_v<ranges::range_value_t<Range>> &&
328  !std::is_same_v<ranges::range_value_t<Range>, bf16_t>)
329 #ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
330  || std::is_same_v<ranges::range_value_t<Range>, int4_t>
331 #endif
332  ,
333  bool>
334  CK_TILE_HOST check_err(const Range& out,
335  const RefRange& ref,
336  const std::string& msg = "Error: Incorrect results!",
337  double = 0,
338  double atol = 0)
339 {
340  if(out.size() != ref.size())
341  {
342  std::cerr << msg << " out.size() != ref.size(), :" << out.size() << " != " << ref.size()
343  << std::endl;
344  return false;
345  }
346 
347  bool res{true};
348  int err_count = 0;
349  int64_t err = 0;
351  for(std::size_t i = 0; i < ref.size(); ++i)
352  {
353  const int64_t o = *std::next(std::begin(out), i);
354  const int64_t r = *std::next(std::begin(ref), i);
355  err = std::abs(o - r);
356 
357  if(err > atol)
358  {
359  max_err = err > max_err ? err : max_err;
360  err_count++;
361  if(err_count < 5)
362  {
363  std::cerr << msg << " out[" << i << "] != ref[" << i << "]: " << o << " != " << r
364  << std::endl;
365  }
366  res = false;
367  }
368  }
369  if(!res)
370  {
371  const float error_percent =
372  static_cast<float>(err_count) / static_cast<float>(out.size()) * 100.f;
373  std::cerr << "max err: " << max_err;
374  std::cerr << ", number of errors: " << err_count;
375  std::cerr << ", " << error_percent << "% wrong values" << std::endl;
376  }
377  return res;
378 }
379 
380 template <typename Range, typename RefRange>
381 std::enable_if_t<(std::is_same_v<ranges::range_value_t<Range>, ranges::range_value_t<RefRange>> &&
382  std::is_same_v<ranges::range_value_t<Range>, fp8_t>),
383  bool>
384  CK_TILE_HOST check_err(const Range& out,
385  const RefRange& ref,
386  const std::string& msg = "Error: Incorrect results!",
387  unsigned max_rounding_point_distance = 1,
388  double atol = 1e-1,
389  bool allow_infinity_ref = false)
390 {
391  if(out.size() != ref.size())
392  {
393  std::cerr << msg << " out.size() != ref.size(), :" << out.size() << " != " << ref.size()
394  << std::endl;
395  return false;
396  }
397 
398  const auto is_infinity_error = [=](auto o, auto r) {
399  const bool either_not_finite = !std::isfinite(o) || !std::isfinite(r);
400  const bool both_infinite_and_same =
401  std::isinf(o) && std::isinf(r) && (bit_cast<uint64_t>(o) == bit_cast<uint64_t>(r));
402 
403  return either_not_finite && !(allow_infinity_ref && both_infinite_and_same);
404  };
405 
406  static const auto get_rounding_point_distance = [](fp8_t o, fp8_t r) -> unsigned {
407  static const auto get_sign_bit = [](fp8_t v) -> bool {
408  return 0x80 & bit_cast<uint8_t>(v);
409  };
410 
411  if(get_sign_bit(o) ^ get_sign_bit(r))
412  {
414  }
415  else
416  {
417  return std::abs(bit_cast<int8_t>(o) - bit_cast<int8_t>(r));
418  }
419  };
420 
421  bool res{true};
422  int err_count = 0;
423  double err = 0;
424  double max_err = std::numeric_limits<float>::min();
425  for(std::size_t i = 0; i < ref.size(); ++i)
426  {
427  const fp8_t o_fp8 = *std::next(std::begin(out), i);
428  const fp8_t r_fp8 = *std::next(std::begin(ref), i);
429  const double o_fp64 = type_convert<float>(o_fp8);
430  const double r_fp64 = type_convert<float>(r_fp8);
431  err = std::abs(o_fp64 - r_fp64);
432  if(!(less_equal<double>{}(err, atol) ||
433  get_rounding_point_distance(o_fp8, r_fp8) <= max_rounding_point_distance) ||
434  is_infinity_error(o_fp64, r_fp64))
435  {
436  max_err = err > max_err ? err : max_err;
437  err_count++;
438  if(err_count < 5)
439  {
440  std::cerr << msg << std::setw(12) << std::setprecision(7) << " out[" << i
441  << "] != ref[" << i << "]: " << o_fp64 << " != " << r_fp64 << std::endl;
442  }
443  res = false;
444  }
445  }
446  if(!res)
447  {
448  const float error_percent =
449  static_cast<float>(err_count) / static_cast<float>(out.size()) * 100.f;
450  std::cerr << "max err: " << max_err;
451  std::cerr << ", number of errors: " << err_count;
452  std::cerr << ", " << error_percent << "% wrong values" << std::endl;
453  }
454  return res;
455 }
456 
457 template <typename Range, typename RefRange>
458 std::enable_if_t<(std::is_same_v<ranges::range_value_t<Range>, ranges::range_value_t<RefRange>> &&
459  std::is_same_v<ranges::range_value_t<Range>, bf8_t>),
460  bool>
461  CK_TILE_HOST check_err(const Range& out,
462  const RefRange& ref,
463  const std::string& msg = "Error: Incorrect results!",
464  double rtol = 1e-3,
465  double atol = 1e-3,
466  bool allow_infinity_ref = false)
467 {
468  if(out.size() != ref.size())
469  {
470  std::cerr << msg << " out.size() != ref.size(), :" << out.size() << " != " << ref.size()
471  << std::endl;
472  return false;
473  }
474 
475  const auto is_infinity_error = [=](auto o, auto r) {
476  const bool either_not_finite = !std::isfinite(o) || !std::isfinite(r);
477  const bool both_infinite_and_same =
478  std::isinf(o) && std::isinf(r) && (bit_cast<uint64_t>(o) == bit_cast<uint64_t>(r));
479 
480  return either_not_finite && !(allow_infinity_ref && both_infinite_and_same);
481  };
482 
483  bool res{true};
484  int err_count = 0;
485  double err = 0;
486  double max_err = std::numeric_limits<float>::min();
487  for(std::size_t i = 0; i < ref.size(); ++i)
488  {
489  const double o = type_convert<float>(*std::next(std::begin(out), i));
490  const double r = type_convert<float>(*std::next(std::begin(ref), i));
491  err = std::abs(o - r);
492  if(err > atol + rtol * std::abs(r) || is_infinity_error(o, r))
493  {
494  max_err = err > max_err ? err : max_err;
495  err_count++;
496  if(err_count < 5)
497  {
498  std::cerr << msg << std::setw(12) << std::setprecision(7) << " out[" << i
499  << "] != ref[" << i << "]: " << o << " != " << r << std::endl;
500  }
501  res = false;
502  }
503  }
504  if(!res)
505  {
506  const float error_percent =
507  static_cast<float>(err_count) / static_cast<float>(out.size()) * 100.f;
508  std::cerr << "max err: " << max_err;
509  std::cerr << ", number of errors: " << err_count;
510  std::cerr << ", " << error_percent << "% wrong values" << std::endl;
511  }
512  return res;
513 }
514 
515 } // namespace ck_tile
#define CK_TILE_HOST
Definition: config.hpp:39
__host__ T pow(T x, T gamma)
Definition: math_v2.hpp:427
__host__ constexpr __device__ T max(T x)
Definition: math.hpp:84
__host__ constexpr __device__ T min(T x)
Definition: math.hpp:116
Definition: cluster_descriptor.hpp:13
_BitInt(8) fp8_t
Definition: float8.hpp:204
int8_t int8_t
Definition: int8.hpp:20
double get_absolute_threshold(const double max_possible_num, const int number_of_accumulations=1)
Definition: check_err.hpp:75
bfloat16_t bf16_t
Definition: bfloat16.hpp:106
std::ostream & operator<<(std::ostream &os, const std::vector< T > &v)
Definition: check_err.hpp:130
std::enable_if< std::is_same_v< ranges::range_value_t< Range >, ranges::range_value_t< RefRange > > &&std::is_floating_point_v< ranges::range_value_t< Range > > &&!std::is_same_v< ranges::range_value_t< Range >, half_t >, bool >::type CK_TILE_HOST check_err(const Range &out, const RefRange &ref, const std::string &msg="Error: Incorrect results!", double rtol=1e-5, double atol=3e-6, bool allow_infinity_ref=false)
Definition: check_err.hpp:152
unsigned _BitInt(8) bf8_t
Definition: float8.hpp:206
double get_relative_threshold(const int number_of_accumulations=1)
Definition: check_err.hpp:22
constexpr CK_TILE_HOST_DEVICE T min(T x)
Definition: math.hpp:210
_Float16 half_t
Definition: half.hpp:111
_BitInt(4) int4_t
Definition: data_type.hpp:26
std::enable_if< B, T > enable_if
Definition: enable_if.hpp:10
constexpr bool is_same_v
Definition: type.hpp:283
long int64_t
Definition: data_type.hpp:2474
typename std::enable_if< B, T >::type enable_if_t
Definition: enable_if.hpp:13
Definition: type_traits.hpp:114
Definition: math.hpp:395
Definition: bfloat16.hpp:380