-
-
Notifications
You must be signed in to change notification settings - Fork 547
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
JAX BDF solver tests failing / update [jax]
versions (due to scipy.linalg.tril
deprecation)
#3959
Comments
scipy.linalg.tril
deprecation)[jax]
versions (due to scipy.linalg.tril
deprecation)
It's probably not as trivial as bumping the JAX version because there are a few other errors that I don't understand with JAX's JIT and spectral volumes, so I'm putting this aside for a bit to return to soon and let others proceed if there is progress |
Bumping to v0.4.24 fixes at least part of the tests, earlier versions still have the SciPy error |
It is worthwhile to bump jax up as high as possible. We have people that are experienced with Jax that might be able to help. We are going to get into more compatibility issues as the code ages |
I agree with you – v0.4.26 is their latest release, should we drop the pin altogether? It might break on v0.5.X, so having |
Pinning is fine so there are not unexpected changes. Realistically we should have all major dependencies pinned. Something like dependabot should do the updates so the failures are all in one place |
Do you need help with this one? |
We shouldn't pin to exact versions as that may cause compatibility issues for our users (if they try to use pybamm + another package that happens to pin e.g. numpy to a different version). We can specify ranges but they should be as wide as possible |
jax is an exception where we have to pin the exact version since every release changes the API |
I would appreciate that, being someone who hasn't used JAX a lot. I was able to get the tests to pass with newer versions of JAX (some of those can be ignored because it's probably not caching the solves properly on my machine). Some spatial methods tests are still failing, where I received
To add to this, we have been keeping the lower bounds in sync with the versions of the packages available on conda-forge (too much of a lower bound brought some trouble earlier during the time of the PyBaMM 23.9 release). It might make sense to drop Python 3.8 soon since it has been passing due to the use of deprecated code? |
I was planning on putting up a PR for that this week. Seemed to align with the removal of ODEs and the removal of the Jax windows restrictions. I will probably just go ahead and make that PR while helping with the Jax stuff. I should have a bit of time to take a look this afternoon. Just share the branch you are working on and I will see what I can do to help out |
I don't have a branch or anything concrete, I was debugging only locally. I'll add the link here once I get back to it |
Yeah let's follow numpy's lead for which python versions we support, they have dropped support for 3.8 |
I was checking this issue, hasn't this been solved by the PRs Eric referenced above? When I looked at the CI tests seem to be passing. |
Ah, that is still one part of the issue. The other thing is that we still need to unpin SciPy which is currently set to <1.13.0, IIRC. |
So, if we unpin SciPy then tests fail, right? Is there any branch where this is done so I can see the errors? |
Yes. I tried only locally last time and I was just going to open up a PR to show you the logs, but I'm facing a strange error locally right now:
and |
Edit: I see that you opened a PR just at the time I commented :) |
* #3959 updated scipy and jax versions to fix deprecation error * #3959 fix issue with vstack array dimensions * style: pre-commit fixes * #3959 use direct solver for interpolant * Update pyproject.toml Co-authored-by: Agriya Khetarpal <74401230+agriyakhetarpal@users.noreply.github.com> * #3959 use jax and jaxlib 0.4.27 * #3959 revert to iterative solver for interpolator and relax test tolerances * style: pre-commit fixes * ruff * #3959 Eric's comments * #3959 reduce tolerances to fix macos-14 unit tests * Update pyproject.toml Co-authored-by: Agriya Khetarpal <74401230+agriyakhetarpal@users.noreply.github.com> * #3959 relax some more tolerances * #3959 reduce solve time in test to avoid overdischarge (aiming to fix macos failing test) * #3959 extend solving time to reach end of discharge --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Agriya Khetarpal <74401230+agriyakhetarpal@users.noreply.github.com> Co-authored-by: Eric G. Kratz <kratman@users.noreply.github.com>
…m-team#4103) * pybamm-team#3959 updated scipy and jax versions to fix deprecation error * pybamm-team#3959 fix issue with vstack array dimensions * style: pre-commit fixes * pybamm-team#3959 use direct solver for interpolant * Update pyproject.toml Co-authored-by: Agriya Khetarpal <74401230+agriyakhetarpal@users.noreply.github.com> * pybamm-team#3959 use jax and jaxlib 0.4.27 * pybamm-team#3959 revert to iterative solver for interpolator and relax test tolerances * style: pre-commit fixes * ruff * pybamm-team#3959 Eric's comments * pybamm-team#3959 reduce tolerances to fix macos-14 unit tests * Update pyproject.toml Co-authored-by: Agriya Khetarpal <74401230+agriyakhetarpal@users.noreply.github.com> * pybamm-team#3959 relax some more tolerances * pybamm-team#3959 reduce solve time in test to avoid overdischarge (aiming to fix macos failing test) * pybamm-team#3959 extend solving time to reach end of discharge --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Agriya Khetarpal <74401230+agriyakhetarpal@users.noreply.github.com> Co-authored-by: Eric G. Kratz <kratman@users.noreply.github.com>
The JAX BDF solver tests are failing on all PRs (#3846, #3945, etc.) for Python 3.9 and later because SciPy removed some linear algebra routines in v1.13.0. The Python 3.8 tests are passing because SciPy has dropped support for it earlier
I'm guessing we need to bump the
jax
andjaxlib
versions now or relax the pin in the requirements, because there have been quite many releases sincev0.4.20
– the current version available at the time of writing isv0.4.25
.Checklist
linalg.tri
deprecation #4103)The text was updated successfully, but these errors were encountered: