Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Jax interface for all devices. #1076

Merged
merged 14 commits into from
Feb 9, 2021
18 changes: 17 additions & 1 deletion .github/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,22 @@

<h3>New features since last release</h3>

- The JAX interface now supports all devices.
chaserileyroberts marked this conversation as resolved.
Show resolved Hide resolved
[(#1076)](https://github.com/PennyLaneAI/pennylane/pull/1076)

Here is an example of how to use JAX with Cirq:

```python
dev = qml.device('cirq.simulator', wires=1)
@qml.qnode(dev, interface="jax")
def circuit(x):
qml.RX(x[1], wires=0)
qml.Rot(x[0], x[1], x[2], wires=0)
return qml.expval(qml.PauliZ(0))
weights = jnp.array([0.2, 0.5, 0.1])
print(circuit(weights)) # DeviceArray(...)
```

- Added the `ControlledPhaseShift` gate as well as the `QFT` operation for applying quantum Fourier
transforms.
[(#1064)](https://github.com/PennyLaneAI/pennylane/pull/1064)
Expand Down Expand Up @@ -59,7 +75,7 @@

This release contains contributions from (in alphabetical order):

Thomas Bromley, Josh Izaac, Daniel Polatajko
Thomas Bromley, Josh Izaac, Daniel Polatajko, Chase Roberts

# Release 0.14.0 (current release)

Expand Down
139 changes: 139 additions & 0 deletions pennylane/tape/interfaces/jax.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
# Copyright 2018-2021 Xanadu Quantum Technologies Inc.

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

# http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This module contains the mixin interface class for creating differentiable quantum tapes with
JAX.
"""
from functools import partial
import jax
import jax.experimental.host_callback as host_callback
import jax.numpy as jnp
from pennylane.tape.queuing import AnnotatedQueue
from pennylane.operation import Variance, Expectation


class JAXInterface(AnnotatedQueue):
"""Mixin class for applying an JAX interface to a :class:`~.JacobianTape`.
JAX-compatible quantum tape classes can be created via subclassing:
.. code-block:: python
class MyJAXQuantumTape(JAXInterface, JacobianTape):
Alternatively, the JAX interface can be dynamically applied to existing
quantum tapes via the :meth:`~.apply` class method. This modifies the
tape **in place**.
Once created, the JAX interface can be used to perform quantum-classical
differentiable programming.
.. note::
If using a device that supports native JAX computation and backpropagation, such as
:class:`~.DefaultQubitJAX`, the JAX interface **does not need to be applied**. It
is only applied to tapes executed on non-JAX compatible devices.
**Example**
Once a JAX quantum tape has been created, it can be differentiated using JAX:
.. code-block:: python
tape = JAXInterface.apply(JacobianTape())
with tape:
qml.Rot(0, 0, 0, wires=0)
expval(qml.PauliX(0))
def cost_fn(x, y, z, device):
tape.set_parameters([x, y ** 2, y * np.sin(z)], trainable_only=False)
return tape.execute(device=device)
>>> x = jnp.array(0.1, requires_grad=False)
>>> y = jnp.array(0.2, requires_grad=True)
>>> z = jnp.array(0.3, requires_grad=True)
>>> dev = qml.device("default.qubit", wires=2)
>>> cost_fn(x, y, z, device=dev)
DeviceArray([ 0.03991951], dtype=float32)
>>> jac_fn = jax.vjp(cost_fn)
>>> jac_fn(x, y, z, device=dev)
DeviceArray([[ 0.39828408, -0.00045133]], dtype=float32)
"""

# pylint: disable=attribute-defined-outside-init
dtype = jnp.float64

@property
def interface(self): # pylint: disable=missing-function-docstring
return "jax"

def _execute(self, params, device):
# TODO (chase): Add support for more than 1 measured observable.
mariaschuld marked this conversation as resolved.
Show resolved Hide resolved
if len(self.observables) != 1:
raise ValueError(
"The JAX interface currently only supports quantum nodes with a single return type."
)
return_type = self.observables[0].return_type
if return_type is not Variance and return_type is not Expectation:
raise ValueError(
f"Only Variance and Expectation returns are support for the JAX interface, given {return_type}."
)

@jax.custom_vjp
def wrapped_exec(params):
exec_fn = partial(self.execute_device, device=device)
return host_callback.call(
exec_fn, params, result_shape=jax.ShapeDtypeStruct((1,), JAXInterface.dtype)
)

def wrapped_exec_fwd(params):
return wrapped_exec(params), params

def wrapped_exec_bwd(params, g):
def jacobian(params):
tape = self.copy()
tape.set_parameters(params)
return tape.jacobian(device, params=params, **tape.jacobian_options)

val = g.reshape((-1,)) * host_callback.call(
jacobian,
params,
result_shape=jax.ShapeDtypeStruct((1, len(params)), JAXInterface.dtype),
)
return (list(val.reshape((-1,))),) # Comma is on purpose.

wrapped_exec.defvjp(wrapped_exec_fwd, wrapped_exec_bwd)
return wrapped_exec(params)

@classmethod
def apply(cls, tape):
"""Apply the JAX interface to an existing tape in-place.
Args:
tape (.JacobianTape): a quantum tape to apply the JAX interface to
**Example**
>>> with JacobianTape() as tape:
... qml.RX(0.5, wires=0)
... expval(qml.PauliZ(0))
>>> JAXInterface.apply(tape)
>>> tape
<JAXQuantumTape: wires=<Wires = [0]>, params=1>
"""
tape_class = getattr(tape, "__bare__", tape.__class__)
tape.__bare__ = tape_class
tape.__class__ = type("JAXQuantumTape", (cls, tape_class), {})
return tape
37 changes: 31 additions & 6 deletions pennylane/tape/qnode.py
Original file line number Diff line number Diff line change
Expand Up @@ -766,13 +766,38 @@ def to_autograd(self):
AutogradInterface.apply(self.qtape)

def to_jax(self):
"""Validation checks when a user expects to use the JAX interface."""
if self.diff_method != "backprop":
"""Apply the JAX interface to the internal quantum tape.
Args:
dtype (tf.dtype): The dtype that the JAX QNode should
output. If not provided, the default is ``jnp.float64``.
Raises:
.QuantumFunctionError: if TensorFlow >= 2.1 is not installed
"""
# pylint: disable=import-outside-toplevel
try:
from pennylane.tape.interfaces.jax import JAXInterface

if self.interface != "jax" and self.interface is not None:
# Since the interface is changing, need to re-validate the tape class.
self._tape, interface, self.device, diff_options = self.get_tape(
self._original_device, "jax", self.diff_method
)

self.interface = interface
self.diff_options.update(diff_options)
else:
self.interface = "jax"

if self.qtape is not None:
JAXInterface.apply(self.qtape)

except ImportError as e:
raise qml.QuantumFunctionError(
"The JAX interface can only be used with "
"diff_method='backprop' on supported devices"
)
self.interface = "jax"
"JAX not found. Please install the latest "
"version of JAX to enable the 'jax' interface."
) from e

INTERFACE_MAP = {"autograd": to_autograd, "torch": to_torch, "tf": to_tf, "jax": to_jax}

Expand Down
20 changes: 0 additions & 20 deletions tests/devices/test_default_qubit_jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -436,26 +436,6 @@ def cost(weights):
grad = jax.grad(cost)(weights)
assert grad.shape == weights.shape

def test_non_backprop_error(self):
"""Test that an error is raised in tape mode if the diff method is not backprop"""
if not qml.tape_mode_active():
pytest.skip("Test only applies in tape mode")

dev = qml.device("default.qubit.jax", wires=2)

def circuit(weights):
qml.RX(weights[0], wires=0)
qml.RY(weights[1], wires=1)
qml.CNOT(wires=[0, 1])
return qml.expval(qml.PauliZ(0))

qnode = qml.QNode(circuit, dev, interface="jax", diff_method="parameter-shift")
weights = jnp.array([0.1, 0.2])

with pytest.raises(qml.QuantumFunctionError, match="The JAX interface can only be used with"):
qnode(weights)


chaserileyroberts marked this conversation as resolved.
Show resolved Hide resolved
class TestOps:
"""Unit tests for operations supported by the default.qubit.jax device"""

Expand Down
Loading