Skip to content

Commit

Permalink
Option to set normal priors automatically
Browse files Browse the repository at this point in the history
  • Loading branch information
ziatdinovmax committed Oct 8, 2023
1 parent cab5bd1 commit 4edb082
Show file tree
Hide file tree
Showing 2 changed files with 86 additions and 1 deletion.
64 changes: 63 additions & 1 deletion gpax/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
Created by Maxim Ziatdinov (email: maxim.ziatdinov@ai4microscopy.com)
"""

from typing import Union, Dict, Type, List
import inspect
from typing import Union, Dict, Type, List, Callable

import jax
import jax.numpy as jnp
Expand Down Expand Up @@ -264,3 +265,64 @@ def uniform_dist(low: float = None,
high = high if high is not None else input_vec.max()

return numpyro.distributions.Uniform(low, high)


def set_fn(func: Callable) -> Callable:
"""
Transforms the given deterministic function to use a params dictionary
for its parameters, excluding the first one (assumed to be the dependent variable).
Args:
- func (Callable): The deterministic function to be transformed.
Returns:
- Callable: The transformed function where parameters are accessed
from a `params` dictionary.
"""
# Extract parameter names excluding the first one (assumed to be the dependent variable)
params_names = list(inspect.signature(func).parameters.keys())[1:]

# Create the transformed function definition
transformed_code = f"def {func.__name__}(x, params):\n"

# Retrieve the source code of the function and indent it to be a valid function body
source = inspect.getsource(func).split("\n", 1)[1]
source = " " + source.replace("\n", "\n ")

# Replace each parameter name with its dictionary lookup
for name in params_names:
source = source.replace(f" {name}", f' params["{name}"]')

# Combine to get the full source
transformed_code += source

# Define the transformed function in the local namespace
local_namespace = {}
exec(transformed_code, globals(), local_namespace)

# Return the transformed function
return local_namespace[func.__name__]


def auto_normal_priors(func: Callable, loc: float = 0.0, scale: float = 1.0) -> Callable:
"""
Generates a function that, when invoked, samples from normal distributions
for each parameter of the given deterministic function, except the first one.
Args:
- func (Callable): The deterministic function for which to set normal priors.
- loc (float, optional): Mean of the normal distribution. Defaults to 0.0.
- scale (float, optional): Standard deviation of the normal distribution. Defaults to 1.0.
Returns:
- Callable: A function that, when invoked, returns a dictionary of sampled values
from normal distributions for each parameter of the original function.
"""
# Get the names of the parameters of the function excluding the first one (dependent variable)
params_names = list(inspect.signature(func).parameters.keys())[1:]

def sample_priors() -> Dict[str, Union[float, Type[Callable]]]:
# Return a dictionary with normal priors for each parameter
return {name: place_normal_prior(name, loc, scale) for name in params_names}

return sample_priors
23 changes: 23 additions & 0 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,12 @@

from gpax.utils import preprocess_sparse_image, split_dict, random_sample_dict, get_keys
from gpax.utils import place_normal_prior, place_halfnormal_prior, place_uniform_prior, place_gamma_prior, gamma_dist, uniform_dist, normal_dist, halfnormal_dist
from gpax.utils import set_fn, auto_normal_priors


def sample_function(x, a, b):
return a + b * x

def test_sparse_img_processing():
img = onp.random.randn(16, 16)
# Generate random indices
Expand Down Expand Up @@ -221,4 +225,23 @@ def test_get_gamma_dist_error():
uniform_dist() # Neither concentration, nor input_vec


def test_set_fn():
transformed_fn = set_fn(sample_function)
result = transformed_fn(2, {"a": 1, "b": 3})
assert result == 7 # Expected output: 1 + 3*2 = 7


def test_auto_normal_priors():
prior_fn = auto_normal_priors(sample_function, loc=2.0, scale=1.0)
with numpyro.handlers.seed(rng_seed=1):
with numpyro.handlers.trace() as tr:
prior_fn()
site1 = tr["a"]
assert_(isinstance(site1['fn'], numpyro.distributions.Normal))
assert_equal(site1['fn'].loc, 2.0)
assert_equal(site1['fn'].scale, 1.0)
site2 = tr["b"]
assert_(isinstance(site2['fn'], numpyro.distributions.Normal))
assert_equal(site2['fn'].loc, 2.0)
assert_equal(site2['fn'].scale, 1.0)

0 comments on commit 4edb082

Please sign in to comment.