Skip to content

Commit

Permalink
Updated scipy and jax versions to fix linalg.tri deprecation (#4103)
Browse files Browse the repository at this point in the history
* #3959 updated scipy and jax versions to fix deprecation error

* #3959 fix issue with vstack array dimensions

* style: pre-commit fixes

* #3959 use direct solver for interpolant

* Update pyproject.toml

Co-authored-by: Agriya Khetarpal <74401230+agriyakhetarpal@users.noreply.github.com>

* #3959 use jax and jaxlib 0.4.27

* #3959 revert to iterative solver for interpolator and relax test tolerances

* style: pre-commit fixes

* ruff

* #3959 Eric's comments

* #3959 reduce tolerances to fix macos-14 unit tests

* Update pyproject.toml

Co-authored-by: Agriya Khetarpal <74401230+agriyakhetarpal@users.noreply.github.com>

* #3959 relax some more tolerances

* #3959 reduce solve time in test to avoid overdischarge (aiming to fix macos failing test)

* #3959 extend solving time to reach end of discharge

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Agriya Khetarpal <74401230+agriyakhetarpal@users.noreply.github.com>
Co-authored-by: Eric G. Kratz <kratman@users.noreply.github.com>
  • Loading branch information
4 people authored May 21, 2024
1 parent 8823b40 commit 36a9caf
Show file tree
Hide file tree
Showing 5 changed files with 16 additions and 12 deletions.
6 changes: 5 additions & 1 deletion pybamm/spatial_methods/spectral_volume.py
Original file line number Diff line number Diff line change
Expand Up @@ -391,7 +391,11 @@ def penalty_matrix(self, domains):
e = np.zeros(n - 1)
e[d - 1 :: d] = 1 / submesh.d_nodes[d - 1 :: d]
sub_matrix = vstack(
[np.zeros(n), diags([-e, e], [0, 1], shape=(n - 1, n)), np.zeros(n)]
[
np.zeros((1, n)),
diags([-e, e], [0, 1], shape=(n - 1, n)),
np.zeros((1, n)),
]
)

# number of repeats
Expand Down
6 changes: 3 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ classifiers = [
]
dependencies = [
"numpy>=1.23.5",
"scipy>=1.9.3,<1.13.0",
"scipy>=1.11.4",
"casadi>=3.6.5",
"xarray>=2022.6.0",
"anytree>=2.8.0",
Expand Down Expand Up @@ -113,8 +113,8 @@ dev = [
]
# For the Jax solver. Note: these must be kept in sync with the versions defined in pybamm/util.py.
jax = [
"jax==0.4.20; python_version >= '3.9'",
"jaxlib==0.4.20; python_version >= '3.9'",
"jax==0.4.27",
"jaxlib==0.4.27",
]
# Contains all optional dependencies, except for jax and dev dependencies
all = [
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ def compare_outputs_two_phase_silicon_graphite(self, model_class):
)

sim = pybamm.Simulation(model, parameter_values=param)
t_eval = np.linspace(0, 9000, 1000)
t_eval = np.linspace(0, 8000, 1000)
inputs = [{"x": 0.01}, {"x": 0.1}]
sol = sim.solve(t_eval, inputs=inputs)

Expand Down
4 changes: 2 additions & 2 deletions tests/unit/test_expression_tree/test_interpolant.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ def f(x, y):
# check also works for cubic
interp = pybamm.Interpolant(x_in, data, (var1, var2), interpolator="cubic")
value = interp.evaluate(y=np.array([1, 5]))
np.testing.assert_equal(value, f(1, 5))
np.testing.assert_almost_equal(value, f(1, 5), decimal=3)

# Test raising error if data is not 2D
data_3d = np.zeros((11, 22, 33))
Expand Down Expand Up @@ -231,7 +231,7 @@ def f(x, y, z):
x_in, data, (var1, var2, var3), interpolator="cubic"
)
value = interp.evaluate(y=np.array([1, 5, 8]))
np.testing.assert_equal(value, f(1, 5, 8))
np.testing.assert_almost_equal(value, f(1, 5, 8), decimal=3)

# Test raising error if data is not 3D
data_4d = np.zeros((11, 22, 33, 5))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -339,7 +339,7 @@ def test_indefinite_integral(self):
phi_exact = np.ones((submesh.npts, 1))
phi_approx = int_grad_phi_disc.evaluate(None, phi_exact)
phi_approx += 1 # add constant of integration
np.testing.assert_array_equal(phi_exact, phi_approx)
np.testing.assert_array_almost_equal(phi_exact, phi_approx)
self.assertEqual(left_boundary_value_disc.evaluate(y=phi_exact), 0)
# linear case
phi_exact = submesh.nodes[:, np.newaxis]
Expand Down Expand Up @@ -379,7 +379,7 @@ def test_indefinite_integral(self):
phi_exact = np.ones((submesh.npts, 1))
phi_approx = int_grad_phi_disc.evaluate(None, phi_exact)
phi_approx += 1 # add constant of integration
np.testing.assert_array_equal(phi_exact, phi_approx)
np.testing.assert_array_almost_equal(phi_exact, phi_approx)
self.assertEqual(left_boundary_value_disc.evaluate(y=phi_exact), 0)

# linear case
Expand Down Expand Up @@ -440,7 +440,7 @@ def test_indefinite_integral(self):
c_exact = np.ones((submesh.npts, 1))
c_approx = c_integral_disc.evaluate(None, c_exact)
c_approx += 1 # add constant of integration
np.testing.assert_array_equal(c_exact, c_approx)
np.testing.assert_array_almost_equal(c_exact, c_approx)
self.assertEqual(left_boundary_value_disc.evaluate(y=c_exact), 0)

# linear case
Expand Down Expand Up @@ -488,7 +488,7 @@ def test_backward_indefinite_integral(self):
phi_exact = np.ones((submesh.npts, 1))
phi_approx = int_grad_phi_disc.evaluate(None, phi_exact)
phi_approx += 1 # add constant of integration
np.testing.assert_array_equal(phi_exact, phi_approx)
np.testing.assert_array_almost_equal(phi_exact, phi_approx)
self.assertEqual(right_boundary_value_disc.evaluate(y=phi_exact), 0)

# linear case
Expand Down Expand Up @@ -561,7 +561,7 @@ def test_indefinite_integral_on_nodes(self):
phi_exact = np.ones((submesh.npts, 1))
int_phi_exact = submesh.edges
int_phi_approx = int_phi_disc.evaluate(None, phi_exact).flatten()
np.testing.assert_array_equal(int_phi_exact, int_phi_approx)
np.testing.assert_array_almost_equal(int_phi_exact, int_phi_approx)
# linear case
phi_exact = submesh.nodes
int_phi_exact = submesh.edges**2 / 2
Expand Down

0 comments on commit 36a9caf

Please sign in to comment.