Skip to content

Commit

Permalink
Merge pull request #146 from jhelgert/fix_jax_example
Browse files Browse the repository at this point in the history
Improve the jax example
  • Loading branch information
moorepants authored Feb 26, 2022
2 parents 8767d08 + 97c6aa2 commit 5e371be
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 10 deletions.
16 changes: 13 additions & 3 deletions docs/source/tutorial.rst
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,14 @@ and the optimal solution,
We start by importing all required libraries::

from jax.config import config

# Enable 64 bit floating point precision
config.update("jax_enable_x64", True)

# We use the CPU instead of GPU und mute all warnings if no GPU/TPU is found.
config.update('jax_platform_name', 'cpu')

import jax.numpy as np
from jax import jit, grad, jacfwd
from cyipopt import minimize_ipopt
Expand Down Expand Up @@ -102,11 +110,13 @@ Next, we build the derivatives and just-in-time (jit) compile the functions
Finally, we can call ``minimize_ipopt`` similar to ``scipy.optimize.minimize``::

# constraints
cons = [{'type': 'eq', 'fun': con_eq_jit, 'jac': con_eq_jac, 'hess': con_eq_hess},
{'type': 'ineq', 'fun': con_ineq_jit, 'jac': con_ineq_jac, 'hess': con_ineq_hess}]
cons = [
{'type': 'eq', 'fun': con_eq_jit, 'jac': con_eq_jac, 'hess': con_eq_hessvp},
{'type': 'ineq', 'fun': con_ineq_jit, 'jac': con_ineq_jac, 'hess': con_ineq_hessvp}
]

# starting point
x0 = np.array([1, 5, 5, 1])
x0 = np.array([1.0, 5.0, 5.0, 1.0])

# variable bounds: 1 <= x[i] <= 5
bnds = [(1, 5) for _ in range(x0.size)]
Expand Down
20 changes: 13 additions & 7 deletions examples/hs071_scipy_jax.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,15 @@
import jax.numpy as np
import jax
from jax import jit, grad, jacfwd, jacrev
from jax.config import config

# Enable 64 bit floating point precision
config.update("jax_enable_x64", True)

# We use the CPU instead of GPU und mute all warnings if no GPU/TPU is found.
config.update('jax_platform_name', 'cpu')

from cyipopt import minimize_ipopt
from jax import jit, grad, jacrev, jacfwd
import jax.numpy as np


# Test the scipy interface on the Hock & Schittkowski test problem 71:
#
Expand All @@ -14,9 +22,6 @@
# We evaluate all derivatives (except the Hessian) by algorithmic differentation
# by means of the JAX library.

# We use the CPU instead of GPU und mute all warnings if no GPU/TPU is found.
jax.config.update('jax_platform_name', 'cpu')


def objective(x):
return x[0]*x[3]*np.sum(x[:3]) + x[2]
Expand Down Expand Up @@ -46,13 +51,14 @@ def ineq_constrains(x):
con_ineq_hessvp = jit(lambda x, v: con_ineq_hess(x) * v[0]) # hessian vector-product

# constraints
# Note that 'hess' is the hessian-vector-product
cons = [
{'type': 'eq', 'fun': con_eq_jit, 'jac': con_eq_jac, 'hess': con_eq_hessvp},
{'type': 'ineq', 'fun': con_ineq_jit, 'jac': con_ineq_jac, 'hess': con_ineq_hessvp},
]

# initial guess
x0 = np.array([1, 5, 5, 1])
x0 = np.array([1.0, 5.0, 5.0, 1.0])

# variable bounds: 1 <= x[i] <= 5
bnds = [(1, 5) for _ in range(x0.size)]
Expand Down

0 comments on commit 5e371be

Please sign in to comment.