Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Utilities for simplifying assignment of priors #45

Merged
merged 2 commits into from
Oct 9, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
158 changes: 155 additions & 3 deletions gpax/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
Created by Maxim Ziatdinov (email: maxim.ziatdinov@ai4microscopy.com)
"""

from typing import Union, Dict, Type, List
import inspect
from typing import Union, Dict, Type, List, Callable

import jax
import jax.numpy as jnp
Expand Down Expand Up @@ -167,10 +168,161 @@ def preprocess_sparse_image(sparse_image):
return gp_input, targets, full_indices


def normal_prior(param_name, loc=0, scale=1):
def place_normal_prior(param_name: str, loc: float = 0.0, scale: float = 1.0):
"""
Samples a value from a normal distribution with the specified mean (loc) and standard deviation (scale),
and assigns it to a named random variable in the probabilistic model. Can be useful for defining prior mean functions
in structured Gaussian processes.
"""
return numpyro.sample(param_name, numpyro.distributions.Normal(loc, scale))
return numpyro.sample(param_name, normal_dist(loc, scale))


def place_halfnormal_prior(param_name: str, scale: float = 1.0):
"""
Samples a value from a half-normal distribution with the specified standard deviation (scale),
and assigns it to a named random variable in the probabilistic model. Can be useful for defining prior mean functions
in structured Gaussian processes.
"""
return numpyro.sample(param_name, halfnormal_dist(scale))


def place_uniform_prior(param_name: str,
low: float = None,
high: float = None,
X: jnp.ndarray = None):
"""
Samples a value from a uniform distribution with the specified low and high values,
and assigns it to a named random variable in the probabilistic model. Can be useful for defining prior mean functions
in structured Gaussian processes.
"""
d = uniform_dist(low, high, X)
return numpyro.sample(param_name, d)


def place_gamma_prior(param_name: str,
c: float = None,
r: float = None,
X: jnp.ndarray = None):
"""
Samples a value from a uniform distribution with the specified concentration (c) and rate (r) values,
and assigns it to a named random variable in the probabilistic model. Can be useful for defining prior mean functions
in structured Gaussian processes.
"""
d = gamma_dist(c, r, X)
return numpyro.sample(param_name, d)


def normal_dist(loc: float = None, scale: float = None
) -> numpyro.distributions.Distribution:
"""
Generate a Normal distribution based on provided center (loc) and standard deviation (scale) parameters.
I neithere are provided, uses 0 and 1 by default.
"""
loc = loc if loc is not None else 0.0
scale = scale if scale is not None else 1.0
return numpyro.distributions.Normal(loc, scale)


def halfnormal_dist(scale: float = None) -> numpyro.distributions.Distribution:
"""
Generate a half-normal distribution based on provided standard deviation (scale).
If none is provided, uses 1.0 by default.
"""
scale = scale if scale is not None else 1.0
return numpyro.distributions.HalfNormal(scale)


def gamma_dist(c: float = None,
r: float = None,
input_vec: jnp.ndarray = None
) -> numpyro.distributions.Distribution:
"""
Generate a Gamma distribution based on provided shape (c) and rate (r) parameters. If the shape (c) is not provided,
it attempts to infer it using the range of the input vector divided by 2. The rate parameter defaults to 1.0 if not provided.
"""
if c is None:
if input_vec is not None:
c = (input_vec.max() - input_vec.min()) / 2
else:
raise ValueError("Provide either c or an input array")
if r is None:
r = 1.0
return numpyro.distributions.Gamma(c, r)


def uniform_dist(low: float = None,
high: float = None,
input_vec: jnp.ndarray = None
) -> numpyro.distributions.Distribution:
"""
Generate a Uniform distribution based on provided low and high bounds. If one of the bounds is not provided,
it attempts to infer the missing bound(s) using the minimum or maximum value from the input vector.
"""
if (low is None or high is None) and input_vec is None:
raise ValueError(
"If 'low' or 'high' is not provided, an input array must be provided.")
low = low if low is not None else input_vec.min()
high = high if high is not None else input_vec.max()

return numpyro.distributions.Uniform(low, high)


def set_fn(func: Callable) -> Callable:
"""
Transforms the given deterministic function to use a params dictionary
for its parameters, excluding the first one (assumed to be the dependent variable).

Args:
- func (Callable): The deterministic function to be transformed.

Returns:
- Callable: The transformed function where parameters are accessed
from a `params` dictionary.
"""
# Extract parameter names excluding the first one (assumed to be the dependent variable)
params_names = list(inspect.signature(func).parameters.keys())[1:]

# Create the transformed function definition
transformed_code = f"def {func.__name__}(x, params):\n"

# Retrieve the source code of the function and indent it to be a valid function body
source = inspect.getsource(func).split("\n", 1)[1]
source = " " + source.replace("\n", "\n ")

# Replace each parameter name with its dictionary lookup
for name in params_names:
source = source.replace(f" {name}", f' params["{name}"]')

# Combine to get the full source
transformed_code += source

# Define the transformed function in the local namespace
local_namespace = {}
exec(transformed_code, globals(), local_namespace)

# Return the transformed function
return local_namespace[func.__name__]


def auto_normal_priors(func: Callable, loc: float = 0.0, scale: float = 1.0) -> Callable:
"""
Generates a function that, when invoked, samples from normal distributions
for each parameter of the given deterministic function, except the first one.

Args:
- func (Callable): The deterministic function for which to set normal priors.
- loc (float, optional): Mean of the normal distribution. Defaults to 0.0.
- scale (float, optional): Standard deviation of the normal distribution. Defaults to 1.0.

Returns:
- Callable: A function that, when invoked, returns a dictionary of sampled values
from normal distributions for each parameter of the original function.
"""
# Get the names of the parameters of the function excluding the first one (dependent variable)
params_names = list(inspect.signature(func).parameters.keys())[1:]

def sample_priors() -> Dict[str, Union[float, Type[Callable]]]:
# Return a dictionary with normal priors for each parameter
return {name: place_normal_prior(name, loc, scale) for name in params_names}

return sample_priors
132 changes: 128 additions & 4 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import sys
import pytest
import numpy as onp
import jax.numpy as jnp
import jax.random as jra
Expand All @@ -7,9 +8,14 @@

sys.path.insert(0, "../gpax/")

from gpax.utils import preprocess_sparse_image, split_dict, random_sample_dict, get_keys, normal_prior
from gpax.utils import preprocess_sparse_image, split_dict, random_sample_dict, get_keys
from gpax.utils import place_normal_prior, place_halfnormal_prior, place_uniform_prior, place_gamma_prior, gamma_dist, uniform_dist, normal_dist, halfnormal_dist
from gpax.utils import set_fn, auto_normal_priors


def sample_function(x, a, b):
return a + b * x

def test_sparse_img_processing():
img = onp.random.randn(16, 16)
# Generate random indices
Expand Down Expand Up @@ -107,17 +113,135 @@ def test_get_keys_different_seeds():
assert_(not onp.array_equal(key2, key2a))


def test_normal_prior():
@pytest.mark.parametrize("prior", [place_normal_prior, place_halfnormal_prior])
def test_normal_prior(prior):
with numpyro.handlers.seed(rng_seed=1):
sample = prior("a")
assert_(isinstance(sample, jnp.ndarray))


def test_uniform_prior():
with numpyro.handlers.seed(rng_seed=1):
sample = place_uniform_prior("a", 0, 1)
assert_(isinstance(sample, jnp.ndarray))


def test_gamma_prior():
with numpyro.handlers.seed(rng_seed=1):
sample = normal_prior("a")
sample = place_gamma_prior("a", 2, 2)
assert_(isinstance(sample, jnp.ndarray))


def test_normal_prior_params():
with numpyro.handlers.seed(rng_seed=1):
with numpyro.handlers.trace() as tr:
normal_prior("a", loc=0.5, scale=0.1)
place_normal_prior("a", loc=0.5, scale=0.1)
site = tr["a"]
assert_(isinstance(site['fn'], numpyro.distributions.Normal))
assert_equal(site['fn'].loc, 0.5)
assert_equal(site['fn'].scale, 0.1)


def test_halfnormal_prior_params():
with numpyro.handlers.seed(rng_seed=1):
with numpyro.handlers.trace() as tr:
place_halfnormal_prior("a", 0.1)
site = tr["a"]
assert_(isinstance(site['fn'], numpyro.distributions.HalfNormal))
assert_equal(site['fn'].scale, 0.1)


def test_uniform_prior_params():
with numpyro.handlers.seed(rng_seed=1):
with numpyro.handlers.trace() as tr:
place_uniform_prior("a", low=0.5, high=1.0)
site = tr["a"]
assert_(isinstance(site['fn'], numpyro.distributions.Uniform))
assert_equal(site['fn'].low, 0.5)
assert_equal(site['fn'].high, 1.0)


def test_gamma_prior_params():
with numpyro.handlers.seed(rng_seed=1):
with numpyro.handlers.trace() as tr:
place_gamma_prior("a", c=2.0, r=1.0)
site = tr["a"]
assert_(isinstance(site['fn'], numpyro.distributions.Gamma))
assert_equal(site['fn'].concentration, 2.0)
assert_equal(site['fn'].rate, 1.0)


def test_get_uniform_dist():
uniform_dist_ = uniform_dist(low=1.0, high=5.0)
assert isinstance(uniform_dist_, numpyro.distributions.Uniform)
assert uniform_dist_.low == 1.0
assert uniform_dist_.high == 5.0


def test_get_uniform_dist_infer_params():
uniform_dist_ = uniform_dist(input_vec=jnp.array([1.0, 2.0, 3.0, 4.0, 5.0]))
assert uniform_dist_.low == 1.0
assert uniform_dist_.high == 5.0


def test_get_gamma_dist():
gamma_dist_ = gamma_dist(c=2.0, r=1.0)
assert isinstance(gamma_dist_, numpyro.distributions.Gamma)
assert gamma_dist_.concentration == 2.0
assert gamma_dist_.rate == 1.0


def test_get_normal_dist():
normal_dist_ = normal_dist(loc=2.0, scale=3.0)
assert isinstance(normal_dist_, numpyro.distributions.Normal)
assert normal_dist_.loc == 2.0
assert normal_dist_.scale == 3.0


def test_get_halfnormal_dist():
halfnormal_dist_ = halfnormal_dist(scale=1.5)
assert isinstance(halfnormal_dist_, numpyro.distributions.HalfNormal)
assert halfnormal_dist_.scale == 1.5


def test_get_gamma_dist_infer_param():
gamma_dist_ = gamma_dist(input_vec=jnp.linspace(0, 10, 20))
assert isinstance(gamma_dist_, numpyro.distributions.Gamma)
assert gamma_dist_.concentration == 5.0
assert gamma_dist_.rate == 1.0


def test_get_uniform_dist_error():
with pytest.raises(ValueError):
uniform_dist(low=1.0) # Only low provided without input_vec
with pytest.raises(ValueError):
uniform_dist(high=5.0) # Only high provided without input_vec
with pytest.raises(ValueError):
uniform_dist() # Neither low nor high, and no input_vec


def test_get_gamma_dist_error():
with pytest.raises(ValueError):
uniform_dist() # Neither concentration, nor input_vec


def test_set_fn():
transformed_fn = set_fn(sample_function)
result = transformed_fn(2, {"a": 1, "b": 3})
assert result == 7 # Expected output: 1 + 3*2 = 7


def test_auto_normal_priors():
prior_fn = auto_normal_priors(sample_function, loc=2.0, scale=1.0)
with numpyro.handlers.seed(rng_seed=1):
with numpyro.handlers.trace() as tr:
prior_fn()
site1 = tr["a"]
assert_(isinstance(site1['fn'], numpyro.distributions.Normal))
assert_equal(site1['fn'].loc, 2.0)
assert_equal(site1['fn'].scale, 1.0)
site2 = tr["b"]
assert_(isinstance(site2['fn'], numpyro.distributions.Normal))
assert_equal(site2['fn'].loc, 2.0)
assert_equal(site2['fn'].scale, 1.0)

Loading