Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Jax interface for all devices. #1076

Merged
merged 14 commits into from
Feb 9, 2021
18 changes: 17 additions & 1 deletion .github/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,22 @@

<h3>New features since last release</h3>

- The JAX interface now supports all devices.
chaserileyroberts marked this conversation as resolved.
Show resolved Hide resolved

Here is an example of how to use JAX with Cirq:

```python
dev = qml.device('cirq.simulator', wires=1)
@qml.qnode(dev, interface="jax")
def circuit(x):
qml.RX(x[1], wires=0)
qml.Rot(x[0], x[1], x[2], wires=0)
return qml.expval(qml.PauliZ(0))
weights = jnp.array([0.2, 0.5, 0.1])
print(circuit(weights)) # DeviceArray(...)
```
[(#1076)](https://github.com/PennyLaneAI/pennylane/pull/1076)
chaserileyroberts marked this conversation as resolved.
Show resolved Hide resolved

- Added the `ControlledPhaseShift` gate as well as the `QFT` operation for applying quantum Fourier
transforms.
[(#1064)](https://github.com/PennyLaneAI/pennylane/pull/1064)
Expand All @@ -18,7 +34,7 @@

This release contains contributions from (in alphabetical order):

Thomas Bromley
Chase Roberts, Thomas Bromley

# Release 0.14.0 (current release)

Expand Down
116 changes: 116 additions & 0 deletions pennylane/tape/interfaces/jax.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
# Copyright 2018-2021 Xanadu Quantum Technologies Inc.

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

# http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This module contains the mixin interface class for creating differentiable quantum tapes with
JAX.
"""
from functools import partial
import jax
import jax.experimental.host_callback as host_callback
import jax.numpy as jnp
from pennylane.tape.queuing import AnnotatedQueue
from pennylane.operation import Variance, Expectation


class JAXInterface(AnnotatedQueue):
"""Mixin class for applying an JAX interface to a :class:`~.JacobianTape`.

JAX-compatible quantum tape classes can be created via subclassing:

.. code-block:: python

class MyJAXQuantumTape(JAXInterface, JacobianTape):

Alternatively, the JAX interface can be dynamically applied to existing
quantum tapes via the :meth:`~.apply` class method. This modifies the
tape **in place**.

Once created, the JAX interface can be used to perform quantum-classical
differentiable programming.

.. note::

If using a device that supports native JAX computation and backpropagation, such as
:class:`~.DefaultQubitJAX`, the JAX interface **does not need to be applied**. It
is only applied to tapes executed on non-JAX compatible devices.

**Example**

Once a JAX quantum tape has been created, it can be differentiated using JAX:

.. code-block:: python

tape = JAXInterface.apply(JacobianTape())

with tape:
qml.Rot(0, 0, 0, wires=0)
expval(qml.PauliX(0))

def cost_fn(x, y, z, device):
tape.set_parameters([x, y ** 2, y * np.sin(z)], trainable_only=False)
return tape.execute(device=device)

>>> x = jnp.array(0.1, requires_grad=False)
>>> y = jnp.array(0.2, requires_grad=True)
>>> z = jnp.array(0.3, requires_grad=True)
>>> dev = qml.device("default.qubit", wires=2)
>>> cost_fn(x, y, z, device=dev)
DeviceArray([ 0.03991951], dtype=float32)
>>> jac_fn = jax.vjp(cost_fn)
>>> jac_fn(x, y, z, device=dev)
DeviceArray([[ 0.39828408, -0.00045133]], dtype=float32)
"""

# pylint: disable=attribute-defined-outside-init
dtype = jnp.float64

@property
def interface(self): # pylint: disable=missing-function-docstring
return "jax"

def _execute(self, params, device):
# TODO (chase): Add support for more than 1 measured observable.
mariaschuld marked this conversation as resolved.
Show resolved Hide resolved
if len(self.observables) != 1:
raise ValueError("Only one return type is supported currently")
chaserileyroberts marked this conversation as resolved.
Show resolved Hide resolved
return_type = self.observables[0].return_type
if return_type is not Variance and return_type is not Expectation:
raise ValueError(
f"Only Variance and Expectation returns are support, given {return_type}"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Again, maybe mention that this is the JAX interface speaking here, and not general PennyLane?

)
exec_fn = partial(self.execute_device, device=device)

return host_callback.call(
exec_fn, params, result_shape=jax.ShapeDtypeStruct((1,), JAXInterface.dtype)
)

@classmethod
def apply(cls, tape):
"""Apply the JAX interface to an existing tape in-place.

Args:
tape (.JacobianTape): a quantum tape to apply the JAX interface to

**Example**

>>> with JacobianTape() as tape:
... qml.RX(0.5, wires=0)
... expval(qml.PauliZ(0))
>>> JAXInterface.apply(tape)
>>> tape
<JAXQuantumTape: wires=<Wires = [0]>, params=1>
"""
tape_class = getattr(tape, "__bare__", tape.__class__)
tape.__bare__ = tape_class
tape.__class__ = type("JAXQuantumTape", (cls, tape_class), {})
return tape
37 changes: 31 additions & 6 deletions pennylane/tape/qnode.py
Original file line number Diff line number Diff line change
Expand Up @@ -754,13 +754,38 @@ def to_autograd(self):
AutogradInterface.apply(self.qtape)

def to_jax(self):
"""Validation checks when a user expects to use the JAX interface."""
if self.diff_method != "backprop":
"""Apply the JAX interface to the internal quantum tape.

Args:
dtype (tf.dtype): The dtype that the JAX QNode should
output. If not provided, the default is ``jnp.float64``.

Raises:
.QuantumFunctionError: if TensorFlow >= 2.1 is not installed
"""
# pylint: disable=import-outside-toplevel
try:
from pennylane.tape.interfaces.jax import JAXInterface

if self.interface != "jax" and self.interface is not None:
# Since the interface is changing, need to re-validate the tape class.
self._tape, interface, self.device, diff_options = self.get_tape(
self._original_device, "jax", self.diff_method
)

self.interface = interface
self.diff_options.update(diff_options)
else:
self.interface = "jax"

if self.qtape is not None:
JAXInterface.apply(self.qtape)

except ImportError as e:
raise qml.QuantumFunctionError(
"The JAX interface can only be used with "
"diff_method='backprop' on supported devices"
)
self.interface = "jax"
"JAX not found. Please install the latest "
"version of JAX to enable the 'jax' interface."
) from e

INTERFACE_MAP = {"autograd": to_autograd, "torch": to_torch, "tf": to_tf, "jax": to_jax}

Expand Down
20 changes: 0 additions & 20 deletions tests/devices/test_default_qubit_jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -436,26 +436,6 @@ def cost(weights):
grad = jax.grad(cost)(weights)
assert grad.shape == weights.shape

def test_non_backprop_error(self):
"""Test that an error is raised in tape mode if the diff method is not backprop"""
if not qml.tape_mode_active():
pytest.skip("Test only applies in tape mode")

dev = qml.device("default.qubit.jax", wires=2)

def circuit(weights):
qml.RX(weights[0], wires=0)
qml.RY(weights[1], wires=1)
qml.CNOT(wires=[0, 1])
return qml.expval(qml.PauliZ(0))

qnode = qml.QNode(circuit, dev, interface="jax", diff_method="parameter-shift")
weights = jnp.array([0.1, 0.2])

with pytest.raises(qml.QuantumFunctionError, match="The JAX interface can only be used with"):
qnode(weights)


chaserileyroberts marked this conversation as resolved.
Show resolved Hide resolved
class TestOps:
"""Unit tests for operations supported by the default.qubit.jax device"""

Expand Down
33 changes: 33 additions & 0 deletions tests/tape/interfaces/test_qnode_jax.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
# Copyright 2018-2021 Xanadu Quantum Technologies Inc.

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

# http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Unit tests for the JAX interface"""
import pytest
jax = pytest.importorskip("jax")
jnp = pytest.importorskip("jax.numpy")
import pennylane as qml
from pennylane.tape import JacobianTape, qnode, QNode, QubitParamShiftTape


def test_qnode_intergration():
mariaschuld marked this conversation as resolved.
Show resolved Hide resolved
dev = qml.device("default.mixed", wires=2) # A non-JAX device

@qml.qnode(dev, interface="jax")
def circuit(weights):
qml.RX(weights[0], wires=0)
qml.RZ(weights[0], wires=1)
return qml.expval(qml.PauliZ(0) @ qml.PauliZ(1))

weights = jnp.array([0.1, 0.2])
val = circuit(weights)
assert "DeviceArray" in val.__repr__()
124 changes: 124 additions & 0 deletions tests/tape/interfaces/test_tape_jax.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
# Copyright 2018-2021 Xanadu Quantum Technologies Inc.

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

# http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Unit tests for the JAX interface"""
import pytest
jax = pytest.importorskip("jax")
jnp = pytest.importorskip("jax.numpy")
import numpy as np
from functools import partial
import pennylane as qml
from pennylane.tape import JacobianTape
from pennylane.tape.interfaces.jax import JAXInterface


class TestJAXQuantumTape:
"""Test the JAX interface applied to a tape"""

def test_interface_str(self):
"""Test that the interface string is correctly identified as JAX"""
with JAXInterface.apply(JacobianTape()) as tape:
qml.RX(0.5, wires=0)
qml.expval(qml.PauliX(0))

assert tape.interface == "jax"
assert isinstance(tape, JAXInterface)

def test_get_parameters(self):
"""Test that the get_parameters function correctly gets the trainable parameters and all
parameters, depending on the trainable_only argument"""
a = jnp.array(0.1)
b = jnp.array(0.2)
c = jnp.array(0.3)
d = jnp.array(0.4)

with JAXInterface.apply(JacobianTape()) as tape:
qml.Rot(a, b, c, wires=0)
qml.RX(d, wires=1)
qml.CNOT(wires=[0, 1])
qml.expval(qml.PauliX(0))

np.testing.assert_array_equal(tape.get_parameters(), [a, b, c, d])

def test_execution(self):
"""Test execution"""
a = jnp.array(0.1)
b = jnp.array(0.2)

def cost(a, b, device):
with JAXInterface.apply(JacobianTape()) as tape:
qml.RY(a, wires=0)
qml.RX(b, wires=0)
qml.expval(qml.PauliZ(0))
return tape.execute(device)

dev = qml.device("default.qubit", wires=1)
res = cost(a, b, device=dev)
assert res.shape == (1,)
# Easiest way to test object is a device array instead of np.array
assert "DeviceArray" in res.__repr__()


def test_state_raises(self):
"""Test returning state raises exception"""
a = jnp.array(0.1)
b = jnp.array(0.2)

def cost(a, b, device):
with JAXInterface.apply(JacobianTape()) as tape:
qml.RY(a, wires=0)
qml.RX(b, wires=0)
qml.state()
return tape.execute(device)

dev = qml.device("default.qubit", wires=1)
# TODO(chase): Make this actually work and not raise an error.
with pytest.raises(ValueError):
res = cost(a, b, device=dev)

def test_execution_with_jit(self):
"""Test execution"""
a = jnp.array(0.1)
b = jnp.array(0.2)

def cost(a, b, device):
with JAXInterface.apply(JacobianTape()) as tape:
qml.RY(a, wires=0)
qml.RX(b, wires=0)
qml.expval(qml.PauliZ(0))
return tape.execute(device)

# Not a JAX device!
dev = qml.device("default.qubit", wires=1)
dev_cost = partial(cost, device=dev)
res = jax.jit(dev_cost)(a, b)
assert res.shape == (1,)
# Easiest way to test object is a device array instead of np.array
assert "DeviceArray" in res.__repr__()

def test_qnode_interface(self):

dev = qml.device("default.mixed", wires=1)

@qml.qnode(dev, interface="jax")
def circuit(a, b):
qml.RY(a, wires=0)
qml.RX(b, wires=0)
return qml.expval(qml.PauliZ(0))

a = jnp.array(0.1)
b = jnp.array(0.2)

res = circuit(a, b)
assert "DeviceArray" in res.__repr__()