Skip to content

Commit

Permalink
Circuit cutting: add mid-circuit measurement integration tests (#2234)
Browse files Browse the repository at this point in the history
* Add test

* Add to changelog

* Add tests in different interfaces

* Work on TF Jit

* Remove JIT tests

* Fix bug in changing wires of MeasurementProcesses

* Fix graph_to_tape

* Add supporting tests

* Parametrize over use_opt_einsum

* Copy observable to prevent in-place modification

* Add test

* Remove added line

* Add to changelog

* Remove mention of jit

* Remove isinstance check

* Switch from deepcopy to copy

Co-authored-by: anthayes92 <34694788+anthayes92@users.noreply.github.com>
  • Loading branch information
trbromley and anthayes92 authored Feb 28, 2022
1 parent 7ad040b commit 57237b2
Show file tree
Hide file tree
Showing 3 changed files with 187 additions and 4 deletions.
1 change: 1 addition & 0 deletions doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@

A suite of integration tests has been added.
[(#2231)](https://github.com/PennyLaneAI/pennylane/pull/2231)
[(#2234)](https://github.com/PennyLaneAI/pennylane/pull/2234)

<h3>Improvements</h3>

Expand Down
25 changes: 21 additions & 4 deletions pennylane/transforms/qcut.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,12 +361,16 @@ def graph_to_tape(graph: MultiDiGraph) -> QuantumTape:
wire_map = {w: w for w in wires}
reverse_wire_map = {v: k for k, v in wire_map.items()}

copy_ops = [copy.copy(op) for _, op in ordered_ops]
copy_ops = [copy.copy(op) for _, op in ordered_ops if not isinstance(op, MeasurementProcess)]
copy_meas = [copy.copy(op) for _, op in ordered_ops if isinstance(op, MeasurementProcess)]
observables = []

with QuantumTape() as tape:
for op in copy_ops:
new_wires = [wire_map[w] for w in op.wires]
op._wires = Wires(new_wires) # TODO: find a better way to update operation wires
new_wires = Wires([wire_map[w] for w in op.wires])

# TODO: find a better way to update operation wires
op._wires = new_wires
apply(op)

if isinstance(op, MeasureNode):
Expand All @@ -380,6 +384,19 @@ def graph_to_tape(graph: MultiDiGraph) -> QuantumTape:
wire_map[original_wire] = new_wire
reverse_wire_map[new_wire] = original_wire

for meas in copy_meas:
obs = meas.obs
obs._wires = Wires([wire_map[w] for w in obs.wires])
observables.append(obs)

# We assume that each MeasurementProcess node in the graph contributes to a single
# expectation value of an observable, given by the tensor product over the observables of
# each MeasurementProcess.
if len(observables) > 1:
qml.expval(Tensor(*observables))
elif len(observables) == 1:
qml.expval(obs)

return tape


Expand Down Expand Up @@ -424,7 +441,7 @@ def _get_measurements(

obs = measurement.obs

return [expval(obs @ g) for g in group]
return [expval(copy.copy(obs) @ g) for g in group]


def _prep_zero_state(wire):
Expand Down
165 changes: 165 additions & 0 deletions tests/transforms/test_qcut.py
Original file line number Diff line number Diff line change
Expand Up @@ -987,6 +987,41 @@ def test_multiple_conversions(self):
for tape1, tape2 in zip(tapes1, tapes2):
compare_tapes(tape1, tape2)

def test_identity(self):
"""Tests that the graph_to_tape function correctly performs the inverse of the tape_to_graph
function, including converting a tensor product expectation value into separate nodes in the
graph returned by tape_to_graph, and then combining those nodes again into a single tensor
product in the circuit returned by graph_to_tape"""

with qml.tape.QuantumTape() as tape:
qml.CNOT(wires=[0, 1])
qml.expval(qml.PauliZ(0) @ qml.PauliZ(1))

graph = qcut.tape_to_graph(tape)
tape_out = qcut.graph_to_tape(graph)

compare_tapes(tape, tape_out)
assert len(tape_out.measurements) == 1

def test_change_obs_wires(self):
"""Tests that the graph_to_tape function correctly swaps the wires of observables when
the tape contains mid-circuit measurements"""

with qml.tape.QuantumTape() as tape:
qml.CNOT(wires=[0, 1])
qcut.MeasureNode(wires=1)
qcut.PrepareNode(wires=1)
qml.CNOT(wires=[0, 1])
qml.expval(qml.PauliZ(1))

graph = qcut.tape_to_graph(tape)
tape_out = qcut.graph_to_tape(graph)

m = tape_out.measurements
assert len(m) == 1
assert m[0].wires == Wires([2])
assert m[0].obs.name == "PauliZ"


class TestGetMeasurements:
"""Tests for the _get_measurements function"""
Expand Down Expand Up @@ -1940,6 +1975,136 @@ def circuit(x):
assert np.isclose(res, res_expected)
assert np.isclose(grad, grad_expected)

def test_with_mid_circuit_measurement(self, mocker, use_opt_einsum):
"""Tests the full circuit cutting pipeline returns the correct value and gradient for a
circuit that contains mid-circuit measurements, using the `cut_circuit` transform."""
dev = qml.device("default.qubit", wires=3)

@qml.qnode(dev)
def circuit(x):
qml.RX(x, wires=0)
qml.CNOT(wires=[0, 1])
qml.WireCut(wires=1)
qml.RX(np.sin(x) ** 2, wires=1)
qml.CNOT(wires=[1, 2])
qml.WireCut(wires=1)
qml.CNOT(wires=[0, 1])
return qml.expval(qml.PauliZ(0) @ qml.PauliZ(1))

spy = mocker.spy(qcut, "qcut_processing_fn")
x = np.array(0.531, requires_grad=True)
cut_circuit = qcut.cut_circuit(circuit, use_opt_einsum=use_opt_einsum)

assert np.isclose(cut_circuit(x), float(circuit(x)))
spy.assert_called_once()

gradient = qml.grad(circuit)(x)
cut_gradient = qml.grad(cut_circuit)(x)

assert np.isclose(gradient, cut_gradient)

def test_simple_cut_circuit_torch(self, use_opt_einsum):
"""
Tests the full circuit cutting pipeline returns the correct value and
gradient for a simple circuit using the `cut_circuit` transform with the torch interface.
"""
torch = pytest.importorskip("torch")

dev = qml.device("default.qubit", wires=2)

@qml.qnode(dev, interface="torch")
def circuit(x):
qml.RX(x, wires=0)
qml.RY(0.543, wires=1)
qml.WireCut(wires=0)
qml.CNOT(wires=[0, 1])
qml.RZ(0.240, wires=0)
qml.RZ(0.133, wires=1)
return qml.expval(qml.PauliZ(wires=[0]))

x = torch.tensor(0.531, requires_grad=True)
cut_circuit = qcut.cut_circuit(circuit, use_opt_einsum=use_opt_einsum)

res = cut_circuit(x)
res_expected = circuit(x)
assert np.isclose(res.detach().numpy(), res_expected.detach().numpy())

res.backward()
grad = x.grad.detach().numpy()

x.grad = None
res_expected.backward()
grad_expected = x.grad.detach().numpy()

assert np.isclose(grad, grad_expected)

def test_simple_cut_circuit_tf(self, use_opt_einsum):
"""
Tests the full circuit cutting pipeline returns the correct value and
gradient for a simple circuit using the `cut_circuit` transform with the TF interface.
"""
tf = pytest.importorskip("tensorflow")

dev = qml.device("default.qubit", wires=2)

@qml.qnode(dev, interface="tf")
def circuit(x):
qml.RX(x, wires=0)
qml.RY(0.543, wires=1)
qml.WireCut(wires=0)
qml.CNOT(wires=[0, 1])
qml.RZ(0.240, wires=0)
qml.RZ(0.133, wires=1)
return qml.expval(qml.PauliZ(wires=[0]))

x = tf.Variable(0.531)
cut_circuit = qcut.cut_circuit(circuit, use_opt_einsum=use_opt_einsum)

with tf.GradientTape() as tape:
res = cut_circuit(x)

grad = tape.gradient(res, x)

with tf.GradientTape() as tape:
res_expected = circuit(x)

grad_expected = tape.gradient(res_expected, x)

assert np.isclose(res, res_expected)
assert np.isclose(grad, grad_expected)

def test_simple_cut_circuit_jax(self, use_opt_einsum):
"""
Tests the full circuit cutting pipeline returns the correct value and
gradient for a simple circuit using the `cut_circuit` transform with the Jax interface.
"""
jax = pytest.importorskip("jax")
import jax.numpy as jnp

dev = qml.device("default.qubit", wires=2)

@qml.qnode(dev, interface="jax")
def circuit(x):
qml.RX(x, wires=0)
qml.RY(0.543, wires=1)
qml.WireCut(wires=0)
qml.CNOT(wires=[0, 1])
qml.RZ(0.240, wires=0)
qml.RZ(0.133, wires=1)
return qml.expval(qml.PauliZ(wires=[0]))

x = jnp.array(0.531)
cut_circuit = qcut.cut_circuit(circuit, use_opt_einsum=use_opt_einsum)

res = cut_circuit(x)
res_expected = circuit(x)

grad = jax.grad(cut_circuit)(x)
grad_expected = jax.grad(circuit)(x)

assert np.isclose(res, res_expected)
assert np.isclose(grad, grad_expected)


class TestCutStrategy:
"""Tests for class CutStrategy"""
Expand Down

0 comments on commit 57237b2

Please sign in to comment.