You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Hey y'all! I think jax[cuda12] is incompatible with tensorflow[and-cuda] and just wanted to clarify if this was expected.
The solve error I'm getting is:
Because tensorflow[and-cuda] (2.15.0) depends on nvidia-nccl-cu12 (2.16.5)
and jax[cuda12] (0.4.23) depends on nvidia-nccl-cu12 (>=2.18.3), tensorflow[and-cuda] (2.15.0) is incompatible with jax[cuda12] (0.4.23).
So, because hex-packages depends on both jax[cuda12] (0.4.23) and tensorflow[and-cuda] (2.15.0), version solving failed.
none of the jax[cuda12] versions with GPU compatibility support nvidia-nccl-cu12=2.16.5; can this requirement be looser to accomodate lower versions of nvidia-nccl-cu12?
System info (python version, jaxlib version, accelerator, etc.)
python 3.9 and python 3.10
jax versions 0.4.23, 0.4.24, 0.4.25, 0.4.26, 0.4.27, 0.4.28
cuda 12.2
v100 GPUs
The text was updated successfully, but these errors were encountered:
Could you please update TensorFlow to 2.16.1 and JAX to 0.4.28 and check if they are compatible or not. TensorFlow 2.16.1 uses nvidia-nccl-cu12=2.19.3 which is compatible with JAX's requirement of nvidia-nccl-cu12>=2.18.3. I have tested this on colab GPU Tesla T4 and it worked fine for me.
I tested this issue on GCP VM instance with ubuntu 20.04LTS and GPU Tesla P100. I have first installed JAX 0.4.28 using:
pip install -U "jax[cuda12]"
and then installed TensorFlow 2.15.1 using:
pip install -U tensorflow[and-cuda]==2.15.1
This process successfully installed both JAX and TensorFlow and both detects GPU. Please find the screenshot for reference.
As per this comment, JAX 0.4.26 has relaxed CUDA version dependencies so the minimum CUDA version for JAX is 12.1. Installing JAX 0.4.26 or later version and then installing tensorflow 2.15.1 might resolve this issue.
Description
Hey y'all! I think
jax[cuda12]
is incompatible withtensorflow[and-cuda]
and just wanted to clarify if this was expected.The solve error I'm getting is:
none of the
jax[cuda12]
versions with GPU compatibility supportnvidia-nccl-cu12=2.16.5
; can this requirement be looser to accomodate lower versions ofnvidia-nccl-cu12
?System info (python version, jaxlib version, accelerator, etc.)
python 3.9 and python 3.10
jax versions 0.4.23, 0.4.24, 0.4.25, 0.4.26, 0.4.27, 0.4.28
cuda 12.2
v100 GPUs
The text was updated successfully, but these errors were encountered: