utils Namespace Reference

utils Namespace Reference#

Composable Kernel: ck::utils Namespace Reference
ck::utils Namespace Reference

Namespaces

 conv
 

Classes

struct  FillUniformDistribution
 
struct  FillUniformDistributionIntegerValue
 
struct  FillMonotonicSeq
 A functor for filling a container with a monotonically increasing or decreasing sequence. More...
 
struct  FillConstant
 
struct  TransformIntoStructuralSparsity
 
union  cvt
 

Functions

template<typename ComputeDataType , typename OutDataType , typename AccDataType = ComputeDataType>
double get_relative_threshold (const int number_of_accumulations=1)
 
template<typename ComputeDataType , typename OutDataType , typename AccDataType = ComputeDataType>
double get_absolute_threshold (const double max_possible_num, const int number_of_accumulations=1)
 
template<typename Range , typename RefRange >
std::enable_if< std::is_same_v< ranges::range_value_t< Range >, ranges::range_value_t< RefRange > > &&std::is_floating_point_v< ranges::range_value_t< Range > > &&!std::is_same_v< ranges::range_value_t< Range >, half_t >, bool >::type check_err (const Range &out, const RefRange &ref, const std::string &msg="Error: Incorrect results!", double rtol=1e-5, double atol=3e-6)
 
template<typename Range , typename RefRange >
std::enable_if< std::is_same_v< ranges::range_value_t< Range >, ranges::range_value_t< RefRange > > &&std::is_same_v< ranges::range_value_t< Range >, bhalf_t >, bool >::type check_err (const Range &out, const RefRange &ref, const std::string &msg="Error: Incorrect results!", double rtol=1e-1, double atol=1e-3)
 
template<typename Range , typename RefRange >
std::enable_if< std::is_same_v< ranges::range_value_t< Range >, ranges::range_value_t< RefRange > > &&std::is_same_v< ranges::range_value_t< Range >, half_t >, bool >::type check_err (const Range &out, const RefRange &ref, const std::string &msg="Error: Incorrect results!", double rtol=1e-3, double atol=1e-3)
 
template<typename Range , typename RefRange >
std::enable_if_t<(std::is_same_v< ranges::range_value_t< Range >, ranges::range_value_t< RefRange >> &&std::is_integral_v< ranges::range_value_t< Range >> &&!std::is_same_v< ranges::range_value_t< Range >, bhalf_t > &&!std::is_same_v< ranges::range_value_t< Range >, f8_t > &&!std::is_same_v< ranges::range_value_t< Range >, bf8_t >), bool > check_err (const Range &out, const RefRange &ref, const std::string &msg="Error: Incorrect results!", double=0, double atol=0)
 
template<typename Range , typename RefRange >
std::enable_if_t<(std::is_same_v< ranges::range_value_t< Range >, ranges::range_value_t< RefRange >> &&std::is_same_v< ranges::range_value_t< Range >, f8_t >), bool > check_err (const Range &out, const RefRange &ref, const std::string &msg="Error: Incorrect results!", double rtol=1e-3, double atol=1e-3)
 
template<typename Range , typename RefRange >
std::enable_if_t<(std::is_same_v< ranges::range_value_t< Range >, ranges::range_value_t< RefRange >> &&std::is_same_v< ranges::range_value_t< Range >, bf8_t >), bool > check_err (const Range &out, const RefRange &ref, const std::string &msg="Error: Incorrect results!", double rtol=1e-3, double atol=1e-3)
 
template<typename Range , typename RefRange >
std::enable_if_t<(std::is_same_v< ranges::range_value_t< Range >, ranges::range_value_t< RefRange >> &&std::is_same_v< ranges::range_value_t< Range >, f4_t >), bool > check_err (const Range &out, const RefRange &ref, const std::string &msg="Error: Incorrect results!", double rtol=0.5, double atol=0.5)
 
template<typename Layout >
void validate_gemm_stride (int M, int N, int stride, const std::string &stride_name="Stride")
 
template<typename ALayout , typename BLayout , typename CLayout >
void validate_gemm_strides_abc (int M, int N, int K, int StrideA, int StrideB, int StrideC)
 
template<typename T >
__host__ constexpr __device__ int32_t get_exponent_value (T x)
 
template<>
__host__ constexpr __device__ int32_t get_exponent_value< e8m0_bexp_t > (e8m0_bexp_t x)
 
template<typename X , typename Y , bool negative_zero_nan, bool clip, bool stoch>
__host__ __device__ Y cast_to_f8 (X x, uint32_t rng)
 
template<typename X , typename Y , bool negative_zero_nan>
__host__ __device__ Y cast_from_f8 (X x)
 
template<>
__host__ __device__ bool is_nan< f4_t > (e8m0_bexp_t const scale, f4_t const dataBytes[[maybe_unused]])
 
template<>
__host__ __device__ bool is_inf< f4_t > (e8m0_bexp_t const scale[[maybe_unused]], f4_t const data[[maybe_unused]])
 
template<>
__host__ __device__ bool is_zero< f4_t > (e8m0_bexp_t const scale[[maybe_unused]], f4_t const data)
 
template<>
__host__ __device__ float to_float< f4_t > (e8m0_bexp_t const scale, f4_t const data)
 
template<>
__host__ __device__ f4_t sat_convert_to_type< f4_t > (float value)
 
template<>
__host__ __device__ f4_t sat_convert_to_type_sr< f4_t > (float value, uint32_t seed)
 
template<>
__host__ __device__ bool is_nan< f6_t > (e8m0_bexp_t const scale, f6_t const dataBytes[[maybe_unused]])
 Checks if an f6_t value is NaN based on the provided scale. More...
 
template<>
__host__ __device__ bool is_nan< bf6_t > (e8m0_bexp_t const scale, bf6_t const dataBytes[[maybe_unused]])
 Checks if an bf6_t value is NaN based on the provided scale. More...
 
template<>
__host__ __device__ bool is_inf< f6_t > (e8m0_bexp_t const scale[[maybe_unused]], f6_t const data[[maybe_unused]])
 Checks if an f6_t value is infinite. More...
 
template<>
__host__ __device__ bool is_inf< bf6_t > (e8m0_bexp_t const scale[[maybe_unused]], bf6_t const data[[maybe_unused]])
 Checks if an bf6_t value is infinite. More...
 
template<>
__host__ __device__ bool is_zero< f6_t > (e8m0_bexp_t const scale, f6_t const data)
 Checks whether an f6_t value is zero. More...
 
template<>
__host__ __device__ bool is_zero< bf6_t > (e8m0_bexp_t const scale, bf6_t const data)
 Checks whether an bf6_t value is zero. More...
 
template<>
__host__ __device__ float to_float< f6_t > (e8m0_bexp_t const scale, f6_t const data)
 Converts an f6_t value to a float based on an e8m0_bexp_t scale factor. More...
 
template<>
__host__ __device__ float to_float< bf6_t > (e8m0_bexp_t const scale, bf6_t const data)
 Converts an bf6_t value to a float based on an e8m0_bexp_t scale factor. More...
 
template<>
__host__ __device__ f6_t sat_convert_to_type< f6_t > (float value)
 Converts a float to f6_t with saturation. More...
 
template<>
__host__ __device__ bf6_t sat_convert_to_type< bf6_t > (float value)
 Converts a float to bf6_t with saturation. More...
 
template<>
__host__ __device__ f6_t sat_convert_to_type_sr< f6_t > (float value, uint32_t seed)
 Converts a float to f6_t with saturation and stochastic rounding. More...
 
template<>
__host__ __device__ bf6_t sat_convert_to_type_sr< bf6_t > (float value, uint32_t seed)
 Converts a float to f6_t with saturation and stochastic rounding. More...
 
template<typename DTYPE >
bool getDataHasInf ()
 
template<typename T >
__host__ __device__ bool is_zero (e8m0_bexp_t const scale, T const data)
 
template<typename T >
__host__ __device__ bool is_nan (e8m0_bexp_t const scale, T const data)
 
template<typename T >
__host__ __device__ bool is_inf (e8m0_bexp_t const scale, T const data)
 
template<typename T >
__host__ __device__ bool is_subnormal (T x)
 
template<typename T >
__host__ __device__ double get_mantissa_value (T x)
 
template<typename T >
__host__ __device__ bool get_data_has_inf ()
 
template<typename T >
__host__ __device__ float convert_to_float (T data, int scale_exp)
 
template<typename T >
__host__ __device__ float to_float (e8m0_bexp_t const scale, T const data)
 
template<typename T >
__host__ __device__ T sat_convert_to_type (float value)
 
template<typename T >
__host__ __device__ T sat_convert_to_type_sr (float value, uint32_t seed)
 
template<typename T >
__host__ __device__ T convert_to_type (float value)
 
template<typename T >
__host__ __device__ T convert_to_type_sr (float value, uint32_t seed)
 

Function Documentation

◆ cast_from_f8()

template<typename X , typename Y , bool negative_zero_nan>
__host__ __device__ Y ck::utils::cast_from_f8 ( x)

◆ cast_to_f8()

template<typename X , typename Y , bool negative_zero_nan, bool clip, bool stoch>
__host__ __device__ Y ck::utils::cast_to_f8 ( x,
uint32_t  rng 
)

◆ check_err() [1/7]

template<typename Range , typename RefRange >
std::enable_if_t<(std::is_same_v<ranges::range_value_t<Range>, ranges::range_value_t<RefRange>> && std::is_same_v<ranges::range_value_t<Range>, f4_t>), bool> ck::utils::check_err ( const Range &  out,
const RefRange &  ref,
const std::string &  msg = "Error: Incorrect results!",
double  rtol = 0.5,
double  atol = 0.5 
)

◆ check_err() [2/7]

template<typename Range , typename RefRange >
std::enable_if< std::is_same_v<ranges::range_value_t<Range>, ranges::range_value_t<RefRange> > && std::is_same_v<ranges::range_value_t<Range>, bhalf_t>, bool>::type ck::utils::check_err ( const Range &  out,
const RefRange &  ref,
const std::string &  msg = "Error: Incorrect results!",
double  rtol = 1e-1,
double  atol = 1e-3 
)

◆ check_err() [3/7]

template<typename Range , typename RefRange >
std::enable_if< std::is_same_v<ranges::range_value_t<Range>, ranges::range_value_t<RefRange> > && std::is_same_v<ranges::range_value_t<Range>, half_t>, bool>::type ck::utils::check_err ( const Range &  out,
const RefRange &  ref,
const std::string &  msg = "Error: Incorrect results!",
double  rtol = 1e-3,
double  atol = 1e-3 
)

◆ check_err() [4/7]

template<typename Range , typename RefRange >
std::enable_if_t<(std::is_same_v<ranges::range_value_t<Range>, ranges::range_value_t<RefRange>> && std::is_same_v<ranges::range_value_t<Range>, f8_t>), bool> ck::utils::check_err ( const Range &  out,
const RefRange &  ref,
const std::string &  msg = "Error: Incorrect results!",
double  rtol = 1e-3,
double  atol = 1e-3 
)

◆ check_err() [5/7]

template<typename Range , typename RefRange >
std::enable_if_t<(std::is_same_v<ranges::range_value_t<Range>, ranges::range_value_t<RefRange>> && std::is_same_v<ranges::range_value_t<Range>, bf8_t>), bool> ck::utils::check_err ( const Range &  out,
const RefRange &  ref,
const std::string &  msg = "Error: Incorrect results!",
double  rtol = 1e-3,
double  atol = 1e-3 
)

◆ check_err() [6/7]

template<typename Range , typename RefRange >
std::enable_if< std::is_same_v<ranges::range_value_t<Range>, ranges::range_value_t<RefRange> > && std::is_floating_point_v<ranges::range_value_t<Range> > && !std::is_same_v<ranges::range_value_t<Range>, half_t>, bool>::type ck::utils::check_err ( const Range &  out,
const RefRange &  ref,
const std::string &  msg = "Error: Incorrect results!",
double  rtol = 1e-5,
double  atol = 3e-6 
)

◆ check_err() [7/7]

template<typename Range , typename RefRange >
std::enable_if_t<(std::is_same_v<ranges::range_value_t<Range>, ranges::range_value_t<RefRange>> && std::is_integral_v<ranges::range_value_t<Range>> && !std::is_same_v<ranges::range_value_t<Range>, bhalf_t> && !std::is_same_v<ranges::range_value_t<Range>, f8_t> && !std::is_same_v<ranges::range_value_t<Range>, bf8_t>) , bool> ck::utils::check_err ( const Range &  out,
const RefRange &  ref,
const std::string &  msg = "Error: Incorrect results!",
double  = 0,
double  atol = 0 
)

◆ convert_to_float()

template<typename T >
__host__ __device__ float ck::utils::convert_to_float ( data,
int  scale_exp 
)

◆ convert_to_type()

template<typename T >
__host__ __device__ T ck::utils::convert_to_type ( float  value)
inline

◆ convert_to_type_sr()

template<typename T >
__host__ __device__ T ck::utils::convert_to_type_sr ( float  value,
uint32_t  seed 
)
inline

◆ get_absolute_threshold()

template<typename ComputeDataType , typename OutDataType , typename AccDataType = ComputeDataType>
double ck::utils::get_absolute_threshold ( const double  max_possible_num,
const int  number_of_accumulations = 1 
)

◆ get_data_has_inf()

template<typename T >
__host__ __device__ bool ck::utils::get_data_has_inf ( )
inline

◆ get_exponent_value()

template<typename T >
__host__ constexpr __device__ int32_t ck::utils::get_exponent_value ( x)
inlineconstexpr

◆ get_exponent_value< e8m0_bexp_t >()

template<>
__host__ constexpr __device__ int32_t ck::utils::get_exponent_value< e8m0_bexp_t > ( e8m0_bexp_t  x)
inlineconstexpr

◆ get_mantissa_value()

template<typename T >
__host__ __device__ double ck::utils::get_mantissa_value ( x)
inline

◆ get_relative_threshold()

template<typename ComputeDataType , typename OutDataType , typename AccDataType = ComputeDataType>
double ck::utils::get_relative_threshold ( const int  number_of_accumulations = 1)

◆ getDataHasInf()

template<typename DTYPE >
bool ck::utils::getDataHasInf ( )
inline

◆ is_inf()

template<typename T >
__host__ __device__ bool ck::utils::is_inf ( e8m0_bexp_t const  scale,
T const  data 
)
inline

◆ is_inf< bf6_t >()

template<>
__host__ __device__ bool ck::utils::is_inf< bf6_t > ( e8m0_bexp_t const scale]  [[maybe_unused],
bf6_t const data]  [[maybe_unused] 
)
inline

Checks if an bf6_t value is infinite.

Because bf6_t does not support infinite values, this function always returns false.

Parameters
scaleThe exponent scale factor (e8m0_bexp_t) used for bf6_t.
dataThe bf6_t value to check.
Returns
Always false, as infinity is not represented in bf6_t.

◆ is_inf< f4_t >()

template<>
__host__ __device__ bool ck::utils::is_inf< f4_t > ( e8m0_bexp_t const scale]  [[maybe_unused],
f4_t const data]  [[maybe_unused] 
)
inline

◆ is_inf< f6_t >()

template<>
__host__ __device__ bool ck::utils::is_inf< f6_t > ( e8m0_bexp_t const scale]  [[maybe_unused],
f6_t const data]  [[maybe_unused] 
)
inline

Checks if an f6_t value is infinite.

Because f6_t does not support infinite values, this function always returns false.

Parameters
scaleThe exponent scale factor (e8m0_bexp_t) used for f6_t.
dataThe f6_t value to check.
Returns
Always false, as infinity is not represented in f6_t.

◆ is_nan()

template<typename T >
__host__ __device__ bool ck::utils::is_nan ( e8m0_bexp_t const  scale,
T const  data 
)
inline

◆ is_nan< bf6_t >()

template<>
__host__ __device__ bool ck::utils::is_nan< bf6_t > ( e8m0_bexp_t const  scale,
bf6_t const dataBytes]  [[maybe_unused] 
)
inline

Checks if an bf6_t value is NaN based on the provided scale.

For bf6_t data, NaN cannot be represented directly. Instead, this function determines NaN by checking if the scale is set to a quiet NaN.

Parameters
scaleThe exponent scale factor (e8m0_bexp_t) used for bf6_t.
dataBytesThe bf6_t value to check (unused in this implementation).
Returns
true if the scale indicates a NaN value, false otherwise.

◆ is_nan< f4_t >()

template<>
__host__ __device__ bool ck::utils::is_nan< f4_t > ( e8m0_bexp_t const  scale,
f4_t const dataBytes]  [[maybe_unused] 
)
inline

◆ is_nan< f6_t >()

template<>
__host__ __device__ bool ck::utils::is_nan< f6_t > ( e8m0_bexp_t const  scale,
f6_t const dataBytes]  [[maybe_unused] 
)
inline

Checks if an f6_t value is NaN based on the provided scale.

For f6_t data, NaN cannot be represented directly. Instead, this function determines NaN by checking if the scale is set to a quiet NaN.

Parameters
scaleThe exponent scale factor (e8m0_bexp_t) used for f6_t.
dataBytesThe f6_t value to check (unused in this implementation).
Returns
true if the scale indicates a NaN value, false otherwise.

◆ is_subnormal()

template<typename T >
__host__ __device__ bool ck::utils::is_subnormal ( x)
inline

◆ is_zero()

template<typename T >
__host__ __device__ bool ck::utils::is_zero ( e8m0_bexp_t const  scale,
T const  data 
)
inline

◆ is_zero< bf6_t >()

template<>
__host__ __device__ bool ck::utils::is_zero< bf6_t > ( e8m0_bexp_t const  scale,
bf6_t const  data 
)
inline

Checks whether an bf6_t value is zero.

If the specified bf6_t is NaN, this function returns false. Otherwise, it masks out the sign bits and checks if the remaining bits are zero.

Parameters
scaleThe exponent scale factor (e8m0_bexp_t) used for bf6_t.
dataThe bf6_t value to check.
Returns
true if the value is zero; otherwise false.

◆ is_zero< f4_t >()

template<>
__host__ __device__ bool ck::utils::is_zero< f4_t > ( e8m0_bexp_t const scale]  [[maybe_unused],
f4_t const  data 
)
inline

◆ is_zero< f6_t >()

template<>
__host__ __device__ bool ck::utils::is_zero< f6_t > ( e8m0_bexp_t const  scale,
f6_t const  data 
)
inline

Checks whether an f6_t value is zero.

If the specified f6_t is NaN, this function returns false. Otherwise, it masks out the sign bits and checks if the remaining bits are zero.

Parameters
scaleThe exponent scale factor (e8m0_bexp_t) used for f6_t.
dataThe f6_t value to check.
Returns
true if the value is zero; otherwise false.

◆ sat_convert_to_type()

template<typename T >
__host__ __device__ T ck::utils::sat_convert_to_type ( float  value)

◆ sat_convert_to_type< bf6_t >()

template<>
__host__ __device__ bf6_t ck::utils::sat_convert_to_type< bf6_t > ( float  value)
inline

Converts a float to bf6_t with saturation.

If the input is NaN or exceeds the representable range for bf6_t, returns the corresponding max normal mask. Handles subnormal cases by returning zero with the appropriate sign.

Parameters
valueThe float value to be converted.
Returns
The saturated bf6_t value.

◆ sat_convert_to_type< f4_t >()

template<>
__host__ __device__ f4_t ck::utils::sat_convert_to_type< f4_t > ( float  value)
inline

◆ sat_convert_to_type< f6_t >()

template<>
__host__ __device__ f6_t ck::utils::sat_convert_to_type< f6_t > ( float  value)
inline

Converts a float to f6_t with saturation.

If the input is NaN or exceeds the representable range for f6_t, returns the corresponding max normal mask. Handles subnormal cases by returning zero with the appropriate sign.

Parameters
valueThe float value to be converted.
Returns
The saturated f6_t value.

◆ sat_convert_to_type_sr()

template<typename T >
__host__ __device__ T ck::utils::sat_convert_to_type_sr ( float  value,
uint32_t  seed 
)

◆ sat_convert_to_type_sr< bf6_t >()

template<>
__host__ __device__ bf6_t ck::utils::sat_convert_to_type_sr< bf6_t > ( float  value,
uint32_t  seed 
)
inline

Converts a float to f6_t with saturation and stochastic rounding.

If the input is NaN or exceeds the representable range for f6_t, returns the corresponding max normal mask. Handles subnormal cases by returning zero with the appropriate sign.

Parameters
valueThe float value to be converted.
Returns
The saturated f6_t value.

◆ sat_convert_to_type_sr< f4_t >()

template<>
__host__ __device__ f4_t ck::utils::sat_convert_to_type_sr< f4_t > ( float  value,
uint32_t  seed 
)
inline

◆ sat_convert_to_type_sr< f6_t >()

template<>
__host__ __device__ f6_t ck::utils::sat_convert_to_type_sr< f6_t > ( float  value,
uint32_t  seed 
)
inline

Converts a float to f6_t with saturation and stochastic rounding.

If the input is NaN or exceeds the representable range for f6_t, returns the corresponding max normal mask. Handles subnormal cases by returning zero with the appropriate sign.

Parameters
valueThe float value to be converted.
Returns
The saturated f6_t value.

◆ to_float()

template<typename T >
__host__ __device__ float ck::utils::to_float ( e8m0_bexp_t const  scale,
T const  data 
)
inline

◆ to_float< bf6_t >()

template<>
__host__ __device__ float ck::utils::to_float< bf6_t > ( e8m0_bexp_t const  scale,
bf6_t const  data 
)
inline

Converts an bf6_t value to a float based on an e8m0_bexp_t scale factor.

Checks if the bf6_t value is NaN or zero before performing the conversion. Applies the exponent from the scale to compute the final float result.

Parameters
scaleThe exponent scale factor (e8m0_bexp_t) used for bf6_t.
dataThe bf6_t value to convert.
Returns
The converted float value.

◆ to_float< f4_t >()

template<>
__host__ __device__ float ck::utils::to_float< f4_t > ( e8m0_bexp_t const  scale,
f4_t const  data 
)
inline

◆ to_float< f6_t >()

template<>
__host__ __device__ float ck::utils::to_float< f6_t > ( e8m0_bexp_t const  scale,
f6_t const  data 
)
inline

Converts an f6_t value to a float based on an e8m0_bexp_t scale factor.

Checks if the f6_t value is NaN or zero before performing the conversion. Applies the exponent from the scale to compute the final float result.

Parameters
scaleThe exponent scale factor (e8m0_bexp_t) used for f6_t.
dataThe f6_t value to convert.
Returns
The converted float value.

◆ validate_gemm_stride()

template<typename Layout >
void ck::utils::validate_gemm_stride ( int  M,
int  N,
int  stride,
const std::string &  stride_name = "Stride" 
)
inline

◆ validate_gemm_strides_abc()

template<typename ALayout , typename BLayout , typename CLayout >
void ck::utils::validate_gemm_strides_abc ( int  M,
int  N,
int  K,
int  StrideA,
int  StrideB,
int  StrideC 
)
inline