Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix JaxObjective #1400

Merged
merged 10 commits into from
May 29, 2024
43 changes: 32 additions & 11 deletions pypesto/objective/jax/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

import copy
from functools import partial
from typing import Union
from typing import Callable, Union

import numpy as np

Expand All @@ -26,6 +26,18 @@
"`pip install jax jaxlib`."
) from None


def _base_objective_as_jax_array_tuple(func: Callable):
def decorator(*args, **kwargs):
# make sure return is a tuple of jax arrays
results = func(*args, **kwargs)
if isinstance(results, tuple):
return tuple(jnp.array(r) for r in results)
return jnp.array(results)

return decorator


# jax compatible (jit-able) objective function using external callback, see
# https://jax.readthedocs.io/en/latest/notebooks/external_callbacks.html

Expand All @@ -42,7 +54,9 @@ def _device_fun(base_objective: ObjectiveBase, x: jnp.array):
jax computed input array.
"""
return jax.pure_callback(
FFroehlich marked this conversation as resolved.
Show resolved Hide resolved
partial(base_objective, sensi_orders=(0,)),
_base_objective_as_jax_array_tuple(
partial(base_objective, sensi_orders=(0,))
),
jax.ShapeDtypeStruct((), x.dtype),
x,
)
Expand All @@ -65,12 +79,14 @@ def _device_fun_value_and_grad(base_objective: ObjectiveBase, x: jnp.array):
jax computed input array.
"""
return jax.pure_callback(
partial(
base_objective,
sensi_orders=(
0,
1,
),
_base_objective_as_jax_array_tuple(
partial(
base_objective,
sensi_orders=(
0,
1,
),
)
),
(
jax.ShapeDtypeStruct((), x.dtype),
Expand Down Expand Up @@ -204,15 +220,20 @@ def __deepcopy__(self, memodict=None):

@property
def history(self):
"""Exposes the history of the inner objective."""
"""Expose the history of the inner objective."""
return self.base_objective.history

@property
def pre_post_processor(self):
"""Exposes the pre_post_processor of inner objective."""
"""Expose the pre_post_processor of inner objective."""
return self.base_objective.pre_post_processor

@pre_post_processor.setter
def pre_post_processor(self, new_pre_post_processor):
"""Set the pre_post_processor of inner objective."""
self.base_objective.pre_post_processor = new_pre_post_processor

@property
def x_names(self):
"""Exposes the x_names of inner objective."""
"""Expose the x_names of inner objective."""
return self.base_objective.x_names
24 changes: 21 additions & 3 deletions test/base/test_objective.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,8 @@ def test_aesara(max_sensi_order, integrated):


@pytest.mark.parametrize("enable_x64", [True, False])
def test_jax(max_sensi_order, integrated, enable_x64):
@pytest.mark.parametrize("fix_parameters", [True, False])
def test_jax(max_sensi_order, integrated, enable_x64, fix_parameters):
"""Test function composition and gradient computation via jax"""
import jax
import jax.numpy as jnp
Expand All @@ -234,6 +235,7 @@ def test_jax(max_sensi_order, integrated, enable_x64):
jax.config.update("jax_enable_x64", enable_x64)

from pypesto.objective.jax import JaxObjective
from pypesto.objective.pre_post_process import FixedParametersProcessor

prob = rosen_for_sensi(max_sensi_order, integrated, [0, 1])

Expand All @@ -250,9 +252,20 @@ def jax_op_out(x: jnp.array) -> jnp.array:
# compose rosenbrock function with sinh transformation
obj = JaxObjective(prob["obj"])

if fix_parameters:
obj.pre_post_processor = FixedParametersProcessor(
dim_full=2,
x_free_indices=[0],
x_fixed_indices=[1],
x_fixed_vals=[0.0],
)

# evaluate for a couple of random points such that we can assess
# compatibility with vmap
xx = x_ref + np.random.randn(10, x_ref.shape[0])
if fix_parameters:
xx = xx[:, obj.pre_post_processor.x_free_indices]

rvals_ref = [
jax_op_out(
prob["obj"](jax_op_in(xxi), sensi_orders=(max_sensi_order,))
Expand Down Expand Up @@ -281,8 +294,11 @@ def _fun(y, pypesto_fun, jax_fun_in, jax_fun_out):
# can't use rtol = 1e-8 for 32bit
rtol = 1e-16 if enable_x64 else 1e-4
for x, rref, rj in zip(xx, rvals_ref, rvals_jax):
assert isinstance(rj, jnp.ndarray)
if max_sensi_order == 0:
np.testing.assert_allclose(rref, rj, atol=atol, rtol=rtol)
np.testing.assert_allclose(
rref, float(rj), atol=atol, rtol=rtol
)
if max_sensi_order == 1:
FFroehlich marked this conversation as resolved.
Show resolved Hide resolved
# g(x) = b(c(x)) => g'(x) = b'(c(x))) * c'(x)
# f(x) = a(g(x)) => f'(x) = a'(g(x)) * g'(x)
Expand All @@ -295,7 +311,9 @@ def _fun(y, pypesto_fun, jax_fun_in, jax_fun_out):
) @ jax.jacfwd(jax_op_in)(x)
# f'(x) = a'(g(x)) * g'(x)
f_prime = jax.jacfwd(jax_op_out)(g) * g_prime
np.testing.assert_allclose(f_prime, rj, atol=atol, rtol=rtol)
np.testing.assert_allclose(
f_prime, np.asarray(rj), atol=atol, rtol=rtol
)


@pytest.fixture(
Expand Down
Loading