Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Raise error if JAX or jaxlib are 0.4.4 #3813

Merged
merged 18 commits into from
Feb 27, 2023
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 19 additions & 3 deletions pennylane/devices/default_qubit_jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""This module contains an jax implementation of the :class:`~.DefaultQubit`
"""This module contains a jax implementation of the :class:`~.DefaultQubit`
reference plugin.
"""
# pylint: disable=ungrouped-imports
Expand All @@ -30,11 +30,25 @@

from pennylane.pulse.parametrized_hamiltonian_pytree import ParametrizedHamiltonianPytree


except ImportError as e: # pragma: no cover
raise ImportError("default.qubit.jax device requires installing jax>0.3.20") from e


def _validate_jax_version():
import jaxlib # pylint:disable=import-outside-toplevel
lillian542 marked this conversation as resolved.
Show resolved Hide resolved

if jax.__version__ == "0.4.4":
raise RuntimeError(
"The current JAX installation is 0.4.4. The JAX implementation for default.qubit requires "
"version 0.4.3 or lower for JAX."
lillian542 marked this conversation as resolved.
Show resolved Hide resolved
)
if jaxlib.__version__ == "0.4.4":
raise RuntimeError(
"The current jaxlib installation is 0.4.4. The JAX implementation for default.qubit "
"requires version 0.4.3 or lower for jaxlib."
)
rmoyard marked this conversation as resolved.
Show resolved Hide resolved


class DefaultQubitJax(DefaultQubit):
"""Simulator plugin based on ``"default.qubit"``, written using jax.

Expand Down Expand Up @@ -119,7 +133,7 @@ def circuit():
a = keyed_circuit(key1)
b = keyed_circuit(key2) # b will be different samples now.

Check out out the `JAX random documentation <https://jax.readthedocs.io/en/latest/jax.random.html>`__
Check out the `JAX random documentation <https://jax.readthedocs.io/en/latest/jax.random.html>`__
for more information.

Args:
Expand All @@ -136,6 +150,8 @@ def circuit():

"""

_validate_jax_version()

name = "Default qubit (jax) PennyLane plugin"
short_name = "default.qubit.jax"

Expand Down
44 changes: 44 additions & 0 deletions tests/devices/test_default_qubit_jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

jax = pytest.importorskip("jax", minversion="0.2")
jnp = jax.numpy
import jaxlib
import numpy as np
from jax.config import config

Expand All @@ -24,6 +25,49 @@
from pennylane.pulse import ParametrizedHamiltonian


@pytest.mark.jax
@pytest.mark.parametrize(
"version, package, should_raise",
[
("0.4.4", jax, True),
("0.4.3", jax, False),
("0.4.4", jaxlib, True),
("0.4.3", jaxlib, False),
],
)
def test_jax_version(version, package, should_raise, monkeypatch):
lillian542 marked this conversation as resolved.
Show resolved Hide resolved
from pennylane.devices.default_qubit_jax import _validate_jax_version

with monkeypatch.context() as m:
m.setattr(package, "__version__", version)

if should_raise:
msg = "installation is 0.4.4"

with pytest.raises(RuntimeError, match=msg):
_validate_jax_version()

with pytest.raises(RuntimeError, match=msg):
dev = qml.device("default.qubit", wires=1)

@qml.qnode(dev, interface="jax")
def circuit():
return None

with pytest.raises(RuntimeError, match=msg):
_ = qml.device("default.qubit.jax", wires=1)
else:
_validate_jax_version()

_ = qml.device("default.qubit.jax", wires=1)

dev = qml.device("default.qubit", wires=1)

@qml.qnode(dev, interface="jax")
def circuit():
return None


@pytest.mark.jax
def test_analytic_deprecation():
"""Tests if the kwarg `analytic` is used and displays error message."""
Expand Down