diff --git a/gpax/utils.py b/gpax/utils.py index 8eede8b..f7b11c4 100644 --- a/gpax/utils.py +++ b/gpax/utils.py @@ -7,7 +7,8 @@ Created by Maxim Ziatdinov (email: maxim.ziatdinov@ai4microscopy.com) """ -from typing import Union, Dict, Type, List +import inspect +from typing import Union, Dict, Type, List, Callable import jax import jax.numpy as jnp @@ -264,3 +265,64 @@ def uniform_dist(low: float = None, high = high if high is not None else input_vec.max() return numpyro.distributions.Uniform(low, high) + + +def set_fn(func: Callable) -> Callable: + """ + Transforms the given deterministic function to use a params dictionary + for its parameters, excluding the first one (assumed to be the dependent variable). + + Args: + - func (Callable): The deterministic function to be transformed. + + Returns: + - Callable: The transformed function where parameters are accessed + from a `params` dictionary. + """ + # Extract parameter names excluding the first one (assumed to be the dependent variable) + params_names = list(inspect.signature(func).parameters.keys())[1:] + + # Create the transformed function definition + transformed_code = f"def {func.__name__}(x, params):\n" + + # Retrieve the source code of the function and indent it to be a valid function body + source = inspect.getsource(func).split("\n", 1)[1] + source = " " + source.replace("\n", "\n ") + + # Replace each parameter name with its dictionary lookup + for name in params_names: + source = source.replace(f" {name}", f' params["{name}"]') + + # Combine to get the full source + transformed_code += source + + # Define the transformed function in the local namespace + local_namespace = {} + exec(transformed_code, globals(), local_namespace) + + # Return the transformed function + return local_namespace[func.__name__] + + +def auto_normal_priors(func: Callable, loc: float = 0.0, scale: float = 1.0) -> Callable: + """ + Generates a function that, when invoked, samples from normal distributions + for each parameter of the given deterministic function, except the first one. + + Args: + - func (Callable): The deterministic function for which to set normal priors. + - loc (float, optional): Mean of the normal distribution. Defaults to 0.0. + - scale (float, optional): Standard deviation of the normal distribution. Defaults to 1.0. + + Returns: + - Callable: A function that, when invoked, returns a dictionary of sampled values + from normal distributions for each parameter of the original function. + """ + # Get the names of the parameters of the function excluding the first one (dependent variable) + params_names = list(inspect.signature(func).parameters.keys())[1:] + + def sample_priors() -> Dict[str, Union[float, Type[Callable]]]: + # Return a dictionary with normal priors for each parameter + return {name: place_normal_prior(name, loc, scale) for name in params_names} + + return sample_priors diff --git a/tests/test_utils.py b/tests/test_utils.py index 18aa749..62d0025 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -10,8 +10,12 @@ from gpax.utils import preprocess_sparse_image, split_dict, random_sample_dict, get_keys from gpax.utils import place_normal_prior, place_halfnormal_prior, place_uniform_prior, place_gamma_prior, gamma_dist, uniform_dist, normal_dist, halfnormal_dist +from gpax.utils import set_fn, auto_normal_priors +def sample_function(x, a, b): + return a + b * x + def test_sparse_img_processing(): img = onp.random.randn(16, 16) # Generate random indices @@ -221,4 +225,23 @@ def test_get_gamma_dist_error(): uniform_dist() # Neither concentration, nor input_vec +def test_set_fn(): + transformed_fn = set_fn(sample_function) + result = transformed_fn(2, {"a": 1, "b": 3}) + assert result == 7 # Expected output: 1 + 3*2 = 7 + + +def test_auto_normal_priors(): + prior_fn = auto_normal_priors(sample_function, loc=2.0, scale=1.0) + with numpyro.handlers.seed(rng_seed=1): + with numpyro.handlers.trace() as tr: + prior_fn() + site1 = tr["a"] + assert_(isinstance(site1['fn'], numpyro.distributions.Normal)) + assert_equal(site1['fn'].loc, 2.0) + assert_equal(site1['fn'].scale, 1.0) + site2 = tr["b"] + assert_(isinstance(site2['fn'], numpyro.distributions.Normal)) + assert_equal(site2['fn'].loc, 2.0) + assert_equal(site2['fn'].scale, 1.0)