Skip to content

Commit

Permalink
Add utilities for assigning priors
Browse files Browse the repository at this point in the history
Covers normal, half-normal, uniform, and gamma distributions
  • Loading branch information
ziatdinovmax committed Oct 8, 2023
1 parent 4049818 commit cab5bd1
Show file tree
Hide file tree
Showing 2 changed files with 197 additions and 6 deletions.
94 changes: 92 additions & 2 deletions gpax/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,10 +167,100 @@ def preprocess_sparse_image(sparse_image):
return gp_input, targets, full_indices


def normal_prior(param_name, loc=0, scale=1):
def place_normal_prior(param_name: str, loc: float = 0.0, scale: float = 1.0):
"""
Samples a value from a normal distribution with the specified mean (loc) and standard deviation (scale),
and assigns it to a named random variable in the probabilistic model. Can be useful for defining prior mean functions
in structured Gaussian processes.
"""
return numpyro.sample(param_name, numpyro.distributions.Normal(loc, scale))
return numpyro.sample(param_name, normal_dist(loc, scale))


def place_halfnormal_prior(param_name: str, scale: float = 1.0):
"""
Samples a value from a half-normal distribution with the specified standard deviation (scale),
and assigns it to a named random variable in the probabilistic model. Can be useful for defining prior mean functions
in structured Gaussian processes.
"""
return numpyro.sample(param_name, halfnormal_dist(scale))


def place_uniform_prior(param_name: str,
low: float = None,
high: float = None,
X: jnp.ndarray = None):
"""
Samples a value from a uniform distribution with the specified low and high values,
and assigns it to a named random variable in the probabilistic model. Can be useful for defining prior mean functions
in structured Gaussian processes.
"""
d = uniform_dist(low, high, X)
return numpyro.sample(param_name, d)


def place_gamma_prior(param_name: str,
c: float = None,
r: float = None,
X: jnp.ndarray = None):
"""
Samples a value from a uniform distribution with the specified concentration (c) and rate (r) values,
and assigns it to a named random variable in the probabilistic model. Can be useful for defining prior mean functions
in structured Gaussian processes.
"""
d = gamma_dist(c, r, X)
return numpyro.sample(param_name, d)


def normal_dist(loc: float = None, scale: float = None
) -> numpyro.distributions.Distribution:
"""
Generate a Normal distribution based on provided center (loc) and standard deviation (scale) parameters.
I neithere are provided, uses 0 and 1 by default.
"""
loc = loc if loc is not None else 0.0
scale = scale if scale is not None else 1.0
return numpyro.distributions.Normal(loc, scale)


def halfnormal_dist(scale: float = None) -> numpyro.distributions.Distribution:
"""
Generate a half-normal distribution based on provided standard deviation (scale).
If none is provided, uses 1.0 by default.
"""
scale = scale if scale is not None else 1.0
return numpyro.distributions.HalfNormal(scale)


def gamma_dist(c: float = None,
r: float = None,
input_vec: jnp.ndarray = None
) -> numpyro.distributions.Distribution:
"""
Generate a Gamma distribution based on provided shape (c) and rate (r) parameters. If the shape (c) is not provided,
it attempts to infer it using the range of the input vector divided by 2. The rate parameter defaults to 1.0 if not provided.
"""
if c is None:
if input_vec is not None:
c = (input_vec.max() - input_vec.min()) / 2
else:
raise ValueError("Provide either c or an input array")
if r is None:
r = 1.0
return numpyro.distributions.Gamma(c, r)


def uniform_dist(low: float = None,
high: float = None,
input_vec: jnp.ndarray = None
) -> numpyro.distributions.Distribution:
"""
Generate a Uniform distribution based on provided low and high bounds. If one of the bounds is not provided,
it attempts to infer the missing bound(s) using the minimum or maximum value from the input vector.
"""
if (low is None or high is None) and input_vec is None:
raise ValueError(
"If 'low' or 'high' is not provided, an input array must be provided.")
low = low if low is not None else input_vec.min()
high = high if high is not None else input_vec.max()

return numpyro.distributions.Uniform(low, high)
109 changes: 105 additions & 4 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import sys
import pytest
import numpy as onp
import jax.numpy as jnp
import jax.random as jra
Expand All @@ -7,7 +8,8 @@

sys.path.insert(0, "../gpax/")

from gpax.utils import preprocess_sparse_image, split_dict, random_sample_dict, get_keys, normal_prior
from gpax.utils import preprocess_sparse_image, split_dict, random_sample_dict, get_keys
from gpax.utils import place_normal_prior, place_halfnormal_prior, place_uniform_prior, place_gamma_prior, gamma_dist, uniform_dist, normal_dist, halfnormal_dist


def test_sparse_img_processing():
Expand Down Expand Up @@ -107,17 +109,116 @@ def test_get_keys_different_seeds():
assert_(not onp.array_equal(key2, key2a))


def test_normal_prior():
@pytest.mark.parametrize("prior", [place_normal_prior, place_halfnormal_prior])
def test_normal_prior(prior):
with numpyro.handlers.seed(rng_seed=1):
sample = normal_prior("a")
sample = prior("a")
assert_(isinstance(sample, jnp.ndarray))


def test_uniform_prior():
with numpyro.handlers.seed(rng_seed=1):
sample = place_uniform_prior("a", 0, 1)
assert_(isinstance(sample, jnp.ndarray))


def test_gamma_prior():
with numpyro.handlers.seed(rng_seed=1):
sample = place_gamma_prior("a", 2, 2)
assert_(isinstance(sample, jnp.ndarray))


def test_normal_prior_params():
with numpyro.handlers.seed(rng_seed=1):
with numpyro.handlers.trace() as tr:
normal_prior("a", loc=0.5, scale=0.1)
place_normal_prior("a", loc=0.5, scale=0.1)
site = tr["a"]
assert_(isinstance(site['fn'], numpyro.distributions.Normal))
assert_equal(site['fn'].loc, 0.5)
assert_equal(site['fn'].scale, 0.1)


def test_halfnormal_prior_params():
with numpyro.handlers.seed(rng_seed=1):
with numpyro.handlers.trace() as tr:
place_halfnormal_prior("a", 0.1)
site = tr["a"]
assert_(isinstance(site['fn'], numpyro.distributions.HalfNormal))
assert_equal(site['fn'].scale, 0.1)


def test_uniform_prior_params():
with numpyro.handlers.seed(rng_seed=1):
with numpyro.handlers.trace() as tr:
place_uniform_prior("a", low=0.5, high=1.0)
site = tr["a"]
assert_(isinstance(site['fn'], numpyro.distributions.Uniform))
assert_equal(site['fn'].low, 0.5)
assert_equal(site['fn'].high, 1.0)


def test_gamma_prior_params():
with numpyro.handlers.seed(rng_seed=1):
with numpyro.handlers.trace() as tr:
place_gamma_prior("a", c=2.0, r=1.0)
site = tr["a"]
assert_(isinstance(site['fn'], numpyro.distributions.Gamma))
assert_equal(site['fn'].concentration, 2.0)
assert_equal(site['fn'].rate, 1.0)


def test_get_uniform_dist():
uniform_dist_ = uniform_dist(low=1.0, high=5.0)
assert isinstance(uniform_dist_, numpyro.distributions.Uniform)
assert uniform_dist_.low == 1.0
assert uniform_dist_.high == 5.0


def test_get_uniform_dist_infer_params():
uniform_dist_ = uniform_dist(input_vec=jnp.array([1.0, 2.0, 3.0, 4.0, 5.0]))
assert uniform_dist_.low == 1.0
assert uniform_dist_.high == 5.0


def test_get_gamma_dist():
gamma_dist_ = gamma_dist(c=2.0, r=1.0)
assert isinstance(gamma_dist_, numpyro.distributions.Gamma)
assert gamma_dist_.concentration == 2.0
assert gamma_dist_.rate == 1.0


def test_get_normal_dist():
normal_dist_ = normal_dist(loc=2.0, scale=3.0)
assert isinstance(normal_dist_, numpyro.distributions.Normal)
assert normal_dist_.loc == 2.0
assert normal_dist_.scale == 3.0


def test_get_halfnormal_dist():
halfnormal_dist_ = halfnormal_dist(scale=1.5)
assert isinstance(halfnormal_dist_, numpyro.distributions.HalfNormal)
assert halfnormal_dist_.scale == 1.5


def test_get_gamma_dist_infer_param():
gamma_dist_ = gamma_dist(input_vec=jnp.linspace(0, 10, 20))
assert isinstance(gamma_dist_, numpyro.distributions.Gamma)
assert gamma_dist_.concentration == 5.0
assert gamma_dist_.rate == 1.0


def test_get_uniform_dist_error():
with pytest.raises(ValueError):
uniform_dist(low=1.0) # Only low provided without input_vec
with pytest.raises(ValueError):
uniform_dist(high=5.0) # Only high provided without input_vec
with pytest.raises(ValueError):
uniform_dist() # Neither low nor high, and no input_vec


def test_get_gamma_dist_error():
with pytest.raises(ValueError):
uniform_dist() # Neither concentration, nor input_vec



0 comments on commit cab5bd1

Please sign in to comment.