diff --git a/.github/CHANGELOG.md b/.github/CHANGELOG.md index 3bd951b0e6d..e95daf3f2de 100644 --- a/.github/CHANGELOG.md +++ b/.github/CHANGELOG.md @@ -2,6 +2,22 @@

New features since last release

+- The JAX interface now supports all devices. + [(#1076)](https://github.com/PennyLaneAI/pennylane/pull/1076) + + Here is an example of how to use JAX with Cirq: + + ```python + dev = qml.device('cirq.simulator', wires=1) + @qml.qnode(dev, interface="jax") + def circuit(x): + qml.RX(x[1], wires=0) + qml.Rot(x[0], x[1], x[2], wires=0) + return qml.expval(qml.PauliZ(0)) + weights = jnp.array([0.2, 0.5, 0.1]) + print(circuit(weights)) # DeviceArray(...) + ``` + - Added the `ControlledPhaseShift` gate as well as the `QFT` operation for applying quantum Fourier transforms. [(#1064)](https://github.com/PennyLaneAI/pennylane/pull/1064) @@ -59,7 +75,7 @@ This release contains contributions from (in alphabetical order): -Thomas Bromley, Josh Izaac, Daniel Polatajko +Thomas Bromley, Josh Izaac, Daniel Polatajko, Chase Roberts # Release 0.14.0 (current release) diff --git a/pennylane/tape/interfaces/jax.py b/pennylane/tape/interfaces/jax.py new file mode 100644 index 00000000000..061f381efc3 --- /dev/null +++ b/pennylane/tape/interfaces/jax.py @@ -0,0 +1,139 @@ +# Copyright 2018-2021 Xanadu Quantum Technologies Inc. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This module contains the mixin interface class for creating differentiable quantum tapes with +JAX. +""" +from functools import partial +import jax +import jax.experimental.host_callback as host_callback +import jax.numpy as jnp +from pennylane.tape.queuing import AnnotatedQueue +from pennylane.operation import Variance, Expectation + + +class JAXInterface(AnnotatedQueue): + """Mixin class for applying an JAX interface to a :class:`~.JacobianTape`. + + JAX-compatible quantum tape classes can be created via subclassing: + + .. code-block:: python + + class MyJAXQuantumTape(JAXInterface, JacobianTape): + + Alternatively, the JAX interface can be dynamically applied to existing + quantum tapes via the :meth:`~.apply` class method. This modifies the + tape **in place**. + + Once created, the JAX interface can be used to perform quantum-classical + differentiable programming. + + .. note:: + + If using a device that supports native JAX computation and backpropagation, such as + :class:`~.DefaultQubitJAX`, the JAX interface **does not need to be applied**. It + is only applied to tapes executed on non-JAX compatible devices. + + **Example** + + Once a JAX quantum tape has been created, it can be differentiated using JAX: + + .. code-block:: python + + tape = JAXInterface.apply(JacobianTape()) + + with tape: + qml.Rot(0, 0, 0, wires=0) + expval(qml.PauliX(0)) + + def cost_fn(x, y, z, device): + tape.set_parameters([x, y ** 2, y * np.sin(z)], trainable_only=False) + return tape.execute(device=device) + + >>> x = jnp.array(0.1, requires_grad=False) + >>> y = jnp.array(0.2, requires_grad=True) + >>> z = jnp.array(0.3, requires_grad=True) + >>> dev = qml.device("default.qubit", wires=2) + >>> cost_fn(x, y, z, device=dev) + DeviceArray([ 0.03991951], dtype=float32) + >>> jac_fn = jax.vjp(cost_fn) + >>> jac_fn(x, y, z, device=dev) + DeviceArray([[ 0.39828408, -0.00045133]], dtype=float32) + """ + + # pylint: disable=attribute-defined-outside-init + dtype = jnp.float64 + + @property + def interface(self): # pylint: disable=missing-function-docstring + return "jax" + + def _execute(self, params, device): + # TODO (chase): Add support for more than 1 measured observable. + if len(self.observables) != 1: + raise ValueError( + "The JAX interface currently only supports quantum nodes with a single return type." + ) + return_type = self.observables[0].return_type + if return_type is not Variance and return_type is not Expectation: + raise ValueError( + f"Only Variance and Expectation returns are support for the JAX interface, given {return_type}." + ) + + @jax.custom_vjp + def wrapped_exec(params): + exec_fn = partial(self.execute_device, device=device) + return host_callback.call( + exec_fn, params, result_shape=jax.ShapeDtypeStruct((1,), JAXInterface.dtype) + ) + + def wrapped_exec_fwd(params): + return wrapped_exec(params), params + + def wrapped_exec_bwd(params, g): + def jacobian(params): + tape = self.copy() + tape.set_parameters(params) + return tape.jacobian(device, params=params, **tape.jacobian_options) + + val = g.reshape((-1,)) * host_callback.call( + jacobian, + params, + result_shape=jax.ShapeDtypeStruct((1, len(params)), JAXInterface.dtype), + ) + return (list(val.reshape((-1,))),) # Comma is on purpose. + + wrapped_exec.defvjp(wrapped_exec_fwd, wrapped_exec_bwd) + return wrapped_exec(params) + + @classmethod + def apply(cls, tape): + """Apply the JAX interface to an existing tape in-place. + + Args: + tape (.JacobianTape): a quantum tape to apply the JAX interface to + + **Example** + + >>> with JacobianTape() as tape: + ... qml.RX(0.5, wires=0) + ... expval(qml.PauliZ(0)) + >>> JAXInterface.apply(tape) + >>> tape + , params=1> + """ + tape_class = getattr(tape, "__bare__", tape.__class__) + tape.__bare__ = tape_class + tape.__class__ = type("JAXQuantumTape", (cls, tape_class), {}) + return tape diff --git a/pennylane/tape/qnode.py b/pennylane/tape/qnode.py index 640887aef99..1d231af6583 100644 --- a/pennylane/tape/qnode.py +++ b/pennylane/tape/qnode.py @@ -766,13 +766,38 @@ def to_autograd(self): AutogradInterface.apply(self.qtape) def to_jax(self): - """Validation checks when a user expects to use the JAX interface.""" - if self.diff_method != "backprop": + """Apply the JAX interface to the internal quantum tape. + + Args: + dtype (tf.dtype): The dtype that the JAX QNode should + output. If not provided, the default is ``jnp.float64``. + + Raises: + .QuantumFunctionError: if TensorFlow >= 2.1 is not installed + """ + # pylint: disable=import-outside-toplevel + try: + from pennylane.tape.interfaces.jax import JAXInterface + + if self.interface != "jax" and self.interface is not None: + # Since the interface is changing, need to re-validate the tape class. + self._tape, interface, self.device, diff_options = self.get_tape( + self._original_device, "jax", self.diff_method + ) + + self.interface = interface + self.diff_options.update(diff_options) + else: + self.interface = "jax" + + if self.qtape is not None: + JAXInterface.apply(self.qtape) + + except ImportError as e: raise qml.QuantumFunctionError( - "The JAX interface can only be used with " - "diff_method='backprop' on supported devices" - ) - self.interface = "jax" + "JAX not found. Please install the latest " + "version of JAX to enable the 'jax' interface." + ) from e INTERFACE_MAP = {"autograd": to_autograd, "torch": to_torch, "tf": to_tf, "jax": to_jax} diff --git a/tests/devices/test_default_qubit_jax.py b/tests/devices/test_default_qubit_jax.py index bbc09b3f9c7..04b0ffe67b6 100644 --- a/tests/devices/test_default_qubit_jax.py +++ b/tests/devices/test_default_qubit_jax.py @@ -436,26 +436,6 @@ def cost(weights): grad = jax.grad(cost)(weights) assert grad.shape == weights.shape - def test_non_backprop_error(self): - """Test that an error is raised in tape mode if the diff method is not backprop""" - if not qml.tape_mode_active(): - pytest.skip("Test only applies in tape mode") - - dev = qml.device("default.qubit.jax", wires=2) - - def circuit(weights): - qml.RX(weights[0], wires=0) - qml.RY(weights[1], wires=1) - qml.CNOT(wires=[0, 1]) - return qml.expval(qml.PauliZ(0)) - - qnode = qml.QNode(circuit, dev, interface="jax", diff_method="parameter-shift") - weights = jnp.array([0.1, 0.2]) - - with pytest.raises(qml.QuantumFunctionError, match="The JAX interface can only be used with"): - qnode(weights) - - class TestOps: """Unit tests for operations supported by the default.qubit.jax device""" diff --git a/tests/tape/interfaces/test_qnode_jax.py b/tests/tape/interfaces/test_qnode_jax.py new file mode 100644 index 00000000000..79069073bbd --- /dev/null +++ b/tests/tape/interfaces/test_qnode_jax.py @@ -0,0 +1,203 @@ +# Copyright 2018-2021 Xanadu Quantum Technologies Inc. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Unit tests for the JAX interface""" +import pytest +jax = pytest.importorskip("jax") +jnp = pytest.importorskip("jax.numpy") +import numpy as np +import pennylane as qml +from pennylane.tape import JacobianTape, qnode, QNode, QubitParamShiftTape + +def test_qnode_intergration(): + """Test a simple use of qnode with a JAX interface and non-JAX device""" + dev = qml.device("default.mixed", wires=2) # A non-JAX device + + @qml.qnode(dev, interface="jax") + def circuit(weights): + qml.RX(weights[0], wires=0) + qml.RZ(weights[1], wires=1) + return qml.expval(qml.PauliZ(0) @ qml.PauliZ(1)) + + weights = jnp.array([0.1, 0.2]) + val = circuit(weights) + assert "DeviceArray" in val.__repr__() + +def test_to_jax(): + """Test the to_jax method""" + dev = qml.device("default.mixed", wires=2) + + @qml.qnode(dev, interface="autograd") + def circuit(weights): + qml.RX(weights[0], wires=0) + qml.RZ(weights[1], wires=1) + return qml.expval(qml.PauliZ(0) @ qml.PauliZ(1)) + + circuit.to_jax() + weights = jnp.array([0.1, 0.2]) + val = circuit(weights) + assert "DeviceArray" in val.__repr__() + + +def test_simple_jacobian(): + """Test the use of jax.jaxrev""" + dev = qml.device("default.mixed", wires=2) # A non-JAX device. + + @qml.qnode(dev, interface="jax", diff_method="parameter-shift") + def circuit(weights): + qml.RX(weights[0], wires=0) + qml.RY(weights[1], wires=1) + return qml.expval(qml.PauliZ(0) @ qml.PauliZ(1)) + + weights = jnp.array([0.1, 0.2]) + grads = jax.jacrev(circuit)(weights) + # This is the easiest way to ensure our object is a DeviceArray instead + # of a numpy array. + assert "DeviceArray" in grads.__repr__() + assert grads.shape == (2,) + np.testing.assert_allclose(grads, np.array([-0.09784342, -0.19767685])) + +def test_simple_grad(): + """Test the use of jax.grad""" + dev = qml.device("default.mixed", wires=2) # A non-JAX device. + @qml.qnode(dev, interface="jax", diff_method="parameter-shift") + def circuit(weights): + qml.RX(weights[0], wires=0) + qml.RZ(weights[1], wires=1) + return qml.expval(qml.PauliZ(0) @ qml.PauliZ(1)) + + weights = jnp.array([0.1, 0.2]) + val = jax.grad(circuit)(weights) + assert "DeviceArray" in val.__repr__() + +@pytest.mark.parametrize("diff_method", ['parameter-shift', 'finite-diff']) +def test_differentiable_expand(diff_method): + """Test that operation and nested tapes expansion + is differentiable""" + class U3(qml.U3): + def expand(self): + theta, phi, lam = self.data + wires = self.wires + + with JacobianTape() as tape: + qml.Rot(lam, theta, -lam, wires=wires) + qml.PhaseShift(phi + lam, wires=wires) + + return tape + + dev = qml.device("default.mixed", wires=1) + a = jnp.array(0.1) + p = jnp.array([0.1, 0.2, 0.3]) + + @qnode(dev, diff_method=diff_method, interface="jax") + def circuit(a, p): + qml.RX(a, wires=0) + U3(p[0], p[1], p[2], wires=0) + return qml.expval(qml.PauliX(0)) + + res = circuit(a, p) + + expected = np.cos(a) * np.cos(p[1]) * np.sin(p[0]) + np.sin(a) * ( + np.cos(p[2]) * np.sin(p[1]) + np.cos(p[0]) * np.cos(p[1]) * np.sin(p[2]) + ) + tol = 1e-5 + assert np.allclose(res, expected, atol=tol, rtol=0) + + res = jax.grad(circuit, argnums=1)(a, p) + expected = np.array( + [ + np.cos(p[1]) * (np.cos(a) * np.cos(p[0]) - np.sin(a) * np.sin(p[0]) * np.sin(p[2])), + np.cos(p[1]) * np.cos(p[2]) * np.sin(a) + - np.sin(p[1]) + * (np.cos(a) * np.sin(p[0]) + np.cos(p[0]) * np.sin(a) * np.sin(p[2])), + np.sin(a) + * (np.cos(p[0]) * np.cos(p[1]) * np.cos(p[2]) - np.sin(p[1]) * np.sin(p[2])), + ] + ) + assert np.allclose(res, expected, atol=tol, rtol=0) + +def qtransform(qnode, a, framework=jnp): + """Transforms every RY(y) gate in a circuit to RX(-a*cos(y))""" + + def construct(self, args, kwargs): + """New quantum tape construct method, that performs + the transform on the tape in a define-by-run manner""" + + t_op = [] + + QNode.construct(self, args, kwargs) + + new_ops = [] + for o in self.qtape.operations: + # here, we loop through all tape operations, and make + # the transformation if a RY gate is encountered. + if isinstance(o, qml.RY): + t_op.append(qml.RX(-a * framework.cos(o.data[0]), wires=o.wires)) + new_ops.append(t_op[-1]) + else: + new_ops.append(o) + + self.qtape._ops = new_ops + self.qtape._update() + + import copy + + new_qnode = copy.deepcopy(qnode) + new_qnode.construct = construct.__get__(new_qnode, QNode) + return new_qnode + + +@pytest.mark.parametrize( + "dev_name,diff_method", + [("default.mixed", "finite-diff"), ("default.qubit.autograd", "parameter-shift")], +) +def test_transform(dev_name, diff_method, monkeypatch, tol): + """Test an example transform""" + monkeypatch.setattr(qml.operation.Operation, "do_check_domain", False) + + dev = qml.device(dev_name, wires=1) + + @qnode(dev, interface="jax", diff_method=diff_method) + def circuit(weights): + op1 = qml.RY(weights[0], wires=0) + op2 = qml.RX(weights[1], wires=0) + return qml.expval(qml.PauliZ(wires=0)) + + weights = np.array([0.32, 0.543]) + a = np.array(0.5) + + def loss(weights, a): + # transform the circuit QNode with trainable weight 'a' + new_circuit = qtransform(circuit, a) + + # evaluate the transformed QNode + res = new_circuit(weights) + + # evaluate the original QNode with pre-processed parameters + res2 = circuit(jnp.sin(weights)) + + # return the sum of the two QNode evaluations + return res + res2 + + res = loss(weights, a) + + grad = jax.grad(loss, argnums=[0, 1])(weights, a) + assert len(grad) == 2 + assert grad[0].shape == weights.shape + assert grad[1].shape == a.shape + + # compare against the expected values + tol = 1e-5 + assert np.allclose(res, 1.8244501889992706, atol=tol, rtol=0) + assert np.allclose(grad[0], [-0.26610258, -0.47053553], atol=tol, rtol=0) + assert np.allclose(grad[1], 0.06486032, atol=tol, rtol=0) diff --git a/tests/tape/interfaces/test_tape_jax.py b/tests/tape/interfaces/test_tape_jax.py new file mode 100644 index 00000000000..234a0f3c5e8 --- /dev/null +++ b/tests/tape/interfaces/test_tape_jax.py @@ -0,0 +1,124 @@ +# Copyright 2018-2021 Xanadu Quantum Technologies Inc. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Unit tests for the JAX interface""" +import pytest +jax = pytest.importorskip("jax") +jnp = pytest.importorskip("jax.numpy") +import numpy as np +from functools import partial +import pennylane as qml +from pennylane.tape import JacobianTape +from pennylane.tape.interfaces.jax import JAXInterface + + +class TestJAXQuantumTape: + """Test the JAX interface applied to a tape""" + + def test_interface_str(self): + """Test that the interface string is correctly identified as JAX""" + with JAXInterface.apply(JacobianTape()) as tape: + qml.RX(0.5, wires=0) + qml.expval(qml.PauliX(0)) + + assert tape.interface == "jax" + assert isinstance(tape, JAXInterface) + + def test_get_parameters(self): + """Test that the get_parameters function correctly gets the trainable parameters and all + parameters, depending on the trainable_only argument""" + a = jnp.array(0.1) + b = jnp.array(0.2) + c = jnp.array(0.3) + d = jnp.array(0.4) + + with JAXInterface.apply(JacobianTape()) as tape: + qml.Rot(a, b, c, wires=0) + qml.RX(d, wires=1) + qml.CNOT(wires=[0, 1]) + qml.expval(qml.PauliX(0)) + + np.testing.assert_array_equal(tape.get_parameters(), [a, b, c, d]) + + def test_execution(self): + """Test execution""" + a = jnp.array(0.1) + b = jnp.array(0.2) + + def cost(a, b, device): + with JAXInterface.apply(JacobianTape()) as tape: + qml.RY(a, wires=0) + qml.RX(b, wires=0) + qml.expval(qml.PauliZ(0)) + return tape.execute(device) + + dev = qml.device("default.qubit", wires=1) + res = cost(a, b, device=dev) + assert res.shape == (1,) + # Easiest way to test object is a device array instead of np.array + assert "DeviceArray" in res.__repr__() + + + def test_state_raises(self): + """Test returning state raises exception""" + a = jnp.array(0.1) + b = jnp.array(0.2) + + def cost(a, b, device): + with JAXInterface.apply(JacobianTape()) as tape: + qml.RY(a, wires=0) + qml.RX(b, wires=0) + qml.state() + return tape.execute(device) + + dev = qml.device("default.qubit", wires=1) + # TODO(chase): Make this actually work and not raise an error. + with pytest.raises(ValueError): + res = cost(a, b, device=dev) + + def test_execution_with_jit(self): + """Test execution""" + a = jnp.array(0.1) + b = jnp.array(0.2) + + def cost(a, b, device): + with JAXInterface.apply(JacobianTape()) as tape: + qml.RY(a, wires=0) + qml.RX(b, wires=0) + qml.expval(qml.PauliZ(0)) + return tape.execute(device) + + # Not a JAX device! + dev = qml.device("default.qubit", wires=1) + dev_cost = partial(cost, device=dev) + res = jax.jit(dev_cost)(a, b) + assert res.shape == (1,) + # Easiest way to test object is a device array instead of np.array + assert "DeviceArray" in res.__repr__() + + def test_qnode_interface(self): + + dev = qml.device("default.mixed", wires=1) + + @qml.qnode(dev, interface="jax") + def circuit(a, b): + qml.RY(a, wires=0) + qml.RX(b, wires=0) + return qml.expval(qml.PauliZ(0)) + + a = jnp.array(0.1) + b = jnp.array(0.2) + + res = circuit(a, b) + assert "DeviceArray" in res.__repr__() +