diff --git a/pennylane/devices/default_qubit_jax.py b/pennylane/devices/default_qubit_jax.py index 3c19213851a..20df2aaaa22 100644 --- a/pennylane/devices/default_qubit_jax.py +++ b/pennylane/devices/default_qubit_jax.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""This module contains an jax implementation of the :class:`~.DefaultQubit` +"""This module contains a jax implementation of the :class:`~.DefaultQubit` reference plugin. """ # pylint: disable=ungrouped-imports @@ -30,11 +30,25 @@ from pennylane.pulse.parametrized_hamiltonian_pytree import ParametrizedHamiltonianPytree - except ImportError as e: # pragma: no cover raise ImportError("default.qubit.jax device requires installing jax>0.3.20") from e +def _validate_jax_version(): + if jax.__version__ == "0.4.4": + raise RuntimeError( + "\nYour installed version of JAX is 0.4.4 but Pennylane is incompatible with it.\n\n" + "You can either downgrade JAX to version 0.4.3 or update to a more recent version if available." + "If you downgrade, you will also need to downgrade JAXLIB to version 0.4.3 or earlier.\n" + "If you are using pip to manage your packages, you can run the following command:\n\n" + "\tpip install 'jax==0.4.3' 'jaxlib==0.4.3'\n\n" + "If you are using conda to manage your packages, you can run the following command:\n\n" + "\tconda install 'jax==0.4.3' 'jaxlib==0.4.3'\n\n" + "If you still have problems, please open an issue at the following link:\n\n" + "\thttps://github.com/PennyLaneAI/pennylane/issues\n" + ) + + class DefaultQubitJax(DefaultQubit): """Simulator plugin based on ``"default.qubit"``, written using jax. @@ -119,7 +133,7 @@ def circuit(): a = keyed_circuit(key1) b = keyed_circuit(key2) # b will be different samples now. - Check out out the `JAX random documentation `__ + Check out the `JAX random documentation `__ for more information. Args: @@ -136,6 +150,8 @@ def circuit(): """ + _validate_jax_version() + name = "Default qubit (jax) PennyLane plugin" short_name = "default.qubit.jax" @@ -163,6 +179,8 @@ def circuit(): _ndim = staticmethod(jnp.ndim) def __init__(self, wires, *, shots=None, prng_key=None, analytic=None): + _validate_jax_version() + if jax_config.read("jax_enable_x64"): c_dtype = jnp.complex128 r_dtype = jnp.float64 diff --git a/tests/devices/test_default_qubit_jax.py b/tests/devices/test_default_qubit_jax.py index 0a2dcab28d8..986cf437882 100644 --- a/tests/devices/test_default_qubit_jax.py +++ b/tests/devices/test_default_qubit_jax.py @@ -15,6 +15,7 @@ jax = pytest.importorskip("jax", minversion="0.2") jnp = jax.numpy +import jaxlib import numpy as np from jax.config import config @@ -24,6 +25,49 @@ from pennylane.pulse import ParametrizedHamiltonian +@pytest.mark.jax +@pytest.mark.parametrize( + "version, package, should_raise", + [ + ("0.4.4", jax, True), + ("0.4.3", jax, False), + ], +) +def test_jax_version(version, package, should_raise, monkeypatch): + from pennylane.devices.default_qubit_jax import _validate_jax_version + + with monkeypatch.context() as m: + m.setattr(package, "__version__", version) + + if should_raise: + msg = "version of JAX is 0.4.4" + + with pytest.raises(RuntimeError, match=msg): + _validate_jax_version() + + dev = qml.device("default.qubit", wires=1) + + with pytest.raises(RuntimeError, match=msg): + + @qml.qnode(dev, interface="jax", diff_method="backprop") + def circuit(): + return None + + with pytest.raises(RuntimeError, match=msg): + dev = qml.device("default.qubit.jax", wires=1) + + else: + _validate_jax_version() + + _ = qml.device("default.qubit.jax", wires=1) + + dev = qml.device("default.qubit", wires=1) + + @qml.qnode(dev, interface="jax") + def circuit(): + return None + + @pytest.mark.jax def test_analytic_deprecation(): """Tests if the kwarg `analytic` is used and displays error message."""