15 template <
typename... Is>
27 template <
typename... Is>
39 template <
typename... Is>
42 return ck::type_convert<ck::half_t>(
value);
51 template <
typename... Is>
54 return ck::type_convert<ck::bhalf_t>(
value);
58 #if defined CK_ENABLE_FP8
64 template <
typename... Is>
67 return ck::type_convert<ck::f8_t>(
value);
77 template <
typename... Is>
80 return ck::type_convert<ck::f4_t>(
value);
89 template <
typename... Is>
101 template <
typename... Is>
110 template <
typename T>
116 template <
typename... Is>
129 template <
typename... Is>
133 return ck::type_convert<ck::bhalf_t>(tmp);
143 template <
typename... Is>
156 template <
typename... Is>
166 #if defined CK_ENABLE_FP8
173 template <
typename... Is>
177 return ck::type_convert<ck::f8_t>(tmp);
182 #if defined CK_ENABLE_BF8
189 template <
typename... Is>
193 return ck::type_convert<ck::bf8_t>(tmp);
204 template <
typename... Is>
208 return ck::type_convert<ck::f4_t>(tmp);
212 template <
typename T>
218 template <
typename... Is>
221 float tmp = float(std::rand()) / float(RAND_MAX);
233 template <
typename... Is>
236 float tmp = float(std::rand()) / float(RAND_MAX);
240 return ck::type_convert<ck::bhalf_t>(fp32_tmp);
244 #if defined CK_ENABLE_FP8
251 template <
typename... Is>
254 float tmp = float(std::rand()) / float(RAND_MAX);
258 return ck::type_convert<ck::f8_t>(fp32_tmp);
263 #if defined CK_ENABLE_BF8
270 template <
typename... Is>
273 float tmp = float(std::rand()) / float(RAND_MAX);
277 return ck::type_convert<ck::bf8_t>(fp32_tmp);
288 template <
typename... Is>
291 float tmp = float(std::rand()) / float(RAND_MAX);
295 return ck::type_convert<ck::f4_t>(fp32_tmp);
299 template <
typename T>
308 template <
typename... Is>
313 return ck::type_convert<T>(tmp);
319 template <
typename... Ts>
323 return std::accumulate(dims.begin(),
326 [](
bool init,
ck::index_t x) ->
int { return init != (x % 2); })
349 template <
typename T, ck::index_t Dim>
352 template <
typename... Ts>
357 float tmp = dims[Dim];
358 return ck::type_convert<T>(tmp);
362 template <
typename T,
size_t NumEffectiveDim = 2>
367 template <
typename... Ts>
371 size_t start_dim = dims.size() - NumEffectiveDim;
373 for(
size_t i = start_dim + 1; i < dims.size(); i++)
375 pred &= (dims[start_dim] == dims[i]);
377 return pred ?
value : T{0};
int8_t int8_t
Definition: int8.hpp:20
bf8_fnuz_t bf8_t
Definition: amd_ck_fp8.hpp:991
f8_fnuz_t f8_t
Definition: amd_ck_fp8.hpp:990
unsigned _BitInt(4) f4_t
Definition: data_type.hpp:27
_Float16 half_t
Definition: data_type.hpp:25
ushort bhalf_t
Definition: data_type.hpp:24
int32_t index_t
Definition: ck.hpp:289
Definition: host_tensor_generator.hpp:14
T operator()(Is...)
Definition: host_tensor_generator.hpp:16
ck::bhalf_t operator()(Is...)
Definition: host_tensor_generator.hpp:52
ck::f4_t operator()(Is...)
Definition: host_tensor_generator.hpp:78
ck::half_t operator()(Is...)
Definition: host_tensor_generator.hpp:40
ck::pk_i4_t operator()(Is...)
Definition: host_tensor_generator.hpp:102
int8_t operator()(Is...)
Definition: host_tensor_generator.hpp:90
Definition: host_tensor_generator.hpp:24
T value
Definition: host_tensor_generator.hpp:25
T operator()(Is...)
Definition: host_tensor_generator.hpp:28
ck::bhalf_t operator()(Is...)
Definition: host_tensor_generator.hpp:130
ck::f4_t operator()(Is...)
Definition: host_tensor_generator.hpp:205
ck::pk_i4_t operator()(Is...)
Definition: host_tensor_generator.hpp:157
int8_t operator()(Is...)
Definition: host_tensor_generator.hpp:144
Definition: host_tensor_generator.hpp:112
int max_value
Definition: host_tensor_generator.hpp:114
int min_value
Definition: host_tensor_generator.hpp:113
T operator()(Is...)
Definition: host_tensor_generator.hpp:117
ck::bhalf_t operator()(Is...)
Definition: host_tensor_generator.hpp:234
ck::f4_t operator()(Is...)
Definition: host_tensor_generator.hpp:289
Definition: host_tensor_generator.hpp:214
float max_value
Definition: host_tensor_generator.hpp:216
float min_value
Definition: host_tensor_generator.hpp:215
T operator()(Is...)
Definition: host_tensor_generator.hpp:219
Definition: host_tensor_generator.hpp:301
std::mt19937 generator
Definition: host_tensor_generator.hpp:302
GeneratorTensor_4(float mean, float stddev, unsigned int seed=1)
Definition: host_tensor_generator.hpp:305
T operator()(Is...)
Definition: host_tensor_generator.hpp:309
std::normal_distribution< float > distribution
Definition: host_tensor_generator.hpp:303
Definition: host_tensor_generator.hpp:318
float operator()(Ts... Xs) const
Definition: host_tensor_generator.hpp:320
Definition: host_tensor_generator.hpp:364
T operator()(Ts... Xs) const
Definition: host_tensor_generator.hpp:368
T value
Definition: host_tensor_generator.hpp:365
Is used to generate sequential values based on the specified dimension.
Definition: host_tensor_generator.hpp:351
T operator()(Ts... Xs) const
Definition: host_tensor_generator.hpp:353
Definition: data_type.hpp:320