53 #ifndef ROCRAND_THREEFRY4_IMPL_H_
54 #define ROCRAND_THREEFRY4_IMPL_H_
56 #include "rocrand/rocrand_common.h"
57 #include "rocrand/rocrand_threefry_common.h"
59 #include <hip/hip_runtime.h>
61 #ifndef THREEFRY4x32_DEFAULT_ROUNDS
62 #define THREEFRY4x32_DEFAULT_ROUNDS 20
65 #ifndef THREEFRY4x64_DEFAULT_ROUNDS
66 #define THREEFRY4x64_DEFAULT_ROUNDS 20
69 namespace rocrand_device
73 __forceinline__ __device__ __host__
int threefry_rotation_array(
int indexX,
int indexY) =
delete;
76 __forceinline__ __device__ __host__
int threefry_rotation_array<unsigned int>(
int indexX,
84 static constexpr
int THREEFRY_ROTATION_32_4[8][2] = {
94 return THREEFRY_ROTATION_32_4[indexX][indexY];
98 __forceinline__ __device__ __host__
int threefry_rotation_array<unsigned long long>(
int indexX,
103 static constexpr
int THREEFRY_ROTATION_64_4[8][2] = {
113 return THREEFRY_ROTATION_64_4[indexX][indexY];
116 template<
typename state_value,
typename value,
unsigned int Nrounds>
117 class threefry_engine4_base
120 struct threefry_state_4
125 unsigned int substate;
127 using state_type = threefry_state_4;
128 using state_vector_type = state_value;
131 __forceinline__ __device__ __host__
void discard(
unsigned long long offset)
133 this->discard_impl(offset);
134 this->m_state.result = this->threefry_rounds(m_state.counter, m_state.key);
142 __forceinline__ __device__ __host__
void discard_subsequence(
unsigned long long subsequence)
144 this->discard_subsequence_impl(subsequence);
145 m_state.result = this->threefry_rounds(m_state.counter, m_state.key);
148 __forceinline__ __device__ __host__ value operator()()
153 __forceinline__ __device__ __host__ value next()
155 #if defined(__HIP_PLATFORM_AMD__)
156 value ret = ROCRAND_HIPVEC_ACCESS(m_state.result)[m_state.substate];
158 value ret = (&m_state.result.x)[m_state.substate];
161 if(m_state.substate == 4)
163 m_state.substate = 0;
164 m_state.counter = this->bump_counter(m_state.counter);
165 m_state.result = this->threefry_rounds(m_state.counter, m_state.key);
170 __forceinline__ __device__ __host__ state_value next4()
172 state_value ret = m_state.result;
173 m_state.counter = this->bump_counter(m_state.counter);
174 m_state.result = this->threefry_rounds(m_state.counter, m_state.key);
176 return this->interleave(ret, m_state.result);
180 __forceinline__ __device__ __host__
static state_value threefry_rounds(state_value counter,
186 static_assert(Nrounds <= 72,
"72 or less only supported in threefry rounds");
188 ks[4] = skein_ks_parity<value>();
211 for(
unsigned int round_idx = 0; round_idx < Nrounds; round_idx++)
213 int rot_0 = threefry_rotation_array<value>(round_idx & 7u, 0);
214 int rot_1 = threefry_rotation_array<value>(round_idx & 7u, 1);
215 if((round_idx & 2u) == 0)
218 X.y = rotl<value>(X.y, rot_0);
221 X.w = rotl<value>(X.w, rot_1);
227 X.w = rotl<value>(X.w, rot_0);
230 X.y = rotl<value>(X.y, rot_1);
234 if((round_idx & 3u) == 3)
236 unsigned int inject_idx = round_idx / 4;
238 X.x += ks[(1 + inject_idx) % 5];
239 X.y += ks[(2 + inject_idx) % 5];
240 X.z += ks[(3 + inject_idx) % 5];
241 X.w += ks[(4 + inject_idx) % 5];
242 X.w += 1 + inject_idx;
251 __forceinline__ __device__ __host__
void discard_impl(
unsigned long long offset)
254 m_state.substate += offset & 3;
255 unsigned long long counter_offset = offset / 4;
256 counter_offset += m_state.substate < 4 ? 0 : 1;
257 m_state.substate += m_state.substate < 4 ? 0 : -4;
259 this->discard_state(counter_offset);
263 __forceinline__ __device__ __host__
void
264 discard_subsequence_impl(
unsigned long long subsequence)
267 ::rocrand_device::detail::split_ull(lo, hi, subsequence);
269 value old_counter = m_state.counter.z;
270 m_state.counter.z += lo;
271 m_state.counter.w += hi + (m_state.counter.z < old_counter ? 1 : 0);
276 __forceinline__ __device__ __host__
void discard_state(
unsigned long long offset)
279 ::rocrand_device::detail::split_ull(lo, hi, offset);
281 state_value old_counter = m_state.counter;
282 m_state.counter.x += lo;
283 m_state.counter.y += hi + (m_state.counter.x < old_counter.x ? 1 : 0);
284 m_state.counter.z += (m_state.counter.y < old_counter.y ? 1 : 0);
285 m_state.counter.w += (m_state.counter.z < old_counter.z ? 1 : 0);
288 __forceinline__ __device__ __host__
static state_value bump_counter(state_value counter)
291 value add = counter.x == 0 ? 1 : 0;
293 add = counter.y == 0 ? add : 0;
295 add = counter.z == 0 ? add : 0;
300 __forceinline__ __device__ __host__ state_value interleave(
const state_value prev,
301 const state_value next)
const
303 switch(m_state.substate)
306 case 1:
return state_value{prev.y, prev.z, prev.w, next.x};
307 case 2:
return state_value{prev.z, prev.w, next.x, next.y};
308 case 3:
return state_value{prev.w, next.x, next.y, next.z};
310 __builtin_unreachable();
314 threefry_state_4 m_state;