53 #ifndef ROCRAND_THREEFRY2_IMPL_H_
54 #define ROCRAND_THREEFRY2_IMPL_H_
56 #include "rocrand/rocrand_common.h"
57 #include "rocrand/rocrand_threefry_common.h"
59 #include <hip/hip_runtime.h>
61 #ifndef THREEFRY2x32_DEFAULT_ROUNDS
62 #define THREEFRY2x32_DEFAULT_ROUNDS 20
65 #ifndef THREEFRY2x64_DEFAULT_ROUNDS
66 #define THREEFRY2x64_DEFAULT_ROUNDS 20
69 namespace rocrand_device
73 __forceinline__ __device__ __host__
int threefry_rotation_array(
int index) =
delete;
76 __forceinline__ __device__ __host__
int threefry_rotation_array<unsigned int>(
int index)
82 static constexpr
int THREEFRY_ROTATION_32_2[8] = {13, 15, 26, 6, 17, 29, 16, 24};
83 return THREEFRY_ROTATION_32_2[index];
87 __forceinline__ __device__ __host__
int threefry_rotation_array<unsigned long long>(
int index)
93 static constexpr
int THREEFRY_ROTATION_64_2[8] = {16, 42, 12, 31, 16, 32, 24, 21};
94 return THREEFRY_ROTATION_64_2[index];
97 template<
typename state_value,
typename value,
unsigned int Nrounds>
98 class threefry_engine2_base
101 struct threefry_state_2
106 unsigned int substate;
108 using state_type = threefry_state_2;
109 using state_vector_type = state_value;
111 __forceinline__ __device__ __host__
void discard(
unsigned long long offset)
113 this->discard_impl(offset);
114 m_state.result = this->threefry_rounds(m_state.counter, m_state.key);
117 __forceinline__ __device__ __host__
void discard()
119 m_state.result = this->threefry_rounds(m_state.counter, m_state.key);
127 __forceinline__ __device__ __host__
void discard_subsequence(
unsigned long long subsequence)
129 this->discard_subsequence_impl(subsequence);
130 m_state.result = this->threefry_rounds(m_state.counter, m_state.key);
133 __forceinline__ __device__ __host__ value operator()()
138 __forceinline__ __device__ __host__
141 #if defined(__HIP_PLATFORM_AMD__)
142 value ret = ROCRAND_HIPVEC_ACCESS(m_state.result)[m_state.substate];
144 value ret = (&m_state.result.x)[m_state.substate];
147 if(m_state.substate == 2)
149 m_state.substate = 0;
150 m_state.counter = this->bump_counter(m_state.counter);
151 m_state.result = this->threefry_rounds(m_state.counter, m_state.key);
156 __forceinline__ __device__ __host__ state_value next2()
158 state_value ret = m_state.result;
159 m_state.counter = this->bump_counter(m_state.counter);
160 m_state.result = this->threefry_rounds(m_state.counter, m_state.key);
162 return this->interleave(ret, m_state.result);
166 __forceinline__ __device__ __host__
static state_value threefry_rounds(state_value counter,
172 static_assert(Nrounds <= 32,
"32 or less only supported in threefry rounds");
174 ks[2] = skein_ks_parity<value>();
189 for(
unsigned int round_idx = 0; round_idx < Nrounds; round_idx++)
192 X.y = rotl<value>(X.y, threefry_rotation_array<value>(round_idx & 7u));
195 if((round_idx & 3u) == 3)
197 unsigned int inject_idx = round_idx / 4;
199 X.x += ks[(1 + inject_idx) % 3];
200 X.y += ks[(2 + inject_idx) % 3];
201 X.y += 1 + inject_idx;
210 __forceinline__ __device__ __host__
void discard_impl(
unsigned long long offset)
213 m_state.substate += offset & 1;
214 unsigned long long counter_offset = offset / 2;
215 counter_offset += m_state.substate < 2 ? 0 : 1;
216 m_state.substate += m_state.substate < 2 ? 0 : -2;
218 this->discard_state(counter_offset);
222 __forceinline__ __device__ __host__
void
223 discard_subsequence_impl(
unsigned long long subsequence)
225 m_state.counter.y += subsequence;
230 __forceinline__ __device__ __host__
void discard_state(
unsigned long long offset)
233 ::rocrand_device::detail::split_ull(lo, hi, offset);
235 value old_counter = m_state.counter.x;
236 m_state.counter.x += lo;
237 m_state.counter.y += hi + (m_state.counter.x < old_counter ? 1 : 0);
240 __forceinline__ __device__ __host__
static state_value bump_counter(state_value counter)
243 value add = counter.x == 0 ? 1 : 0;
248 __forceinline__ __device__ __host__ state_value interleave(
const state_value prev,
249 const state_value next)
const
251 switch(m_state.substate)
254 case 1:
return state_value{prev.y, next.x};
256 __builtin_unreachable();
260 threefry_state_2 m_state;