Skip to content

Commit

Permalink
Added PPF tests
Browse files Browse the repository at this point in the history
  • Loading branch information
vhaasteren committed Sep 21, 2023
1 parent ef663fa commit 1fa8dd9
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 4 deletions.
5 changes: 4 additions & 1 deletion enterprise/signals/parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,6 +266,9 @@ def NormalPPF(value, mu, sigma):
Handles scalar mu and sigma, compatible vector value/mu/sigma,
vector value/mu and compatible covariance matrix sigma."""

if np.ndim(sigma) == 2:
raise NotImplementedError("PPF not implemented when sigma is 2D")

return sstats.norm.ppf(value, loc=mu, scale=sigma)


Expand All @@ -285,7 +288,7 @@ def Normal(mu=0, sigma=1, size=None):
class Normal(Parameter):
_size = size
_prior = Function(NormalPrior, mu=mu, sigma=sigma)
_ppf = Function(NormaPPF, mu=mu, sigma=sigma)
_ppf = Function(NormalPPF, mu=mu, sigma=sigma)
_sampler = staticmethod(NormalSampler)
_typename = _argrepr("Normal", mu=mu, sigma=sigma)

Expand Down
43 changes: 40 additions & 3 deletions tests/test_parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,10 @@
import numpy as np
import scipy.stats

from enterprise.signals.parameter import UniformPrior, UniformSampler, Uniform
from enterprise.signals.parameter import NormalPrior, NormalSampler, Normal
from enterprise.signals.parameter import UniformPrior, UniformSampler, Uniform, UniformPPF
from enterprise.signals.parameter import NormalPrior, NormalSampler, Normal, NormalPPF
from enterprise.signals.parameter import TruncNormalPrior, TruncNormalSampler, TruncNormal
from enterprise.signals.parameter import LinearExpPrior, LinearExpSampler
from enterprise.signals.parameter import LinearExpPrior, LinearExpSampler, LinearExpPPF


class TestParameter(unittest.TestCase):
Expand All @@ -35,6 +35,9 @@ def test_uniform(self):
assert p_min < x1 < p_max, msg2
assert type(x1) == float, msg2

msg3 = "Enterprise and scipy PPF do not match"
assert np.allclose(UniformPPF(x, p_min, p_max), scipy.stats.uniform.ppf(x, p_min, p_max - p_min)), msg3

# vector argument
x = np.array([0.5, 0.1])
assert np.allclose(UniformPrior(x, p_min, p_max), scipy.stats.uniform.pdf(x, p_min, p_max - p_min)), msg1
Expand All @@ -43,9 +46,13 @@ def test_uniform(self):
assert np.all((p_min < x1) & (x1 < p_max)), msg2
assert x1.shape == (3,), msg2

# vector argument
assert np.allclose(UniformPPF(x, p_min, p_max), scipy.stats.uniform.ppf(x, p_min, p_max - p_min)), msg3

# vector bounds
p_min, p_max = np.array([0.2, 0.3]), np.array([1.1, 1.2])
assert np.allclose(UniformPrior(x, p_min, p_max), scipy.stats.uniform.pdf(x, p_min, p_max - p_min)), msg1
assert np.allclose(UniformPPF(x, p_min, p_max), scipy.stats.uniform.ppf(x, p_min, p_max - p_min)), msg3

x1 = UniformSampler(p_min, p_max)
assert np.all((p_min < x1) & (x1 < p_max)), msg2
Expand All @@ -68,6 +75,10 @@ def test_linearexp(self):
msg1b = "Scalar sampler out of range"
assert p_min <= x <= p_max, msg1b

msg1c = "Scalar PPF does not match"
x = 0.5
assert np.allclose(LinearExpPPF(x, p_min, p_max), np.log10(10**p_min + x*(10**p_max-10**p_min))), msg1c

# vector argument
x = np.array([0, 1.5, 2.5])
msg2 = "Vector-argument prior does not match"
Expand All @@ -79,6 +90,13 @@ def test_linearexp(self):
msg2b = "Vector-argument sampler out of range"
assert np.all((p_min < x) & (x < p_max)), msg2b

x = np.array([0.5, 0.75])
msg2c = "Vector-argument PPF does not match"
assert np.allclose(
LinearExpPPF(x, p_min, p_max),
np.log10(10**p_min + x * (10**p_max-10**p_min))
), msg2c

# vector bounds
p_min, p_max = np.array([0, 1]), np.array([2, 3])
x = np.array([1, 2])
Expand All @@ -88,6 +106,15 @@ def test_linearexp(self):
np.array([10**1 / (10**2 - 10**0), 10**2 / (10**3 - 10**1)]) * np.log(10),
), msg3

# Vector PPF
x = np.array([0.5, 0.75])
p_min, p_max = np.array([0, 1]), np.array([2, 3])
msg3c = "Vector-argument PPF+bounds does not match"
assert np.allclose(
LinearExpPPF(x, p_min, p_max),
np.log10(10**p_min + x * (10**p_max-10**p_min))
), msg3c

def test_normal(self):
"""Test Normal parameter prior and sampler for various combinations of scalar and vector arguments."""

Expand All @@ -105,6 +132,11 @@ def test_normal(self):
# this should almost never fail
assert -5 < (x1 - mu) / sigma < 5, msg2

msg3 = "Enterprise and scipy PPF do not match"
assert np.allclose(
NormalPPF(x, mu, sigma), scipy.stats.norm.ppf(x, loc=mu, scale=sigma)
), msg3

# vector argument
x = np.array([-0.2, 0.1, 0.5])

Expand All @@ -118,6 +150,11 @@ def test_normal(self):
)
assert x1.shape == x2.shape, msg2

x = np.array([0.1, 0.25, 0.65])
assert np.allclose(
NormalPPF(x, mu, sigma), scipy.stats.norm.ppf(x, loc=mu, scale=sigma)
), msg3

# vector bounds; note the different semantics from `NormalPrior`,
# which returns a vector consistently with `UniformPrior`
mu, sigma = np.array([0.1, 0.15, 0.2]), np.array([2, 1, 2])
Expand Down

0 comments on commit 1fa8dd9

Please sign in to comment.