diff --git a/.github/CHANGELOG.md b/.github/CHANGELOG.md
index 3bd951b0e6d..e95daf3f2de 100644
--- a/.github/CHANGELOG.md
+++ b/.github/CHANGELOG.md
@@ -2,6 +2,22 @@
New features since last release
+- The JAX interface now supports all devices.
+ [(#1076)](https://github.com/PennyLaneAI/pennylane/pull/1076)
+
+ Here is an example of how to use JAX with Cirq:
+
+ ```python
+ dev = qml.device('cirq.simulator', wires=1)
+ @qml.qnode(dev, interface="jax")
+ def circuit(x):
+ qml.RX(x[1], wires=0)
+ qml.Rot(x[0], x[1], x[2], wires=0)
+ return qml.expval(qml.PauliZ(0))
+ weights = jnp.array([0.2, 0.5, 0.1])
+ print(circuit(weights)) # DeviceArray(...)
+ ```
+
- Added the `ControlledPhaseShift` gate as well as the `QFT` operation for applying quantum Fourier
transforms.
[(#1064)](https://github.com/PennyLaneAI/pennylane/pull/1064)
@@ -59,7 +75,7 @@
This release contains contributions from (in alphabetical order):
-Thomas Bromley, Josh Izaac, Daniel Polatajko
+Thomas Bromley, Josh Izaac, Daniel Polatajko, Chase Roberts
# Release 0.14.0 (current release)
diff --git a/pennylane/tape/interfaces/jax.py b/pennylane/tape/interfaces/jax.py
new file mode 100644
index 00000000000..061f381efc3
--- /dev/null
+++ b/pennylane/tape/interfaces/jax.py
@@ -0,0 +1,139 @@
+# Copyright 2018-2021 Xanadu Quantum Technologies Inc.
+
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+
+# http://www.apache.org/licenses/LICENSE-2.0
+
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+This module contains the mixin interface class for creating differentiable quantum tapes with
+JAX.
+"""
+from functools import partial
+import jax
+import jax.experimental.host_callback as host_callback
+import jax.numpy as jnp
+from pennylane.tape.queuing import AnnotatedQueue
+from pennylane.operation import Variance, Expectation
+
+
+class JAXInterface(AnnotatedQueue):
+ """Mixin class for applying an JAX interface to a :class:`~.JacobianTape`.
+
+ JAX-compatible quantum tape classes can be created via subclassing:
+
+ .. code-block:: python
+
+ class MyJAXQuantumTape(JAXInterface, JacobianTape):
+
+ Alternatively, the JAX interface can be dynamically applied to existing
+ quantum tapes via the :meth:`~.apply` class method. This modifies the
+ tape **in place**.
+
+ Once created, the JAX interface can be used to perform quantum-classical
+ differentiable programming.
+
+ .. note::
+
+ If using a device that supports native JAX computation and backpropagation, such as
+ :class:`~.DefaultQubitJAX`, the JAX interface **does not need to be applied**. It
+ is only applied to tapes executed on non-JAX compatible devices.
+
+ **Example**
+
+ Once a JAX quantum tape has been created, it can be differentiated using JAX:
+
+ .. code-block:: python
+
+ tape = JAXInterface.apply(JacobianTape())
+
+ with tape:
+ qml.Rot(0, 0, 0, wires=0)
+ expval(qml.PauliX(0))
+
+ def cost_fn(x, y, z, device):
+ tape.set_parameters([x, y ** 2, y * np.sin(z)], trainable_only=False)
+ return tape.execute(device=device)
+
+ >>> x = jnp.array(0.1, requires_grad=False)
+ >>> y = jnp.array(0.2, requires_grad=True)
+ >>> z = jnp.array(0.3, requires_grad=True)
+ >>> dev = qml.device("default.qubit", wires=2)
+ >>> cost_fn(x, y, z, device=dev)
+ DeviceArray([ 0.03991951], dtype=float32)
+ >>> jac_fn = jax.vjp(cost_fn)
+ >>> jac_fn(x, y, z, device=dev)
+ DeviceArray([[ 0.39828408, -0.00045133]], dtype=float32)
+ """
+
+ # pylint: disable=attribute-defined-outside-init
+ dtype = jnp.float64
+
+ @property
+ def interface(self): # pylint: disable=missing-function-docstring
+ return "jax"
+
+ def _execute(self, params, device):
+ # TODO (chase): Add support for more than 1 measured observable.
+ if len(self.observables) != 1:
+ raise ValueError(
+ "The JAX interface currently only supports quantum nodes with a single return type."
+ )
+ return_type = self.observables[0].return_type
+ if return_type is not Variance and return_type is not Expectation:
+ raise ValueError(
+ f"Only Variance and Expectation returns are support for the JAX interface, given {return_type}."
+ )
+
+ @jax.custom_vjp
+ def wrapped_exec(params):
+ exec_fn = partial(self.execute_device, device=device)
+ return host_callback.call(
+ exec_fn, params, result_shape=jax.ShapeDtypeStruct((1,), JAXInterface.dtype)
+ )
+
+ def wrapped_exec_fwd(params):
+ return wrapped_exec(params), params
+
+ def wrapped_exec_bwd(params, g):
+ def jacobian(params):
+ tape = self.copy()
+ tape.set_parameters(params)
+ return tape.jacobian(device, params=params, **tape.jacobian_options)
+
+ val = g.reshape((-1,)) * host_callback.call(
+ jacobian,
+ params,
+ result_shape=jax.ShapeDtypeStruct((1, len(params)), JAXInterface.dtype),
+ )
+ return (list(val.reshape((-1,))),) # Comma is on purpose.
+
+ wrapped_exec.defvjp(wrapped_exec_fwd, wrapped_exec_bwd)
+ return wrapped_exec(params)
+
+ @classmethod
+ def apply(cls, tape):
+ """Apply the JAX interface to an existing tape in-place.
+
+ Args:
+ tape (.JacobianTape): a quantum tape to apply the JAX interface to
+
+ **Example**
+
+ >>> with JacobianTape() as tape:
+ ... qml.RX(0.5, wires=0)
+ ... expval(qml.PauliZ(0))
+ >>> JAXInterface.apply(tape)
+ >>> tape
+ , params=1>
+ """
+ tape_class = getattr(tape, "__bare__", tape.__class__)
+ tape.__bare__ = tape_class
+ tape.__class__ = type("JAXQuantumTape", (cls, tape_class), {})
+ return tape
diff --git a/pennylane/tape/qnode.py b/pennylane/tape/qnode.py
index 640887aef99..1d231af6583 100644
--- a/pennylane/tape/qnode.py
+++ b/pennylane/tape/qnode.py
@@ -766,13 +766,38 @@ def to_autograd(self):
AutogradInterface.apply(self.qtape)
def to_jax(self):
- """Validation checks when a user expects to use the JAX interface."""
- if self.diff_method != "backprop":
+ """Apply the JAX interface to the internal quantum tape.
+
+ Args:
+ dtype (tf.dtype): The dtype that the JAX QNode should
+ output. If not provided, the default is ``jnp.float64``.
+
+ Raises:
+ .QuantumFunctionError: if TensorFlow >= 2.1 is not installed
+ """
+ # pylint: disable=import-outside-toplevel
+ try:
+ from pennylane.tape.interfaces.jax import JAXInterface
+
+ if self.interface != "jax" and self.interface is not None:
+ # Since the interface is changing, need to re-validate the tape class.
+ self._tape, interface, self.device, diff_options = self.get_tape(
+ self._original_device, "jax", self.diff_method
+ )
+
+ self.interface = interface
+ self.diff_options.update(diff_options)
+ else:
+ self.interface = "jax"
+
+ if self.qtape is not None:
+ JAXInterface.apply(self.qtape)
+
+ except ImportError as e:
raise qml.QuantumFunctionError(
- "The JAX interface can only be used with "
- "diff_method='backprop' on supported devices"
- )
- self.interface = "jax"
+ "JAX not found. Please install the latest "
+ "version of JAX to enable the 'jax' interface."
+ ) from e
INTERFACE_MAP = {"autograd": to_autograd, "torch": to_torch, "tf": to_tf, "jax": to_jax}
diff --git a/tests/devices/test_default_qubit_jax.py b/tests/devices/test_default_qubit_jax.py
index bbc09b3f9c7..04b0ffe67b6 100644
--- a/tests/devices/test_default_qubit_jax.py
+++ b/tests/devices/test_default_qubit_jax.py
@@ -436,26 +436,6 @@ def cost(weights):
grad = jax.grad(cost)(weights)
assert grad.shape == weights.shape
- def test_non_backprop_error(self):
- """Test that an error is raised in tape mode if the diff method is not backprop"""
- if not qml.tape_mode_active():
- pytest.skip("Test only applies in tape mode")
-
- dev = qml.device("default.qubit.jax", wires=2)
-
- def circuit(weights):
- qml.RX(weights[0], wires=0)
- qml.RY(weights[1], wires=1)
- qml.CNOT(wires=[0, 1])
- return qml.expval(qml.PauliZ(0))
-
- qnode = qml.QNode(circuit, dev, interface="jax", diff_method="parameter-shift")
- weights = jnp.array([0.1, 0.2])
-
- with pytest.raises(qml.QuantumFunctionError, match="The JAX interface can only be used with"):
- qnode(weights)
-
-
class TestOps:
"""Unit tests for operations supported by the default.qubit.jax device"""
diff --git a/tests/tape/interfaces/test_qnode_jax.py b/tests/tape/interfaces/test_qnode_jax.py
new file mode 100644
index 00000000000..79069073bbd
--- /dev/null
+++ b/tests/tape/interfaces/test_qnode_jax.py
@@ -0,0 +1,203 @@
+# Copyright 2018-2021 Xanadu Quantum Technologies Inc.
+
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+
+# http://www.apache.org/licenses/LICENSE-2.0
+
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Unit tests for the JAX interface"""
+import pytest
+jax = pytest.importorskip("jax")
+jnp = pytest.importorskip("jax.numpy")
+import numpy as np
+import pennylane as qml
+from pennylane.tape import JacobianTape, qnode, QNode, QubitParamShiftTape
+
+def test_qnode_intergration():
+ """Test a simple use of qnode with a JAX interface and non-JAX device"""
+ dev = qml.device("default.mixed", wires=2) # A non-JAX device
+
+ @qml.qnode(dev, interface="jax")
+ def circuit(weights):
+ qml.RX(weights[0], wires=0)
+ qml.RZ(weights[1], wires=1)
+ return qml.expval(qml.PauliZ(0) @ qml.PauliZ(1))
+
+ weights = jnp.array([0.1, 0.2])
+ val = circuit(weights)
+ assert "DeviceArray" in val.__repr__()
+
+def test_to_jax():
+ """Test the to_jax method"""
+ dev = qml.device("default.mixed", wires=2)
+
+ @qml.qnode(dev, interface="autograd")
+ def circuit(weights):
+ qml.RX(weights[0], wires=0)
+ qml.RZ(weights[1], wires=1)
+ return qml.expval(qml.PauliZ(0) @ qml.PauliZ(1))
+
+ circuit.to_jax()
+ weights = jnp.array([0.1, 0.2])
+ val = circuit(weights)
+ assert "DeviceArray" in val.__repr__()
+
+
+def test_simple_jacobian():
+ """Test the use of jax.jaxrev"""
+ dev = qml.device("default.mixed", wires=2) # A non-JAX device.
+
+ @qml.qnode(dev, interface="jax", diff_method="parameter-shift")
+ def circuit(weights):
+ qml.RX(weights[0], wires=0)
+ qml.RY(weights[1], wires=1)
+ return qml.expval(qml.PauliZ(0) @ qml.PauliZ(1))
+
+ weights = jnp.array([0.1, 0.2])
+ grads = jax.jacrev(circuit)(weights)
+ # This is the easiest way to ensure our object is a DeviceArray instead
+ # of a numpy array.
+ assert "DeviceArray" in grads.__repr__()
+ assert grads.shape == (2,)
+ np.testing.assert_allclose(grads, np.array([-0.09784342, -0.19767685]))
+
+def test_simple_grad():
+ """Test the use of jax.grad"""
+ dev = qml.device("default.mixed", wires=2) # A non-JAX device.
+ @qml.qnode(dev, interface="jax", diff_method="parameter-shift")
+ def circuit(weights):
+ qml.RX(weights[0], wires=0)
+ qml.RZ(weights[1], wires=1)
+ return qml.expval(qml.PauliZ(0) @ qml.PauliZ(1))
+
+ weights = jnp.array([0.1, 0.2])
+ val = jax.grad(circuit)(weights)
+ assert "DeviceArray" in val.__repr__()
+
+@pytest.mark.parametrize("diff_method", ['parameter-shift', 'finite-diff'])
+def test_differentiable_expand(diff_method):
+ """Test that operation and nested tapes expansion
+ is differentiable"""
+ class U3(qml.U3):
+ def expand(self):
+ theta, phi, lam = self.data
+ wires = self.wires
+
+ with JacobianTape() as tape:
+ qml.Rot(lam, theta, -lam, wires=wires)
+ qml.PhaseShift(phi + lam, wires=wires)
+
+ return tape
+
+ dev = qml.device("default.mixed", wires=1)
+ a = jnp.array(0.1)
+ p = jnp.array([0.1, 0.2, 0.3])
+
+ @qnode(dev, diff_method=diff_method, interface="jax")
+ def circuit(a, p):
+ qml.RX(a, wires=0)
+ U3(p[0], p[1], p[2], wires=0)
+ return qml.expval(qml.PauliX(0))
+
+ res = circuit(a, p)
+
+ expected = np.cos(a) * np.cos(p[1]) * np.sin(p[0]) + np.sin(a) * (
+ np.cos(p[2]) * np.sin(p[1]) + np.cos(p[0]) * np.cos(p[1]) * np.sin(p[2])
+ )
+ tol = 1e-5
+ assert np.allclose(res, expected, atol=tol, rtol=0)
+
+ res = jax.grad(circuit, argnums=1)(a, p)
+ expected = np.array(
+ [
+ np.cos(p[1]) * (np.cos(a) * np.cos(p[0]) - np.sin(a) * np.sin(p[0]) * np.sin(p[2])),
+ np.cos(p[1]) * np.cos(p[2]) * np.sin(a)
+ - np.sin(p[1])
+ * (np.cos(a) * np.sin(p[0]) + np.cos(p[0]) * np.sin(a) * np.sin(p[2])),
+ np.sin(a)
+ * (np.cos(p[0]) * np.cos(p[1]) * np.cos(p[2]) - np.sin(p[1]) * np.sin(p[2])),
+ ]
+ )
+ assert np.allclose(res, expected, atol=tol, rtol=0)
+
+def qtransform(qnode, a, framework=jnp):
+ """Transforms every RY(y) gate in a circuit to RX(-a*cos(y))"""
+
+ def construct(self, args, kwargs):
+ """New quantum tape construct method, that performs
+ the transform on the tape in a define-by-run manner"""
+
+ t_op = []
+
+ QNode.construct(self, args, kwargs)
+
+ new_ops = []
+ for o in self.qtape.operations:
+ # here, we loop through all tape operations, and make
+ # the transformation if a RY gate is encountered.
+ if isinstance(o, qml.RY):
+ t_op.append(qml.RX(-a * framework.cos(o.data[0]), wires=o.wires))
+ new_ops.append(t_op[-1])
+ else:
+ new_ops.append(o)
+
+ self.qtape._ops = new_ops
+ self.qtape._update()
+
+ import copy
+
+ new_qnode = copy.deepcopy(qnode)
+ new_qnode.construct = construct.__get__(new_qnode, QNode)
+ return new_qnode
+
+
+@pytest.mark.parametrize(
+ "dev_name,diff_method",
+ [("default.mixed", "finite-diff"), ("default.qubit.autograd", "parameter-shift")],
+)
+def test_transform(dev_name, diff_method, monkeypatch, tol):
+ """Test an example transform"""
+ monkeypatch.setattr(qml.operation.Operation, "do_check_domain", False)
+
+ dev = qml.device(dev_name, wires=1)
+
+ @qnode(dev, interface="jax", diff_method=diff_method)
+ def circuit(weights):
+ op1 = qml.RY(weights[0], wires=0)
+ op2 = qml.RX(weights[1], wires=0)
+ return qml.expval(qml.PauliZ(wires=0))
+
+ weights = np.array([0.32, 0.543])
+ a = np.array(0.5)
+
+ def loss(weights, a):
+ # transform the circuit QNode with trainable weight 'a'
+ new_circuit = qtransform(circuit, a)
+
+ # evaluate the transformed QNode
+ res = new_circuit(weights)
+
+ # evaluate the original QNode with pre-processed parameters
+ res2 = circuit(jnp.sin(weights))
+
+ # return the sum of the two QNode evaluations
+ return res + res2
+
+ res = loss(weights, a)
+
+ grad = jax.grad(loss, argnums=[0, 1])(weights, a)
+ assert len(grad) == 2
+ assert grad[0].shape == weights.shape
+ assert grad[1].shape == a.shape
+
+ # compare against the expected values
+ tol = 1e-5
+ assert np.allclose(res, 1.8244501889992706, atol=tol, rtol=0)
+ assert np.allclose(grad[0], [-0.26610258, -0.47053553], atol=tol, rtol=0)
+ assert np.allclose(grad[1], 0.06486032, atol=tol, rtol=0)
diff --git a/tests/tape/interfaces/test_tape_jax.py b/tests/tape/interfaces/test_tape_jax.py
new file mode 100644
index 00000000000..234a0f3c5e8
--- /dev/null
+++ b/tests/tape/interfaces/test_tape_jax.py
@@ -0,0 +1,124 @@
+# Copyright 2018-2021 Xanadu Quantum Technologies Inc.
+
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+
+# http://www.apache.org/licenses/LICENSE-2.0
+
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Unit tests for the JAX interface"""
+import pytest
+jax = pytest.importorskip("jax")
+jnp = pytest.importorskip("jax.numpy")
+import numpy as np
+from functools import partial
+import pennylane as qml
+from pennylane.tape import JacobianTape
+from pennylane.tape.interfaces.jax import JAXInterface
+
+
+class TestJAXQuantumTape:
+ """Test the JAX interface applied to a tape"""
+
+ def test_interface_str(self):
+ """Test that the interface string is correctly identified as JAX"""
+ with JAXInterface.apply(JacobianTape()) as tape:
+ qml.RX(0.5, wires=0)
+ qml.expval(qml.PauliX(0))
+
+ assert tape.interface == "jax"
+ assert isinstance(tape, JAXInterface)
+
+ def test_get_parameters(self):
+ """Test that the get_parameters function correctly gets the trainable parameters and all
+ parameters, depending on the trainable_only argument"""
+ a = jnp.array(0.1)
+ b = jnp.array(0.2)
+ c = jnp.array(0.3)
+ d = jnp.array(0.4)
+
+ with JAXInterface.apply(JacobianTape()) as tape:
+ qml.Rot(a, b, c, wires=0)
+ qml.RX(d, wires=1)
+ qml.CNOT(wires=[0, 1])
+ qml.expval(qml.PauliX(0))
+
+ np.testing.assert_array_equal(tape.get_parameters(), [a, b, c, d])
+
+ def test_execution(self):
+ """Test execution"""
+ a = jnp.array(0.1)
+ b = jnp.array(0.2)
+
+ def cost(a, b, device):
+ with JAXInterface.apply(JacobianTape()) as tape:
+ qml.RY(a, wires=0)
+ qml.RX(b, wires=0)
+ qml.expval(qml.PauliZ(0))
+ return tape.execute(device)
+
+ dev = qml.device("default.qubit", wires=1)
+ res = cost(a, b, device=dev)
+ assert res.shape == (1,)
+ # Easiest way to test object is a device array instead of np.array
+ assert "DeviceArray" in res.__repr__()
+
+
+ def test_state_raises(self):
+ """Test returning state raises exception"""
+ a = jnp.array(0.1)
+ b = jnp.array(0.2)
+
+ def cost(a, b, device):
+ with JAXInterface.apply(JacobianTape()) as tape:
+ qml.RY(a, wires=0)
+ qml.RX(b, wires=0)
+ qml.state()
+ return tape.execute(device)
+
+ dev = qml.device("default.qubit", wires=1)
+ # TODO(chase): Make this actually work and not raise an error.
+ with pytest.raises(ValueError):
+ res = cost(a, b, device=dev)
+
+ def test_execution_with_jit(self):
+ """Test execution"""
+ a = jnp.array(0.1)
+ b = jnp.array(0.2)
+
+ def cost(a, b, device):
+ with JAXInterface.apply(JacobianTape()) as tape:
+ qml.RY(a, wires=0)
+ qml.RX(b, wires=0)
+ qml.expval(qml.PauliZ(0))
+ return tape.execute(device)
+
+ # Not a JAX device!
+ dev = qml.device("default.qubit", wires=1)
+ dev_cost = partial(cost, device=dev)
+ res = jax.jit(dev_cost)(a, b)
+ assert res.shape == (1,)
+ # Easiest way to test object is a device array instead of np.array
+ assert "DeviceArray" in res.__repr__()
+
+ def test_qnode_interface(self):
+
+ dev = qml.device("default.mixed", wires=1)
+
+ @qml.qnode(dev, interface="jax")
+ def circuit(a, b):
+ qml.RY(a, wires=0)
+ qml.RX(b, wires=0)
+ return qml.expval(qml.PauliZ(0))
+
+ a = jnp.array(0.1)
+ b = jnp.array(0.2)
+
+ res = circuit(a, b)
+ assert "DeviceArray" in res.__repr__()
+