/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/library/utility/check_err.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/library/utility/check_err.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/library/utility/check_err.hpp Source File
check_err.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include <algorithm>
7 #include <cmath>
8 #include <cstdlib>
9 #include <iostream>
10 #include <iomanip>
11 #include <iterator>
12 #include <limits>
13 #include <type_traits>
14 #include <vector>
15 
16 #include "ck/ck.hpp"
17 #include "ck/utility/data_type.hpp"
18 #include "ck/utility/type.hpp"
19 #include "ck/host_utility/io.hpp"
20 
22 
23 namespace ck {
24 namespace utils {
25 
26 template <typename ComputeDataType, typename OutDataType, typename AccDataType = ComputeDataType>
27 double get_relative_threshold(const int number_of_accumulations = 1)
28 {
29  using F4 = ck::f4_t;
30  using F8 = ck::f8_t;
31  using F16 = ck::half_t;
32  using BF16 = ck::bhalf_t;
33  using F32 = float;
34  using I8 = int8_t;
35  using I32 = int32_t;
36 
37  static_assert(is_same_v<ComputeDataType, F4> || is_same_v<ComputeDataType, F8> ||
38  is_same_v<ComputeDataType, F16> || is_same_v<ComputeDataType, BF16> ||
39  is_same_v<ComputeDataType, F32> || is_same_v<ComputeDataType, I8> ||
40  is_same_v<ComputeDataType, I32> || is_same_v<ComputeDataType, int>,
41  "Warning: Unhandled ComputeDataType for setting up the relative threshold!");
42  double compute_error = 0;
43  if constexpr(is_same_v<ComputeDataType, I8> || is_same_v<ComputeDataType, I32> ||
44  is_same_v<ComputeDataType, int>)
45  {
46  return 0;
47  }
48  else
49  {
50  compute_error = std::pow(2, -NumericUtils<ComputeDataType>::mant) * 0.5;
51  }
52 
53  static_assert(is_same_v<OutDataType, F4> || is_same_v<OutDataType, F8> ||
54  is_same_v<OutDataType, F16> || is_same_v<OutDataType, BF16> ||
55  is_same_v<OutDataType, F32> || is_same_v<OutDataType, I8> ||
56  is_same_v<OutDataType, I32> || is_same_v<OutDataType, int>,
57  "Warning: Unhandled OutDataType for setting up the relative threshold!");
58  double output_error = 0;
59  if constexpr(is_same_v<OutDataType, I8> || is_same_v<OutDataType, I32> ||
60  is_same_v<OutDataType, int>)
61  {
62  return 0;
63  }
64  else
65  {
66  output_error = std::pow(2, -NumericUtils<OutDataType>::mant) * 0.5;
67  }
68  double midway_error = std::max(compute_error, output_error);
69 
70  static_assert(is_same_v<AccDataType, F4> || is_same_v<AccDataType, F8> ||
71  is_same_v<AccDataType, F16> || is_same_v<AccDataType, BF16> ||
72  is_same_v<AccDataType, F32> || is_same_v<AccDataType, I8> ||
73  is_same_v<AccDataType, I32> || is_same_v<AccDataType, int>,
74  "Warning: Unhandled AccDataType for setting up the relative threshold!");
75  double acc_error = 0;
76  if constexpr(is_same_v<AccDataType, I8> || is_same_v<AccDataType, I32> ||
77  is_same_v<AccDataType, int>)
78  {
79  return 0;
80  }
81  else
82  {
83  acc_error = std::pow(2, -NumericUtils<AccDataType>::mant) * 0.5 * number_of_accumulations;
84  }
85  return std::max(acc_error, midway_error);
86 }
87 
88 template <typename ComputeDataType, typename OutDataType, typename AccDataType = ComputeDataType>
89 double get_absolute_threshold(const double max_possible_num, const int number_of_accumulations = 1)
90 {
91  using F4 = ck::f4_t;
92  using F8 = ck::f8_t;
93  using F16 = ck::half_t;
94  using BF16 = ck::bhalf_t;
95  using F32 = float;
96  using I8 = int8_t;
97  using I32 = int32_t;
98 
99  static_assert(is_same_v<ComputeDataType, F4> || is_same_v<ComputeDataType, F8> ||
100  is_same_v<ComputeDataType, F16> || is_same_v<ComputeDataType, BF16> ||
101  is_same_v<ComputeDataType, F32> || is_same_v<ComputeDataType, I8> ||
102  is_same_v<ComputeDataType, I32> || is_same_v<ComputeDataType, int>,
103  "Warning: Unhandled ComputeDataType for setting up the absolute threshold!");
104  auto expo = std::log2(std::abs(max_possible_num));
105  double compute_error = 0;
106  if constexpr(is_same_v<ComputeDataType, I8> || is_same_v<ComputeDataType, I32> ||
107  is_same_v<ComputeDataType, int>)
108  {
109  return 0;
110  }
111  else
112  {
113  compute_error = std::pow(2, expo - NumericUtils<ComputeDataType>::mant) * 0.5;
114  }
115 
116  static_assert(is_same_v<OutDataType, F4> || is_same_v<OutDataType, F8> ||
117  is_same_v<OutDataType, F16> || is_same_v<OutDataType, BF16> ||
118  is_same_v<OutDataType, F32> || is_same_v<OutDataType, I8> ||
119  is_same_v<OutDataType, I32> || is_same_v<OutDataType, int>,
120  "Warning: Unhandled OutDataType for setting up the absolute threshold!");
121  double output_error = 0;
122  if constexpr(is_same_v<OutDataType, I8> || is_same_v<OutDataType, I32> ||
123  is_same_v<OutDataType, int>)
124  {
125  return 0;
126  }
127  else
128  {
129  output_error = std::pow(2, expo - NumericUtils<OutDataType>::mant) * 0.5;
130  }
131  double midway_error = std::max(compute_error, output_error);
132 
133  static_assert(is_same_v<AccDataType, F4> || is_same_v<AccDataType, F8> ||
134  is_same_v<AccDataType, F16> || is_same_v<AccDataType, BF16> ||
135  is_same_v<AccDataType, F32> || is_same_v<AccDataType, I8> ||
136  is_same_v<AccDataType, I32> || is_same_v<AccDataType, int>,
137  "Warning: Unhandled AccDataType for setting up the absolute threshold!");
138  double acc_error = 0;
139  if constexpr(is_same_v<AccDataType, I8> || is_same_v<AccDataType, I32> ||
140  is_same_v<AccDataType, int>)
141  {
142  return 0;
143  }
144  else
145  {
146  acc_error =
147  std::pow(2, expo - NumericUtils<AccDataType>::mant) * 0.5 * number_of_accumulations;
148  }
149  return std::max(acc_error, midway_error);
150 }
151 
152 template <typename Range, typename RefRange>
153 typename std::enable_if<
154  std::is_same_v<ranges::range_value_t<Range>, ranges::range_value_t<RefRange>> &&
155  std::is_floating_point_v<ranges::range_value_t<Range>> &&
156  !std::is_same_v<ranges::range_value_t<Range>, half_t>,
157  bool>::type
158 check_err(const Range& out,
159  const RefRange& ref,
160  const std::string& msg = "Error: Incorrect results!",
161  double rtol = 1e-5,
162  double atol = 3e-6)
163 {
164  if(out.size() != ref.size())
165  {
166  std::cerr << msg << " out.size() != ref.size(), :" << out.size() << " != " << ref.size()
167  << std::endl;
168  return false;
169  }
170 
171  bool res{true};
172  int err_count = 0;
173  double err = 0;
174  double max_err = std::numeric_limits<double>::min();
175  for(std::size_t i = 0; i < ref.size(); ++i)
176  {
177  const double o = *std::next(std::begin(out), i);
178  const double r = *std::next(std::begin(ref), i);
179  err = std::abs(o - r);
180  if(err > atol + rtol * std::abs(r) || !std::isfinite(o) || !std::isfinite(r))
181  {
182  max_err = err > max_err ? err : max_err;
183  err_count++;
184  if(err_count < 5)
185  {
186  std::cerr << msg << std::setw(12) << std::setprecision(7) << " out[" << i
187  << "] != ref[" << i << "]: " << o << " != " << r << std::endl;
188  }
189  res = false;
190  }
191  }
192  if(!res)
193  {
194  const float error_percent =
195  static_cast<float>(err_count) / static_cast<float>(out.size()) * 100.f;
196  std::cerr << "max err: " << max_err;
197  std::cerr << ", number of errors: " << err_count;
198  std::cerr << ", " << error_percent << "% wrong values" << std::endl;
199  }
200  return res;
201 }
202 
203 template <typename Range, typename RefRange>
204 typename std::enable_if<
205  std::is_same_v<ranges::range_value_t<Range>, ranges::range_value_t<RefRange>> &&
206  std::is_same_v<ranges::range_value_t<Range>, bhalf_t>,
207  bool>::type
208 check_err(const Range& out,
209  const RefRange& ref,
210  const std::string& msg = "Error: Incorrect results!",
211  double rtol = 1e-1,
212  double atol = 1e-3)
213 {
214  if(out.size() != ref.size())
215  {
216  std::cerr << msg << " out.size() != ref.size(), :" << out.size() << " != " << ref.size()
217  << std::endl;
218  return false;
219  }
220 
221  bool res{true};
222  int err_count = 0;
223  double err = 0;
224  // TODO: This is a hack. We should have proper specialization for bhalf_t data type.
225  double max_err = std::numeric_limits<float>::min();
226  for(std::size_t i = 0; i < ref.size(); ++i)
227  {
228  const double o = type_convert<float>(*std::next(std::begin(out), i));
229  const double r = type_convert<float>(*std::next(std::begin(ref), i));
230  err = std::abs(o - r);
231  if(err > atol + rtol * std::abs(r) || !std::isfinite(o) || !std::isfinite(r))
232  {
233  max_err = err > max_err ? err : max_err;
234  err_count++;
235  if(err_count < 5)
236  {
237  std::cerr << msg << std::setw(12) << std::setprecision(7) << " out[" << i
238  << "] != ref[" << i << "]: " << o << " != " << r << std::endl;
239  }
240  res = false;
241  }
242  }
243  if(!res)
244  {
245  const float error_percent =
246  static_cast<float>(err_count) / static_cast<float>(out.size()) * 100.f;
247  std::cerr << "max err: " << max_err;
248  std::cerr << ", number of errors: " << err_count;
249  std::cerr << ", " << error_percent << "% wrong values" << std::endl;
250  }
251  return res;
252 }
253 
254 template <typename Range, typename RefRange>
255 typename std::enable_if<
256  std::is_same_v<ranges::range_value_t<Range>, ranges::range_value_t<RefRange>> &&
257  std::is_same_v<ranges::range_value_t<Range>, half_t>,
258  bool>::type
259 check_err(const Range& out,
260  const RefRange& ref,
261  const std::string& msg = "Error: Incorrect results!",
262  double rtol = 1e-3,
263  double atol = 1e-3)
264 {
265  if(out.size() != ref.size())
266  {
267  std::cerr << msg << " out.size() != ref.size(), :" << out.size() << " != " << ref.size()
268  << std::endl;
269  return false;
270  }
271 
272  bool res{true};
273  int err_count = 0;
274  double err = 0;
275  double max_err = NumericLimits<ranges::range_value_t<Range>>::Min();
276  for(std::size_t i = 0; i < ref.size(); ++i)
277  {
278  const double o = type_convert<float>(*std::next(std::begin(out), i));
279  const double r = type_convert<float>(*std::next(std::begin(ref), i));
280  err = std::abs(o - r);
281  if(err > atol + rtol * std::abs(r) || !std::isfinite(o) || !std::isfinite(r))
282  {
283  max_err = err > max_err ? err : max_err;
284  err_count++;
285  if(err_count < 5)
286  {
287  std::cerr << msg << std::setw(12) << std::setprecision(7) << " out[" << i
288  << "] != ref[" << i << "]: " << o << " != " << r << std::endl;
289  }
290  res = false;
291  }
292  }
293  if(!res)
294  {
295  const float error_percent =
296  static_cast<float>(err_count) / static_cast<float>(out.size()) * 100.f;
297  std::cerr << "max err: " << max_err;
298  std::cerr << ", number of errors: " << err_count;
299  std::cerr << ", " << error_percent << "% wrong values" << std::endl;
300  }
301  return res;
302 }
303 
304 template <typename Range, typename RefRange>
305 std::enable_if_t<(std::is_same_v<ranges::range_value_t<Range>, ranges::range_value_t<RefRange>> &&
306  std::is_integral_v<ranges::range_value_t<Range>> &&
307  !std::is_same_v<ranges::range_value_t<Range>, bhalf_t> &&
308  !std::is_same_v<ranges::range_value_t<Range>, f8_t> &&
309  !std::is_same_v<ranges::range_value_t<Range>, bf8_t>)
310 #ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
312 #endif
313  ,
314  bool>
315 check_err(const Range& out,
316  const RefRange& ref,
317  const std::string& msg = "Error: Incorrect results!",
318  double = 0,
319  double atol = 0)
320 {
321  if(out.size() != ref.size())
322  {
323  std::cerr << msg << " out.size() != ref.size(), :" << out.size() << " != " << ref.size()
324  << std::endl;
325  return false;
326  }
327 
328  bool res{true};
329  int err_count = 0;
330  int64_t err = 0;
332  for(std::size_t i = 0; i < ref.size(); ++i)
333  {
334  const int64_t o = *std::next(std::begin(out), i);
335  const int64_t r = *std::next(std::begin(ref), i);
336  err = std::abs(o - r);
337 
338  if(err > atol)
339  {
340  max_err = err > max_err ? err : max_err;
341  err_count++;
342  if(err_count < 5)
343  {
344  std::cerr << msg << " out[" << i << "] != ref[" << i << "]: " << o << " != " << r
345  << std::endl;
346  }
347  res = false;
348  }
349  }
350  if(!res)
351  {
352  const float error_percent =
353  static_cast<float>(err_count) / static_cast<float>(out.size()) * 100.f;
354  std::cerr << "max err: " << max_err;
355  std::cerr << ", number of errors: " << err_count;
356  std::cerr << ", " << error_percent << "% wrong values" << std::endl;
357  }
358  return res;
359 }
360 
361 template <typename Range, typename RefRange>
362 std::enable_if_t<(std::is_same_v<ranges::range_value_t<Range>, ranges::range_value_t<RefRange>> &&
363  std::is_same_v<ranges::range_value_t<Range>, f8_t>),
364  bool>
365 check_err(const Range& out,
366  const RefRange& ref,
367  const std::string& msg = "Error: Incorrect results!",
368  double rtol = 1e-3,
369  double atol = 1e-3)
370 {
371  if(out.size() != ref.size())
372  {
373  std::cerr << msg << " out.size() != ref.size(), :" << out.size() << " != " << ref.size()
374  << std::endl;
375  return false;
376  }
377 
378  bool res{true};
379  int err_count = 0;
380  double err = 0;
381  double max_err = std::numeric_limits<float>::min();
382 
383  for(std::size_t i = 0; i < ref.size(); ++i)
384  {
385  const double o = type_convert<float>(*std::next(std::begin(out), i));
386  const double r = type_convert<float>(*std::next(std::begin(ref), i));
387  err = std::abs(o - r);
388 
389  if(err > atol + rtol * std::abs(r) || !std::isfinite(o) || !std::isfinite(r))
390  {
391  max_err = err > max_err ? err : max_err;
392  err_count++;
393  if(err_count < 5)
394  {
395  std::cerr << msg << std::setw(12) << std::setprecision(7) << " out[" << i
396  << "] != ref[" << i << "]: " << o << " != " << r << std::endl;
397  }
398  res = false;
399  }
400  }
401 
402  if(!res)
403  {
404  std::cerr << std::setw(12) << std::setprecision(7) << "max err: " << max_err
405  << " number of errors: " << err_count << std::endl;
406  }
407  return res;
408 }
409 
410 template <typename Range, typename RefRange>
411 std::enable_if_t<(std::is_same_v<ranges::range_value_t<Range>, ranges::range_value_t<RefRange>> &&
412  std::is_same_v<ranges::range_value_t<Range>, bf8_t>),
413  bool>
414 check_err(const Range& out,
415  const RefRange& ref,
416  const std::string& msg = "Error: Incorrect results!",
417  double rtol = 1e-3,
418  double atol = 1e-3)
419 {
420  if(out.size() != ref.size())
421  {
422  std::cerr << msg << " out.size() != ref.size(), :" << out.size() << " != " << ref.size()
423  << std::endl;
424  return false;
425  }
426 
427  bool res{true};
428  int err_count = 0;
429  double err = 0;
430  double max_err = std::numeric_limits<float>::min();
431  for(std::size_t i = 0; i < ref.size(); ++i)
432  {
433  const double o = type_convert<float>(*std::next(std::begin(out), i));
434  const double r = type_convert<float>(*std::next(std::begin(ref), i));
435  err = std::abs(o - r);
436  if(err > atol + rtol * std::abs(r) || !std::isfinite(o) || !std::isfinite(r))
437  {
438  max_err = err > max_err ? err : max_err;
439  err_count++;
440  if(err_count < 5)
441  {
442  std::cerr << msg << std::setw(12) << std::setprecision(7) << " out[" << i
443  << "] != ref[" << i << "]: " << o << " != " << r << std::endl;
444  }
445  res = false;
446  }
447  }
448  if(!res)
449  {
450  std::cerr << std::setw(12) << std::setprecision(7) << "max err: " << max_err << std::endl;
451  }
452  return res;
453 }
454 
455 template <typename Range, typename RefRange>
456 std::enable_if_t<(std::is_same_v<ranges::range_value_t<Range>, ranges::range_value_t<RefRange>> &&
457  std::is_same_v<ranges::range_value_t<Range>, f4_t>),
458  bool>
459 check_err(const Range& out,
460  const RefRange& ref,
461  const std::string& msg = "Error: Incorrect results!",
462  double rtol = 0.5,
463  double atol = 0.5)
464 {
465  if(out.size() != ref.size())
466  {
467  std::cerr << msg << " out.size() != ref.size(), :" << out.size() << " != " << ref.size()
468  << std::endl;
469  return false;
470  }
471 
472  bool res{true};
473  int err_count = 0;
474  double err = 0;
475  double max_err = std::numeric_limits<float>::min();
476 
477  for(std::size_t i = 0; i < ref.size(); ++i)
478  {
479  const double o = type_convert<float>(*std::next(std::begin(out), i));
480  const double r = type_convert<float>(*std::next(std::begin(ref), i));
481  err = std::abs(o - r);
482 
483  if(err > atol + rtol * std::abs(r) || !std::isfinite(o) || !std::isfinite(r))
484  {
485  max_err = err > max_err ? err : max_err;
486  err_count++;
487  if(err_count < 5)
488  {
489  std::cerr << msg << std::setw(12) << std::setprecision(7) << " out[" << i
490  << "] != ref[" << i << "]: " << o << " != " << r << std::endl;
491  }
492  res = false;
493  }
494  }
495 
496  if(!res)
497  {
498  std::cerr << std::setw(12) << std::setprecision(7) << "max err: " << max_err
499  << " number of errors: " << err_count << std::endl;
500  }
501  return res;
502 }
503 
504 } // namespace utils
505 } // namespace ck
__host__ T pow(T x, T gamma)
Definition: math_v2.hpp:427
__host__ constexpr __device__ T max(T x)
Definition: math.hpp:84
__host__ constexpr __device__ T min(T x)
Definition: math.hpp:116
iter_value_t< ranges::iterator_t< R > > range_value_t
Definition: ranges.hpp:28
std::enable_if< std::is_same_v< ranges::range_value_t< Range >, ranges::range_value_t< RefRange > > &&std::is_floating_point_v< ranges::range_value_t< Range > > &&!std::is_same_v< ranges::range_value_t< Range >, half_t >, bool >::type check_err(const Range &out, const RefRange &ref, const std::string &msg="Error: Incorrect results!", double rtol=1e-5, double atol=3e-6)
Definition: check_err.hpp:158
double get_absolute_threshold(const double max_possible_num, const int number_of_accumulations=1)
Definition: check_err.hpp:89
double get_relative_threshold(const int number_of_accumulations=1)
Definition: check_err.hpp:27
float F32
32-bit floating point (single precision) type
Definition: check_err.hpp:33
ck_tile::bf16_t BF16
16-bit brain floating point type
Definition: check_err.hpp:31
ck_tile::half_t F16
16-bit floating point (half precision) type
Definition: check_err.hpp:29
int32_t I32
32-bit signed integer type
Definition: check_err.hpp:37
ck_tile::fp8_t F8
8-bit floating point type
Definition: check_err.hpp:25
int8_t I8
8-bit signed integer type
Definition: check_err.hpp:35
Definition: ck.hpp:267
bf8_fnuz_t bf8_t
Definition: amd_ck_fp8.hpp:1738
f8_fnuz_t f8_t
Definition: amd_ck_fp8.hpp:1737
unsigned _BitInt(4) f4_t
Definition: data_type.hpp:32
_Float16 half_t
Definition: data_type.hpp:30
ushort bhalf_t
Definition: data_type.hpp:29
_BitInt(4) int4_t
Definition: data_type.hpp:31
std::enable_if< B, T > enable_if
Definition: enable_if.hpp:24
constexpr bool is_same_v
Definition: type.hpp:283
long int64_t
Definition: data_type.hpp:461
typename std::enable_if< B, T >::type enable_if_t
Definition: enable_if.hpp:27
signed int int32_t
Definition: stdint.h:123
signed char int8_t
Definition: stdint.h:121
Definition: numeric_limits.hpp:309
Definition: numeric_utils.hpp:10