From c638479d040107242bb2374608bdcbed359c82f1 Mon Sep 17 00:00:00 2001 From: lillian542 Date: Thu, 23 Feb 2023 14:10:44 -0500 Subject: [PATCH 01/10] Raise error for 0.4.4 jaxlib and jax --- pennylane/devices/default_qubit_jax.py | 22 +++++++++++++++++++--- 1 file changed, 19 insertions(+), 3 deletions(-) diff --git a/pennylane/devices/default_qubit_jax.py b/pennylane/devices/default_qubit_jax.py index 3c19213851a..4fd7090b21f 100644 --- a/pennylane/devices/default_qubit_jax.py +++ b/pennylane/devices/default_qubit_jax.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""This module contains an jax implementation of the :class:`~.DefaultQubit` +"""This module contains a jax implementation of the :class:`~.DefaultQubit` reference plugin. """ # pylint: disable=ungrouped-imports @@ -30,11 +30,25 @@ from pennylane.pulse.parametrized_hamiltonian_pytree import ParametrizedHamiltonianPytree - except ImportError as e: # pragma: no cover raise ImportError("default.qubit.jax device requires installing jax>0.3.20") from e +def _validate_jax_version(): + import jaxlib # pylint:disable=import-outside-toplevel + + if jax.__version__ == "0.4.4": + raise RuntimeError( + "The current JAX installation is 0.4.4. The JAX implementation for default.qubit requires " + "version 0.4.3 or lower for JAX." + ) + if jaxlib.__version__ == "0.4.4": + raise RuntimeError( + "The current jaxlib installation is 0.4.4. The JAX implementation for default.qubit " + "requires version 0.4.3 or lower for jaxlib." + ) + + class DefaultQubitJax(DefaultQubit): """Simulator plugin based on ``"default.qubit"``, written using jax. @@ -119,7 +133,7 @@ def circuit(): a = keyed_circuit(key1) b = keyed_circuit(key2) # b will be different samples now. - Check out out the `JAX random documentation `__ + Check out the `JAX random documentation `__ for more information. Args: @@ -136,6 +150,8 @@ def circuit(): """ + _validate_jax_version() + name = "Default qubit (jax) PennyLane plugin" short_name = "default.qubit.jax" From b3caca178a869ebe18b86f0e42b5da872e4864bd Mon Sep 17 00:00:00 2001 From: lillian542 Date: Thu, 23 Feb 2023 14:11:57 -0500 Subject: [PATCH 02/10] add tests --- tests/devices/test_default_qubit_jax.py | 44 +++++++++++++++++++++++++ 1 file changed, 44 insertions(+) diff --git a/tests/devices/test_default_qubit_jax.py b/tests/devices/test_default_qubit_jax.py index 0a2dcab28d8..24cc87ab331 100644 --- a/tests/devices/test_default_qubit_jax.py +++ b/tests/devices/test_default_qubit_jax.py @@ -15,6 +15,7 @@ jax = pytest.importorskip("jax", minversion="0.2") jnp = jax.numpy +import jaxlib import numpy as np from jax.config import config @@ -24,6 +25,49 @@ from pennylane.pulse import ParametrizedHamiltonian +@pytest.mark.parametrize( + "version, package, should_raise", + [ + ("0.4.4", jax, True), + ("0.4.3", jax, False), + ("0.4.4", jaxlib, True), + ("0.4.3", jaxlib, False), + ], +) +def test_jax_version(version, package, should_raise, monkeypatch): + + from pennylane.devices.default_qubit_jax import _validate_jax_version + + with monkeypatch.context() as m: + m.setattr(package, "__version__", version) + + if should_raise: + msg = "installation is 0.4.4" + + with pytest.raises(RuntimeError, match=msg): + _validate_jax_version() + + with pytest.raises(RuntimeError, match=msg): + dev = qml.device("default.qubit", wires=1) + + @qml.qnode(dev, interface="jax") + def circuit(): + return None + + with pytest.raises(RuntimeError, match=msg): + _ = qml.device("default.qubit.jax", wires=1) + else: + _validate_jax_version() + + _ = qml.device("default.qubit.jax", wires=1) + + dev = qml.device("default.qubit", wires=1) + + @qml.qnode(dev, interface="jax") + def circuit(): + return None + + @pytest.mark.jax def test_analytic_deprecation(): """Tests if the kwarg `analytic` is used and displays error message.""" From acdc893e8937a4bb70877cb48091f512758aae05 Mon Sep 17 00:00:00 2001 From: lillian542 Date: Thu, 23 Feb 2023 14:51:00 -0500 Subject: [PATCH 03/10] black for tests --- tests/devices/test_default_qubit_jax.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/devices/test_default_qubit_jax.py b/tests/devices/test_default_qubit_jax.py index 24cc87ab331..bd06fd00e7f 100644 --- a/tests/devices/test_default_qubit_jax.py +++ b/tests/devices/test_default_qubit_jax.py @@ -35,7 +35,6 @@ ], ) def test_jax_version(version, package, should_raise, monkeypatch): - from pennylane.devices.default_qubit_jax import _validate_jax_version with monkeypatch.context() as m: From f8858b3e7940a123b8cefa795a5600e17ee80eb7 Mon Sep 17 00:00:00 2001 From: lillian542 Date: Thu, 23 Feb 2023 16:35:17 -0500 Subject: [PATCH 04/10] add jax mark for test --- tests/devices/test_default_qubit_jax.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/devices/test_default_qubit_jax.py b/tests/devices/test_default_qubit_jax.py index bd06fd00e7f..459646e5a73 100644 --- a/tests/devices/test_default_qubit_jax.py +++ b/tests/devices/test_default_qubit_jax.py @@ -25,6 +25,7 @@ from pennylane.pulse import ParametrizedHamiltonian +@pytest.mark.jax @pytest.mark.parametrize( "version, package, should_raise", [ From a01277312b52c8728100f1e6247c236893bed933 Mon Sep 17 00:00:00 2001 From: lillian542 <38584660+lillian542@users.noreply.github.com> Date: Fri, 24 Feb 2023 10:07:28 -0500 Subject: [PATCH 05/10] Remove jaxlib check Co-authored-by: Filippo Vicentini --- pennylane/devices/default_qubit_jax.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/pennylane/devices/default_qubit_jax.py b/pennylane/devices/default_qubit_jax.py index 4fd7090b21f..c900619432b 100644 --- a/pennylane/devices/default_qubit_jax.py +++ b/pennylane/devices/default_qubit_jax.py @@ -35,18 +35,12 @@ def _validate_jax_version(): - import jaxlib # pylint:disable=import-outside-toplevel if jax.__version__ == "0.4.4": raise RuntimeError( "The current JAX installation is 0.4.4. The JAX implementation for default.qubit requires " "version 0.4.3 or lower for JAX." ) - if jaxlib.__version__ == "0.4.4": - raise RuntimeError( - "The current jaxlib installation is 0.4.4. The JAX implementation for default.qubit " - "requires version 0.4.3 or lower for jaxlib." - ) class DefaultQubitJax(DefaultQubit): From 2e76d75042cbc25348f9f911371ea29465ef1f78 Mon Sep 17 00:00:00 2001 From: lillian542 <38584660+lillian542@users.noreply.github.com> Date: Fri, 24 Feb 2023 10:07:50 -0500 Subject: [PATCH 06/10] More detailed error msg Co-authored-by: Filippo Vicentini --- pennylane/devices/default_qubit_jax.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/pennylane/devices/default_qubit_jax.py b/pennylane/devices/default_qubit_jax.py index c900619432b..6f7b56f69bb 100644 --- a/pennylane/devices/default_qubit_jax.py +++ b/pennylane/devices/default_qubit_jax.py @@ -38,8 +38,15 @@ def _validate_jax_version(): if jax.__version__ == "0.4.4": raise RuntimeError( - "The current JAX installation is 0.4.4. The JAX implementation for default.qubit requires " - "version 0.4.3 or lower for JAX." + "\nYour installed version of JAX is 0.4.4 but Pennylane is incompatible with it.\n\n" + "You can either downgrade JAX to version 0.4.3 or update to a more recent version if available." + "If you downgrade, you will also need to downgrade JAXLIB to version 0.4.3 or earlier.\n" + "If you are using pip to manage your packages, you can run the following command:\n\n" + "\tpip install 'jax==0.4.3' 'jaxlib==0.4.3'\n\n" + "If you are using conda to manage your packages, you can run the following command:\n\n" + "\tconda install 'jax==0.4.3' 'jaxlib==0.4.3'\n\n" + "If you still have problems, please open an issue at the following link:\n\n" + "\thttps://github.com/PennyLaneAI/pennylane/issues\n" ) From abd49b38c9aa2ed0e82a38c1d34641c63e8152dc Mon Sep 17 00:00:00 2001 From: lillian542 Date: Fri, 24 Feb 2023 11:02:24 -0500 Subject: [PATCH 07/10] Update tests --- tests/devices/test_default_qubit_jax.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/tests/devices/test_default_qubit_jax.py b/tests/devices/test_default_qubit_jax.py index 459646e5a73..752a4eccf87 100644 --- a/tests/devices/test_default_qubit_jax.py +++ b/tests/devices/test_default_qubit_jax.py @@ -31,8 +31,6 @@ [ ("0.4.4", jax, True), ("0.4.3", jax, False), - ("0.4.4", jaxlib, True), - ("0.4.3", jaxlib, False), ], ) def test_jax_version(version, package, should_raise, monkeypatch): @@ -42,7 +40,7 @@ def test_jax_version(version, package, should_raise, monkeypatch): m.setattr(package, "__version__", version) if should_raise: - msg = "installation is 0.4.4" + msg = "version of JAX is 0.4.4" with pytest.raises(RuntimeError, match=msg): _validate_jax_version() From 8f0625c94efb438e88029dcc1cb1f4732fcc4244 Mon Sep 17 00:00:00 2001 From: rmoyard Date: Fri, 24 Feb 2023 14:00:32 -0500 Subject: [PATCH 08/10] Tests and black --- pennylane/devices/default_qubit_jax.py | 1 - tests/devices/test_default_qubit_jax.py | 10 +--------- 2 files changed, 1 insertion(+), 10 deletions(-) diff --git a/pennylane/devices/default_qubit_jax.py b/pennylane/devices/default_qubit_jax.py index 6f7b56f69bb..c9a7c9f7a01 100644 --- a/pennylane/devices/default_qubit_jax.py +++ b/pennylane/devices/default_qubit_jax.py @@ -35,7 +35,6 @@ def _validate_jax_version(): - if jax.__version__ == "0.4.4": raise RuntimeError( "\nYour installed version of JAX is 0.4.4 but Pennylane is incompatible with it.\n\n" diff --git a/tests/devices/test_default_qubit_jax.py b/tests/devices/test_default_qubit_jax.py index 752a4eccf87..4f8196ba971 100644 --- a/tests/devices/test_default_qubit_jax.py +++ b/tests/devices/test_default_qubit_jax.py @@ -38,6 +38,7 @@ def test_jax_version(version, package, should_raise, monkeypatch): with monkeypatch.context() as m: m.setattr(package, "__version__", version) + print(jax.__version__) if should_raise: msg = "version of JAX is 0.4.4" @@ -45,15 +46,6 @@ def test_jax_version(version, package, should_raise, monkeypatch): with pytest.raises(RuntimeError, match=msg): _validate_jax_version() - with pytest.raises(RuntimeError, match=msg): - dev = qml.device("default.qubit", wires=1) - - @qml.qnode(dev, interface="jax") - def circuit(): - return None - - with pytest.raises(RuntimeError, match=msg): - _ = qml.device("default.qubit.jax", wires=1) else: _validate_jax_version() From 9a63b6cf89f91bcfea72224d29766db7c0ae61b3 Mon Sep 17 00:00:00 2001 From: rmoyard Date: Fri, 24 Feb 2023 14:02:23 -0500 Subject: [PATCH 09/10] Typo --- tests/devices/test_default_qubit_jax.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/devices/test_default_qubit_jax.py b/tests/devices/test_default_qubit_jax.py index 4f8196ba971..8a368b0a343 100644 --- a/tests/devices/test_default_qubit_jax.py +++ b/tests/devices/test_default_qubit_jax.py @@ -38,7 +38,6 @@ def test_jax_version(version, package, should_raise, monkeypatch): with monkeypatch.context() as m: m.setattr(package, "__version__", version) - print(jax.__version__) if should_raise: msg = "version of JAX is 0.4.4" From afb384400f81dd8afe8f3c5e9bd9ac2cd4573a8d Mon Sep 17 00:00:00 2001 From: rmoyard Date: Fri, 24 Feb 2023 14:08:10 -0500 Subject: [PATCH 10/10] Change --- pennylane/devices/default_qubit_jax.py | 2 ++ tests/devices/test_default_qubit_jax.py | 11 +++++++++++ 2 files changed, 13 insertions(+) diff --git a/pennylane/devices/default_qubit_jax.py b/pennylane/devices/default_qubit_jax.py index c9a7c9f7a01..20df2aaaa22 100644 --- a/pennylane/devices/default_qubit_jax.py +++ b/pennylane/devices/default_qubit_jax.py @@ -179,6 +179,8 @@ def circuit(): _ndim = staticmethod(jnp.ndim) def __init__(self, wires, *, shots=None, prng_key=None, analytic=None): + _validate_jax_version() + if jax_config.read("jax_enable_x64"): c_dtype = jnp.complex128 r_dtype = jnp.float64 diff --git a/tests/devices/test_default_qubit_jax.py b/tests/devices/test_default_qubit_jax.py index 8a368b0a343..986cf437882 100644 --- a/tests/devices/test_default_qubit_jax.py +++ b/tests/devices/test_default_qubit_jax.py @@ -45,6 +45,17 @@ def test_jax_version(version, package, should_raise, monkeypatch): with pytest.raises(RuntimeError, match=msg): _validate_jax_version() + dev = qml.device("default.qubit", wires=1) + + with pytest.raises(RuntimeError, match=msg): + + @qml.qnode(dev, interface="jax", diff_method="backprop") + def circuit(): + return None + + with pytest.raises(RuntimeError, match=msg): + dev = qml.device("default.qubit.jax", wires=1) + else: _validate_jax_version()