Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Circuit cutting: add mid-circuit measurement integration tests #2234

Merged
merged 24 commits into from
Feb 28, 2022
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
73acf20
Add test
trbromley Feb 25, 2022
11c9091
Add to changelog
trbromley Feb 25, 2022
447af24
Add tests in different interfaces
trbromley Feb 25, 2022
a88da66
Work on TF Jit
trbromley Feb 25, 2022
b4314e4
Remove JIT tests
trbromley Feb 25, 2022
f3e8a56
Fix bug in changing wires of MeasurementProcesses
trbromley Feb 25, 2022
5b69cb5
Fix graph_to_tape
trbromley Feb 25, 2022
f0f3839
Add supporting tests
trbromley Feb 25, 2022
7192f3d
Parametrize over use_opt_einsum
trbromley Feb 25, 2022
0cc78d7
Merge branch 'qcut_integ_tests' into qcut_integ_tests_2
trbromley Feb 25, 2022
4ac5f7b
Copy observable to prevent in-place modification
trbromley Feb 25, 2022
6f68b95
Add test
trbromley Feb 25, 2022
89dd999
Remove added line
trbromley Feb 25, 2022
01e507c
Add to changelog
trbromley Feb 25, 2022
b4a9d66
Merge branch 'master' into qcut_integ_tests
anthayes92 Feb 26, 2022
23e7c63
Remove mention of jit
trbromley Feb 27, 2022
c7b8532
Merge branch 'qcut_integ_tests' into qcut_integ_tests_2
trbromley Feb 27, 2022
9bf43a6
Remove isinstance check
trbromley Feb 27, 2022
57c247d
Merge branch 'qcut_integ_tests' into qcut_integ_tests_2
trbromley Feb 27, 2022
3d39cb9
Merge branch 'qcut_integ_tests' of github.com:XanaduAI/pennylane into…
trbromley Feb 28, 2022
fe71c95
Merge branch 'master' into qcut_integ_tests
trbromley Feb 28, 2022
9554fe8
Merge branch 'qcut_integ_tests' into qcut_integ_tests_2
trbromley Feb 28, 2022
55da371
Switch from deepcopy to copy
trbromley Feb 28, 2022
bb99ee3
Merge branch 'master' into qcut_integ_tests_2
trbromley Feb 28, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,9 @@
of graph partitioning parameters.
[(#2168)](https://github.com/PennyLaneAI/pennylane/pull/2168)

A suite of integration tests has been added.
[(#2231)](https://github.com/PennyLaneAI/pennylane/pull/2231)

<h3>Improvements</h3>

* The `gradients` module has been streamlined and special-purpose functions
Expand Down
25 changes: 21 additions & 4 deletions pennylane/transforms/qcut.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,12 +361,16 @@ def graph_to_tape(graph: MultiDiGraph) -> QuantumTape:
wire_map = {w: w for w in wires}
reverse_wire_map = {v: k for k, v in wire_map.items()}

copy_ops = [copy.copy(op) for _, op in ordered_ops]
copy_ops = [copy.copy(op) for _, op in ordered_ops if not isinstance(op, MeasurementProcess)]
copy_meas = [copy.copy(op) for _, op in ordered_ops if isinstance(op, MeasurementProcess)]
Comment on lines +364 to +365
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Previously we were treating the ops (gates) and measurements together. However, when isinstance(op, MeasurementProcess), then setting op._wires = ... does not work when the MeasurementProcess is an expectation value of an observable, because the measurement process uses the contained observable to determine op.wires.

In this updated version, we treat the MeasurementProcesses more carefully after applying the gates.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice!

I'm curious what where the cases that revealed this? Mid circuit measurements and tensor products of obs in measurements?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right, previously if your observable (or part of it) was on wire i, but wire i had mid-circuit measurements and was relabelled to wire i', then the observable wasn't getting its wire updated from wire i to i'.

observables = []

with QuantumTape() as tape:
for op in copy_ops:
new_wires = [wire_map[w] for w in op.wires]
op._wires = Wires(new_wires) # TODO: find a better way to update operation wires
new_wires = Wires([wire_map[w] for w in op.wires])

# TODO: find a better way to update operation wires
op._wires = new_wires
apply(op)

if isinstance(op, MeasureNode):
Expand All @@ -380,6 +384,19 @@ def graph_to_tape(graph: MultiDiGraph) -> QuantumTape:
wire_map[original_wire] = new_wire
reverse_wire_map[new_wire] = original_wire

for meas in copy_meas:
obs = meas.obs
obs._wires = Wires([wire_map[w] for w in obs.wires])
observables.append(obs)

# We assume that each MeasurementProcess node in the graph contributes to a single
# expectation value of an observable, given by the tensor product over the observables of
# each MeasurementProcess.
if len(observables) > 1:
qml.expval(Tensor(*observables))
elif len(observables) == 1:
qml.expval(obs)
trbromley marked this conversation as resolved.
Show resolved Hide resolved

return tape


Expand Down Expand Up @@ -424,7 +441,7 @@ def _get_measurements(

obs = measurement.obs

return [expval(obs @ g) for g in group]
return [expval(copy.deepcopy(obs) @ g) for g in group]
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am very confused by this. the tensor product works in-place 🤔 Perhaps it is a queuing-related quirk.

obs = qml.PauliZ(0) @ qml.PauliZ(1)
obs2 = obs @ qml.PauliZ(2)
assert obs == obs2
print(obs)

Copy link
Contributor

@anthayes92 anthayes92 Feb 26, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, so the original obs is being overwritten in when we take a tensor product of it with another tensor. This is weird!

But deepcopy seems to solve for our needs, or is there potential for more problems with this?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also works with copy.copy so I changed to that. I think copy.copy is a sufficient workaround for now, but long run it'd be better for observables not to be altered when taking the tensor product. Issue posted: #2235



def _prep_zero_state(wire):
Expand Down
172 changes: 170 additions & 2 deletions tests/transforms/test_qcut.py
Original file line number Diff line number Diff line change
Expand Up @@ -987,6 +987,41 @@ def test_multiple_conversions(self):
for tape1, tape2 in zip(tapes1, tapes2):
compare_tapes(tape1, tape2)

def test_identity(self):
"""Tests that the graph_to_tape function correctly performs the inverse of the tape_to_graph
function, including converting a tensor product expectation value into separate nodes in the
graph returned by tape_to_graph, and then combining those nodes again into a single tensor
product in the circuit returned by graph_to_tape"""

with qml.tape.QuantumTape() as tape:
qml.CNOT(wires=[0, 1])
qml.expval(qml.PauliZ(0) @ qml.PauliZ(1))

graph = qcut.tape_to_graph(tape)
tape_out = qcut.graph_to_tape(graph)

compare_tapes(tape, tape_out)
assert len(tape_out.measurements) == 1

def test_change_obs_wires(self):
"""Tests that the graph_to_tape function correctly swaps the wires of observables when
the tape contains mid-circuit measurements"""

with qml.tape.QuantumTape() as tape:
qml.CNOT(wires=[0, 1])
qcut.MeasureNode(wires=1)
qcut.PrepareNode(wires=1)
qml.CNOT(wires=[0, 1])
qml.expval(qml.PauliZ(1))

graph = qcut.tape_to_graph(tape)
tape_out = qcut.graph_to_tape(graph)

m = tape_out.measurements
assert len(m) == 1
assert m[0].wires == Wires([2])
assert m[0].obs.name == "PauliZ"


class TestGetMeasurements:
"""Tests for the _get_measurements function"""
Expand Down Expand Up @@ -1802,12 +1837,13 @@ def f(x):
assert np.allclose(grad, expected_grad)


@pytest.mark.parametrize("use_opt_einsum", [True, False])
class TestCutCircuitTransform:
"""
Tests for the cut_circuit transform
"""

def test_simple_cut_circuit(self, mocker):
def test_simple_cut_circuit(self, mocker, use_opt_einsum):
"""
Tests the full circuit cutting pipeline returns the correct value and
gradient for a simple circuit using the `cut_circuit` transform.
Expand All @@ -1827,7 +1863,139 @@ def circuit(x):

spy = mocker.spy(qcut, "qcut_processing_fn")
x = np.array(0.531, requires_grad=True)
cut_circuit = qcut.cut_circuit(circuit)
cut_circuit = qcut.cut_circuit(circuit, use_opt_einsum=use_opt_einsum)

assert np.isclose(cut_circuit(x), float(circuit(x)))
spy.assert_called_once()

gradient = qml.grad(circuit)(x)
cut_gradient = qml.grad(cut_circuit)(x)

assert np.isclose(gradient, cut_gradient)

def test_simple_cut_circuit_torch(self, use_opt_einsum):
"""
Tests the full circuit cutting pipeline returns the correct value and
gradient for a simple circuit using the `cut_circuit` transform with the torch interface.
"""
torch = pytest.importorskip("torch")

dev = qml.device("default.qubit", wires=2)

@qml.qnode(dev, interface="torch")
def circuit(x):
qml.RX(x, wires=0)
qml.RY(0.543, wires=1)
qml.WireCut(wires=0)
qml.CNOT(wires=[0, 1])
qml.RZ(0.240, wires=0)
qml.RZ(0.133, wires=1)
return qml.expval(qml.PauliZ(wires=[0]))

x = torch.tensor(0.531, requires_grad=True)
cut_circuit = qcut.cut_circuit(circuit, use_opt_einsum=use_opt_einsum)

res = cut_circuit(x)
res_expected = circuit(x)
assert np.isclose(res.detach().numpy(), res_expected.detach().numpy())
assert isinstance(res, torch.Tensor)

res.backward()
grad = x.grad.detach().numpy()

x.grad = None
res_expected.backward()
grad_expected = x.grad.detach().numpy()

assert np.isclose(grad, grad_expected)

def test_simple_cut_circuit_tf(self, use_opt_einsum):
"""
Tests the full circuit cutting pipeline returns the correct value and
gradient for a simple circuit using the `cut_circuit` transform with the TF interface.
"""
tf = pytest.importorskip("tensorflow")

dev = qml.device("default.qubit", wires=2)

@qml.qnode(dev, interface="tf")
def circuit(x):
qml.RX(x, wires=0)
qml.RY(0.543, wires=1)
qml.WireCut(wires=0)
qml.CNOT(wires=[0, 1])
qml.RZ(0.240, wires=0)
qml.RZ(0.133, wires=1)
return qml.expval(qml.PauliZ(wires=[0]))

x = tf.Variable(0.531)
cut_circuit = qcut.cut_circuit(circuit, use_opt_einsum=use_opt_einsum)

with tf.GradientTape() as tape:
res = cut_circuit(x)

grad = tape.gradient(res, x)

with tf.GradientTape() as tape:
res_expected = circuit(x)

grad_expected = tape.gradient(res_expected, x)

assert np.isclose(res, res_expected)
assert np.isclose(grad, grad_expected)

def test_simple_cut_circuit_jax(self, use_opt_einsum):
"""
Tests the full circuit cutting pipeline returns the correct value and
gradient for a simple circuit using the `cut_circuit` transform with the Jax interface and
using JIT.
"""
jax = pytest.importorskip("jax")
import jax.numpy as jnp

dev = qml.device("default.qubit", wires=2)

@qml.qnode(dev, interface="jax")
def circuit(x):
qml.RX(x, wires=0)
qml.RY(0.543, wires=1)
qml.WireCut(wires=0)
qml.CNOT(wires=[0, 1])
qml.RZ(0.240, wires=0)
qml.RZ(0.133, wires=1)
return qml.expval(qml.PauliZ(wires=[0]))

x = jnp.array(0.531)
cut_circuit = qcut.cut_circuit(circuit, use_opt_einsum=use_opt_einsum)

res = cut_circuit(x)
res_expected = circuit(x)

grad = jax.grad(cut_circuit)(x)
grad_expected = jax.grad(circuit)(x)

assert np.isclose(res, res_expected)
assert np.isclose(grad, grad_expected)

def test_with_mid_circuit_measurement(self, mocker, use_opt_einsum):
"""Tests the full circuit cutting pipeline returns the correct value and gradient for a
circuit that contains mid-circuit measurements, using the `cut_circuit` transform."""
dev = qml.device("default.qubit", wires=3)

@qml.qnode(dev)
def circuit(x):
qml.RX(x, wires=0)
qml.CNOT(wires=[0, 1])
qml.WireCut(wires=1)
qml.RX(np.sin(x) ** 2, wires=1)
qml.CNOT(wires=[1, 2])
qml.WireCut(wires=1)
qml.CNOT(wires=[0, 1])
return qml.expval(qml.PauliZ(0) @ qml.PauliZ(1))

spy = mocker.spy(qcut, "qcut_processing_fn")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

perhaps it would be more appropriate to spy on _get_measurements in this test

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Happy to change, but why do you think _get_measurements is more appropriate?

x = np.array(0.531, requires_grad=True)
cut_circuit = qcut.cut_circuit(circuit, use_opt_einsum=use_opt_einsum)

assert np.isclose(cut_circuit(x), float(circuit(x)))
spy.assert_called_once()
Expand Down