Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cannot import gpjax after installing tensorflow #397

Closed
rgiordan opened this issue Oct 4, 2023 · 3 comments
Closed

Cannot import gpjax after installing tensorflow #397

rgiordan opened this issue Oct 4, 2023 · 3 comments
Labels
bug Something isn't working

Comments

@rgiordan
Copy link

rgiordan commented Oct 4, 2023

Bug Report

gpjax==0.7.0

Current behavior:

gpjax cannot be imported after installing tensorflow==2.14.0 using pip.

Expected behavior:

I should be able to run

$ python3 -m pip install gpjax
$ python3 -m pip install tensorflow
$ python3 -c 'import gpjax'

Steps to reproduce:

$ python3 -m pip install gpjax
... installs successfully
$ python3 -m pip freeze # see pip freeze 1 below
$ python3 -c 'import gpjax'
... imports successfully.

$ python3 -m pip install tensorflow
... installs successfully
$ python3 -m pip freeze # see pip freeze 2 below
$ python3 -c 'import gpjax'
...results in:

2023-10-04 15:02:18.661354: E tensorflow/compiler/xla/stream_executor/cuda/cuda_dnn.cc:9342] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2023-10-04 15:02:18.661385: E tensorflow/compiler/xla/stream_executor/cuda/cuda_fft.cc:609] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2023-10-04 15:02:18.661402: E tensorflow/compiler/xla/stream_executor/cuda/cuda_blas.cc:1518] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2023-10-04 15:02:19.078384: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
Traceback (most recent call last):
  File "<string>", line 1, in <module>
  File "/tmp/venv/lib/python3.10/site-packages/gpjax/__init__.py", line 16, in <module>
    from gpjax.base import (
  File "/tmp/venv/lib/python3.10/site-packages/gpjax/base/__init__.py", line 16, in <module>
    from gpjax.base.module import (
  File "/tmp/venv/lib/python3.10/site-packages/gpjax/base/module.py", line 49, in <module>
    import tensorflow_probability.substrates.jax.bijectors as tfb
  File "/tmp/venv/lib/python3.10/site-packages/tensorflow_probability/substrates/jax/__init__.py", line 41, in <module>
    from tensorflow_probability.substrates.jax import bijectors
  File "/tmp/venv/lib/python3.10/site-packages/tensorflow_probability/substrates/jax/bijectors/__init__.py", line 19, in <module>
    from tensorflow_probability.substrates.jax.bijectors.absolute_value import AbsoluteValue
  File "/tmp/venv/lib/python3.10/site-packages/tensorflow_probability/substrates/jax/bijectors/absolute_value.py", line 19, in <module>
    from tensorflow_probability.substrates.jax.bijectors import bijector
  File "/tmp/venv/lib/python3.10/site-packages/tensorflow_probability/substrates/jax/bijectors/bijector.py", line 26, in <module>
    from tensorflow_probability.substrates.jax.internal import batch_shape_lib
  File "/tmp/venv/lib/python3.10/site-packages/tensorflow_probability/substrates/jax/internal/batch_shape_lib.py", line 23, in <module>
    from tensorflow_probability.substrates.jax.internal import prefer_static as ps
  File "/tmp/venv/lib/python3.10/site-packages/tensorflow_probability/substrates/jax/internal/prefer_static.py", line 361, in <module>
    ones_like = _copy_docstring(tf.ones_like, _ones_like)
  File "/tmp/venv/lib/python3.10/site-packages/tensorflow_probability/substrates/jax/internal/prefer_static.py", line 84, in _copy_docstring
    raise ValueError(
ValueError: Arg specs do not match: original=FullArgSpec(args=['input', 'dtype', 'name', 'layout'], varargs=None, varkw=None, defaults=(None, None, None), kwonlyargs=[], kwonlydefaults=None, annotations={}), new=FullArgSpec(args=['input', 'dtype', 'name'], varargs=None, varkw=None, defaults=(None, None), kwonlyargs=[], kwonlydefaults=None, annotations={}), fn=<function ones_like_v2 at 0x7fa3958cd3f0>

Related code:

Other information:

After the above steps, I am able to import tensorflow successfully:
$ python3 -c 'import tensorflow' # succeeds

Pip freeze 1:

absl-py==2.0.0
beartype==0.13.1
chex==0.1.83
cloudpickle==2.2.1
cola-ml==0.0.1
cola-plum-dispatch==0.1.1
decorator==5.1.1
dm-tree==0.1.8
etils==1.5.0
fsspec==2023.9.2
gast==0.5.4
gpjax==0.7.0
importlib-resources==6.1.0
jax==0.4.17
jaxlib==0.4.17
jaxtyping==0.2.22
ml-dtypes==0.3.1
msgpack==1.0.7
nest-asyncio==1.5.8
numpy==1.26.0
opt-einsum==3.3.0
optax==0.1.7
orbax-checkpoint==0.4.1
protobuf==4.24.4
PyYAML==6.0.1
scipy==1.11.3
simple-pytree==0.1.7
six==1.16.0
tensorflow-probability==0.19.0
tensorstore==0.1.45
toolz==0.12.0
tqdm==4.66.1
typeguard==4.1.5
typing_extensions==4.8.0
zipp==3.17.0

Pip freeze 2:

absl-py==2.0.0
astunparse==1.6.3
beartype==0.13.1
cachetools==5.3.1
certifi==2023.7.22
charset-normalizer==3.3.0
chex==0.1.83
cloudpickle==2.2.1
cola-ml==0.0.1
cola-plum-dispatch==0.1.1
decorator==5.1.1
dm-tree==0.1.8
etils==1.5.0
flatbuffers==23.5.26
fsspec==2023.9.2
gast==0.5.4
google-auth==2.23.2
google-auth-oauthlib==1.0.0
google-pasta==0.2.0
gpjax==0.7.0
grpcio==1.59.0
h5py==3.9.0
idna==3.4
importlib-resources==6.1.0
jax==0.4.17
jaxlib==0.4.17
jaxtyping==0.2.22
keras==2.14.0
libclang==16.0.6
Markdown==3.4.4
MarkupSafe==2.1.3
ml-dtypes==0.2.0
msgpack==1.0.7
nest-asyncio==1.5.8
numpy==1.26.0
oauthlib==3.2.2
opt-einsum==3.3.0
optax==0.1.7
orbax-checkpoint==0.4.1
packaging==23.2
protobuf==4.24.4
pyasn1==0.5.0
pyasn1-modules==0.3.0
PyYAML==6.0.1
requests==2.31.0
requests-oauthlib==1.3.1
rsa==4.9
scipy==1.11.3
simple-pytree==0.1.7
six==1.16.0
tensorboard==2.14.1
tensorboard-data-server==0.7.1
tensorflow==2.14.0
tensorflow-estimator==2.14.0
tensorflow-io-gcs-filesystem==0.34.0
tensorflow-probability==0.19.0
tensorstore==0.1.45
termcolor==2.3.0
toolz==0.12.0
tqdm==4.66.1
typeguard==4.1.5
typing_extensions==4.8.0
urllib3==2.0.6
Werkzeug==3.0.0
wrapt==1.14.1
zipp==3.17.0

@rgiordan rgiordan added the bug Something isn't working label Oct 4, 2023
@daniel-dodd
Copy link
Member

Thanks @rgiordan. This seems to me like an issue with our tensorflow_probability.substrates.jax dependency at v0.20.0 with tensorflow v2.14.0. Bumping tensorflow_probability to 0.22.0 seems to resolve issues with tensorflow v2.14.0, but then it seems we run into conflicts with our some of our other dependencies.

@vabor112
Copy link

vabor112 commented Mar 8, 2024

I have also experienced that. Had to downgrade tensorflow in the end. It would be great if you could upgrade to a newer version of tensorflow_probability in a feature releases.

@thomaspinder
Copy link
Collaborator

This should now be resolved through #442

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

4 participants