Skip to content

Commit

Permalink
Add the option to specify a seed for the Random sampling method (#696)
Browse files Browse the repository at this point in the history
* Add the option to specify a seed for the Random sampling method

* Behaviour is controlled by the numpy version used

* Create unit tests

* Update doc with new option and add seed to test function to avoid to generete different image every time the doc is compiled

* Change authors order

---------

Co-authored-by: Enrico Stragiotti <enrico.stragiotti@onera.fr>
  • Loading branch information
enricostragiotti and Enrico Stragiotti authored Dec 17, 2024
1 parent c65e721 commit 61bf8c1
Show file tree
Hide file tree
Showing 8 changed files with 129 additions and 4 deletions.
1 change: 1 addition & 0 deletions AUTHORS.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ SMT has been developed thanks to contributions from:
* Andres Lopez Lopera
* Antoine Averland
* Emile Roux
* Enrico Stragiotti
* Ewout ter Hoeven
* Florent Vergnes
* Frederick Zahle
Expand Down
2 changes: 1 addition & 1 deletion doc/_src_docs/sampling_methods.rst

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

7 changes: 6 additions & 1 deletion doc/_src_docs/sampling_methods/random.rst

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Binary file modified doc/_src_docs/sampling_methods/random_Test_run_random.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified doc/_src_docs/sampling_methods_Test_run_random.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
48 changes: 47 additions & 1 deletion smt/sampling_methods/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,53 @@
Random sampling.
"""

import warnings

import numpy as np

from smt.sampling_methods.sampling_method import ScaledSamplingMethod

# Check NumPy version
numpy_version = tuple(
map(int, np.__version__.split(".")[:2])
) # Extract major and minor version


class Random(ScaledSamplingMethod):
def _initialize(self, **kwargs):
self.options.declare(
"random_state",
types=(type(None), int, np.random.RandomState, np.random.Generator),
desc="Numpy RandomState or Generator object or seed number which controls random draws",
)

# Update options values passed by the user here to get 'random_state' option
self.options.update(kwargs)

# RandomState and Generator are and have to be initialized once at constructor time,
# not in _compute to avoid yielding the same dataset again and again
if numpy_version < (2, 0): # Version is below 2.0.0
if isinstance(self.options["random_state"], np.random.RandomState):
self.random_state = self.options["random_state"]
elif isinstance(self.options["random_state"], np.random.Generator):
self.random_state = np.random.RandomState()
warnings.warn(
"numpy.random.Generator initialization of random_state is not implemented for numpy "
"versions < 2.0.0. Using the default np.random.RandomState() as random_state. "
"Please consider upgrading to numpy version > 2.0.0, or use the legacy numpy.random.RandomState "
"class in the future.",
FutureWarning,
)
elif isinstance(self.options["random_state"], int):
self.random_state = np.random.RandomState(self.options["random_state"])
else:
self.random_state = np.random.RandomState()
else:
# Construct a new Generator with the default BitGenerator (PCG64).
# If passed a Generator, it will be returned unaltered. When passed a legacy
# RandomState instance it will be coerced to a Generator.
self.random_state = np.random.default_rng(seed=self.options["random_state"])

def _compute(self, nt):
"""
Implemented by sampling methods to compute the requested number of sampling points.
Expand All @@ -30,4 +71,9 @@ def _compute(self, nt):
"""
xlimits = self.options["xlimits"]
nx = xlimits.shape[0]
return np.random.rand(nt, nx)
if numpy_version < (2, 0): # Version is below 2.0.0
return self.random_state.rand(nt, nx)
else:
# Create a Generator object with a specified seed (numpy.random_state.rand(nt, nx)
# is being deprecated)
return self.random_state.random((nt, nx))
73 changes: 73 additions & 0 deletions smt/sampling_methods/tests/test_random.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
import unittest
from unittest.mock import patch

import numpy as np
import numpy.testing as npt

from smt.sampling_methods import Random


class TestRandomSamplingMethod(unittest.TestCase):
def setUp(self):
self.xlimits = np.array([[0.0, 1.0], [0.0, 1.0]]) # 2D unit hypercube

def test_random_state_initialization_legacy(self):
# Test random state initialization for numpy < 2.0.0
with patch("smt.sampling_methods.random.numpy_version", new=(1, 21)):
sampler = Random(xlimits=self.xlimits, random_state=12)
self.assertIsInstance(sampler.random_state, np.random.RandomState)

def test_random_state_initialization_new(self):
# Test random state initialization for numpy >= 2.0.0
with patch("smt.sampling_methods.random.numpy_version", new=(2, 0)):
sampler = Random(xlimits=self.xlimits, random_state=12)
self.assertIsInstance(sampler.random_state, np.random.Generator)

def test_random_state_warning_for_generator_legacy(self):
# Test that a warning is issued when using Generator with numpy < 2.0.0
with (
patch("smt.sampling_methods.random.numpy_version", new=(1, 21)),
self.assertWarns(FutureWarning),
):
sampler = Random(xlimits=self.xlimits, random_state=np.random.default_rng())
self.assertIsInstance(sampler.random_state, np.random.RandomState)

def test_compute_legacy(self):
# Test _compute method for numpy < 2.0.0
with patch("smt.sampling_methods.random.numpy_version", new=(1, 26)):
sampler = Random(xlimits=self.xlimits, random_state=12)
points = sampler(4)
self.assertEqual(points.shape, (4, 2))
self.assertTrue(np.all(points >= 0) and np.all(points <= 1))
# Check almost equality with known seed-generated data (example)
expected_points = np.array(
[
[0.154163, 0.74005],
[0.263315, 0.533739],
[0.014575, 0.918747],
[0.900715, 0.033421],
]
)
npt.assert_allclose(points, expected_points, rtol=1e-4)

def test_compute_new(self):
# Test _compute method for numpy >= 2.0.0
with patch("smt.sampling_methods.random.numpy_version", new=(2, 2)):
sampler = Random(xlimits=self.xlimits, random_state=12)
points = sampler(4)
self.assertEqual(points.shape, (4, 2))
self.assertTrue(np.all(points >= 0) and np.all(points <= 1))
# Check almost equality with known seed-generated data (example)
expected_points = np.array(
[
[0.250824, 0.946753],
[0.18932, 0.179291],
[0.349889, 0.230541],
[0.670446, 0.115079],
]
)
npt.assert_allclose(points, expected_points, rtol=1e-4)


if __name__ == "__main__":
unittest.main()
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def run_random():
from smt.sampling_methods import Random

xlimits = np.array([[0.0, 4.0], [0.0, 3.0]])
sampling = Random(xlimits=xlimits)
sampling = Random(xlimits=xlimits, random_state=12)

num = 50
x = sampling(num)
Expand Down

0 comments on commit 61bf8c1

Please sign in to comment.