diff --git a/integration_tests/test_random.py b/integration_tests/test_random.py index db60a16d37..d20f6286e0 100644 --- a/integration_tests/test_random.py +++ b/integration_tests/test_random.py @@ -52,6 +52,33 @@ def test_weibullvariate(): r = random.weibullvariate(-5.6, 1.2) print(r) +def test_seed(): + random.seed() + t6: f64 = random.random() + random.seed(123) + t1: f64 + t1 = random.random() + random.seed(321) + t2: f64 + t2 = random.random() + random.seed(123) + t3: f64 + t3 = random.random() + random.seed(0) + t4: f64 + t4 = random.random() + random.seed(0) + t5: f64 + t5 = random.random() + random.seed() + t7: f64 = random.random() + assert t1 != t2 + assert t1 == t3 + assert t1 != t4 + assert t1 != t5 + assert t4 == t5 + assert t6 != t7 + def check(): test_random() test_randrange() @@ -60,5 +87,6 @@ def check(): test_paretovariate() test_expovariate() test_weibullvariate() + test_seed() check() diff --git a/src/libasr/runtime/lfortran_intrinsics.c b/src/libasr/runtime/lfortran_intrinsics.c index 6548592d5d..f01cb94550 100644 --- a/src/libasr/runtime/lfortran_intrinsics.c +++ b/src/libasr/runtime/lfortran_intrinsics.c @@ -96,6 +96,16 @@ LFORTRAN_API void _lfortran_random_number(int n, double *v) } } +LFORTRAN_API void _lfortran_init_random_seed(unsigned seed) +{ + srand(seed); +} + +LFORTRAN_API void _lfortran_init_random_clock() +{ + srand((unsigned int)clock()); +} + LFORTRAN_API double _lfortran_random() { return (rand() / (double) RAND_MAX); diff --git a/src/libasr/runtime/lfortran_intrinsics.h b/src/libasr/runtime/lfortran_intrinsics.h index ab418523df..3ea56e4991 100644 --- a/src/libasr/runtime/lfortran_intrinsics.h +++ b/src/libasr/runtime/lfortran_intrinsics.h @@ -67,6 +67,8 @@ typedef double _Complex double_complex_t; LFORTRAN_API double _lfortran_sum(int n, double *v); LFORTRAN_API void _lfortran_random_number(int n, double *v); +LFORTRAN_API void _lfortran_init_random_clock(); +LFORTRAN_API void _lfortran_init_random_seed(unsigned seed); LFORTRAN_API double _lfortran_random(); LFORTRAN_API int _lfortran_randrange(int lower, int upper); LFORTRAN_API int _lfortran_random_int(int lower, int upper); diff --git a/src/runtime/random.py b/src/runtime/random.py index 280f5552db..3e7367dc4e 100644 --- a/src/runtime/random.py +++ b/src/runtime/random.py @@ -32,6 +32,30 @@ def random() -> f64: def _lfortran_random() -> f64: pass +@overload +def seed() -> None: + """ + Initializes the random number generator. + """ + _lfortran_init_random_clock() + return + +@overload +def seed(seed: i32) -> None: + """ + Initializes the random number generator. + """ + _lfortran_init_random_seed(seed) + return + +@ccall +def _lfortran_init_random_clock() -> None: + pass + +@ccall +def _lfortran_init_random_seed(seed: i32) -> None: + pass + def randrange(lower: i32, upper: i32) -> i32: """ Return a random integer N such that `lower <= N < upper`.