Skip to content

Commit

Permalink
Fix cong implementation to be properly random and not just cycling. (J…
Browse files Browse the repository at this point in the history
…uliaLang#55509)

This was found by @IanButterworth. It unfortunately has a small
performance regression due to actually using all the rng bits
  • Loading branch information
gbaraldi authored Aug 26, 2024
1 parent 6477530 commit 733b3f5
Show file tree
Hide file tree
Showing 4 changed files with 28 additions and 12 deletions.
2 changes: 1 addition & 1 deletion src/gc-stock.h
Original file line number Diff line number Diff line change
Expand Up @@ -446,7 +446,7 @@ STATIC_INLINE int gc_is_concurrent_collector_thread(int tid) JL_NOTSAFEPOINT
STATIC_INLINE int gc_random_parallel_collector_thread_id(jl_ptls_t ptls) JL_NOTSAFEPOINT
{
assert(jl_n_markthreads > 0);
int v = gc_first_tid + (int)cong(jl_n_markthreads - 1, &ptls->rngseed);
int v = gc_first_tid + (int)cong(jl_n_markthreads, &ptls->rngseed); // cong is [0, n)
assert(v >= gc_first_tid && v <= gc_last_parallel_collector_thread_id());
return v;
}
Expand Down
34 changes: 25 additions & 9 deletions src/julia_internal.h
Original file line number Diff line number Diff line change
Expand Up @@ -1307,20 +1307,36 @@ JL_DLLEXPORT size_t jl_maxrss(void);
// congruential random number generator
// for a small amount of thread-local randomness

STATIC_INLINE uint64_t cong(uint64_t max, uint64_t *seed) JL_NOTSAFEPOINT
//TODO: utilize https://github.com/openssl/openssl/blob/master/crypto/rand/rand_uniform.c#L13-L99
// for better performance, it does however require making users expect a 32bit random number.

STATIC_INLINE uint64_t cong(uint64_t max, uint64_t *seed) JL_NOTSAFEPOINT // Open interval [0, max)
{
if (max == 0)
if (max < 2)
return 0;
uint64_t mask = ~(uint64_t)0;
--max;
mask >>= __builtin_clzll(max|1);
uint64_t x;
int zeros = __builtin_clzll(max);
int bits = CHAR_BIT * sizeof(uint64_t) - zeros;
mask = mask >> zeros;
do {
*seed = 69069 * (*seed) + 362437;
x = *seed & mask;
} while (x > max);
return x;
uint64_t value = 69069 * (*seed) + 362437;
*seed = value;
uint64_t x = value & mask;
if (x < max) {
return x;
}
int bits_left = zeros;
while (bits_left >= bits) {
value >>= bits;
x = value & mask;
if (x < max) {
return x;
}
bits_left -= bits;
}
} while (1);
}

JL_DLLEXPORT uint64_t jl_rand(void) JL_NOTSAFEPOINT;
JL_DLLEXPORT void jl_srand(uint64_t) JL_NOTSAFEPOINT;
JL_DLLEXPORT void jl_init_rand(void);
Expand Down
2 changes: 1 addition & 1 deletion src/scheduler.c
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ extern int jl_gc_mark_queue_obj_explicit(jl_gc_mark_cache_t *gc_cache,
// parallel task runtime
// ---

JL_DLLEXPORT uint32_t jl_rand_ptls(uint32_t max)
JL_DLLEXPORT uint32_t jl_rand_ptls(uint32_t max) // [0, n)
{
jl_ptls_t ptls = jl_current_task->ptls;
return cong(max, &ptls->rngseed);
Expand Down
2 changes: 1 addition & 1 deletion src/signal-handling.c
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ static void jl_shuffle_int_array_inplace(int *carray, int size, uint64_t *seed)
// The "modern Fisher–Yates shuffle" - O(n) algorithm
// https://en.wikipedia.org/wiki/Fisher%E2%80%93Yates_shuffle#The_modern_algorithm
for (int i = size; i-- > 1; ) {
size_t j = cong(i, seed);
size_t j = cong(i + 1, seed); // cong is an open interval so we add 1
uint64_t tmp = carray[j];
carray[j] = carray[i];
carray[i] = tmp;
Expand Down

0 comments on commit 733b3f5

Please sign in to comment.