Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

jax[cuda12] compatibility with tensorflow[and-cuda] 2.15.0/2.15.1 #21335

Closed
attaluris opened this issue May 21, 2024 · 3 comments
Closed

jax[cuda12] compatibility with tensorflow[and-cuda] 2.15.0/2.15.1 #21335

attaluris opened this issue May 21, 2024 · 3 comments
Labels
bug Something isn't working

Comments

@attaluris
Copy link

Description

Hey y'all! I think jax[cuda12] is incompatible with tensorflow[and-cuda] and just wanted to clarify if this was expected.
The solve error I'm getting is:

Because tensorflow[and-cuda] (2.15.0) depends on nvidia-nccl-cu12 (2.16.5)
 and jax[cuda12] (0.4.23) depends on nvidia-nccl-cu12 (>=2.18.3), tensorflow[and-cuda] (2.15.0) is incompatible with jax[cuda12] (0.4.23).
So, because hex-packages depends on both jax[cuda12] (0.4.23) and tensorflow[and-cuda] (2.15.0), version solving failed.

none of the jax[cuda12] versions with GPU compatibility support nvidia-nccl-cu12=2.16.5; can this requirement be looser to accomodate lower versions of nvidia-nccl-cu12?

System info (python version, jaxlib version, accelerator, etc.)

python 3.9 and python 3.10
jax versions 0.4.23, 0.4.24, 0.4.25, 0.4.26, 0.4.27, 0.4.28
cuda 12.2
v100 GPUs

@attaluris attaluris added the bug Something isn't working label May 21, 2024
@rajasekharporeddy
Copy link
Contributor

Hi @attaluris

Could you please update TensorFlow to 2.16.1 and JAX to 0.4.28 and check if they are compatible or not. TensorFlow 2.16.1 uses nvidia-nccl-cu12=2.19.3 which is compatible with JAX's requirement of nvidia-nccl-cu12>=2.18.3. I have tested this on colab GPU Tesla T4 and it worked fine for me.

Thank you.

@rajasekharporeddy
Copy link
Contributor

Hi @attaluris

I tested this issue on GCP VM instance with ubuntu 20.04LTS and GPU Tesla P100. I have first installed JAX 0.4.28 using:

pip install -U "jax[cuda12]"

and then installed TensorFlow 2.15.1 using:

pip install -U tensorflow[and-cuda]==2.15.1

This process successfully installed both JAX and TensorFlow and both detects GPU. Please find the screenshot for reference.

image

As per this comment, JAX 0.4.26 has relaxed CUDA version dependencies so the minimum CUDA version for JAX is 12.1. Installing JAX 0.4.26 or later version and then installing tensorflow 2.15.1 might resolve this issue.

Thank you.

@hawkinsp
Copy link
Collaborator

Yes. JAX has much more relaxed version dependencies in recent releases. If you upgrade TensorFlow to 2.16.1, then the problem should be resolved.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

3 participants