-
Notifications
You must be signed in to change notification settings - Fork 586
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Circuit cutting: add mid-circuit measurement integration tests #2234
Changes from 13 commits
73acf20
11c9091
447af24
a88da66
b4314e4
f3e8a56
5b69cb5
f0f3839
7192f3d
0cc78d7
4ac5f7b
6f68b95
89dd999
01e507c
b4a9d66
23e7c63
c7b8532
9bf43a6
57c247d
3d39cb9
fe71c95
9554fe8
55da371
bb99ee3
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -361,12 +361,16 @@ def graph_to_tape(graph: MultiDiGraph) -> QuantumTape: | |
wire_map = {w: w for w in wires} | ||
reverse_wire_map = {v: k for k, v in wire_map.items()} | ||
|
||
copy_ops = [copy.copy(op) for _, op in ordered_ops] | ||
copy_ops = [copy.copy(op) for _, op in ordered_ops if not isinstance(op, MeasurementProcess)] | ||
copy_meas = [copy.copy(op) for _, op in ordered_ops if isinstance(op, MeasurementProcess)] | ||
observables = [] | ||
|
||
with QuantumTape() as tape: | ||
for op in copy_ops: | ||
new_wires = [wire_map[w] for w in op.wires] | ||
op._wires = Wires(new_wires) # TODO: find a better way to update operation wires | ||
new_wires = Wires([wire_map[w] for w in op.wires]) | ||
|
||
# TODO: find a better way to update operation wires | ||
op._wires = new_wires | ||
apply(op) | ||
|
||
if isinstance(op, MeasureNode): | ||
|
@@ -380,6 +384,19 @@ def graph_to_tape(graph: MultiDiGraph) -> QuantumTape: | |
wire_map[original_wire] = new_wire | ||
reverse_wire_map[new_wire] = original_wire | ||
|
||
for meas in copy_meas: | ||
obs = meas.obs | ||
obs._wires = Wires([wire_map[w] for w in obs.wires]) | ||
observables.append(obs) | ||
|
||
# We assume that each MeasurementProcess node in the graph contributes to a single | ||
# expectation value of an observable, given by the tensor product over the observables of | ||
# each MeasurementProcess. | ||
if len(observables) > 1: | ||
qml.expval(Tensor(*observables)) | ||
elif len(observables) == 1: | ||
qml.expval(obs) | ||
trbromley marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
return tape | ||
|
||
|
||
|
@@ -424,7 +441,7 @@ def _get_measurements( | |
|
||
obs = measurement.obs | ||
|
||
return [expval(obs @ g) for g in group] | ||
return [expval(copy.deepcopy(obs) @ g) for g in group] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I am very confused by this. the tensor product works in-place 🤔 Perhaps it is a queuing-related quirk. obs = qml.PauliZ(0) @ qml.PauliZ(1)
obs2 = obs @ qml.PauliZ(2)
assert obs == obs2
print(obs) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I see, so the original But There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Also works with |
||
|
||
|
||
def _prep_zero_state(wire): | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -987,6 +987,41 @@ def test_multiple_conversions(self): | |
for tape1, tape2 in zip(tapes1, tapes2): | ||
compare_tapes(tape1, tape2) | ||
|
||
def test_identity(self): | ||
"""Tests that the graph_to_tape function correctly performs the inverse of the tape_to_graph | ||
function, including converting a tensor product expectation value into separate nodes in the | ||
graph returned by tape_to_graph, and then combining those nodes again into a single tensor | ||
product in the circuit returned by graph_to_tape""" | ||
|
||
with qml.tape.QuantumTape() as tape: | ||
qml.CNOT(wires=[0, 1]) | ||
qml.expval(qml.PauliZ(0) @ qml.PauliZ(1)) | ||
|
||
graph = qcut.tape_to_graph(tape) | ||
tape_out = qcut.graph_to_tape(graph) | ||
|
||
compare_tapes(tape, tape_out) | ||
assert len(tape_out.measurements) == 1 | ||
|
||
def test_change_obs_wires(self): | ||
"""Tests that the graph_to_tape function correctly swaps the wires of observables when | ||
the tape contains mid-circuit measurements""" | ||
|
||
with qml.tape.QuantumTape() as tape: | ||
qml.CNOT(wires=[0, 1]) | ||
qcut.MeasureNode(wires=1) | ||
qcut.PrepareNode(wires=1) | ||
qml.CNOT(wires=[0, 1]) | ||
qml.expval(qml.PauliZ(1)) | ||
|
||
graph = qcut.tape_to_graph(tape) | ||
tape_out = qcut.graph_to_tape(graph) | ||
|
||
m = tape_out.measurements | ||
assert len(m) == 1 | ||
assert m[0].wires == Wires([2]) | ||
assert m[0].obs.name == "PauliZ" | ||
|
||
|
||
class TestGetMeasurements: | ||
"""Tests for the _get_measurements function""" | ||
|
@@ -1802,12 +1837,13 @@ def f(x): | |
assert np.allclose(grad, expected_grad) | ||
|
||
|
||
@pytest.mark.parametrize("use_opt_einsum", [True, False]) | ||
class TestCutCircuitTransform: | ||
""" | ||
Tests for the cut_circuit transform | ||
""" | ||
|
||
def test_simple_cut_circuit(self, mocker): | ||
def test_simple_cut_circuit(self, mocker, use_opt_einsum): | ||
""" | ||
Tests the full circuit cutting pipeline returns the correct value and | ||
gradient for a simple circuit using the `cut_circuit` transform. | ||
|
@@ -1827,7 +1863,139 @@ def circuit(x): | |
|
||
spy = mocker.spy(qcut, "qcut_processing_fn") | ||
x = np.array(0.531, requires_grad=True) | ||
cut_circuit = qcut.cut_circuit(circuit) | ||
cut_circuit = qcut.cut_circuit(circuit, use_opt_einsum=use_opt_einsum) | ||
|
||
assert np.isclose(cut_circuit(x), float(circuit(x))) | ||
spy.assert_called_once() | ||
|
||
gradient = qml.grad(circuit)(x) | ||
cut_gradient = qml.grad(cut_circuit)(x) | ||
|
||
assert np.isclose(gradient, cut_gradient) | ||
|
||
def test_simple_cut_circuit_torch(self, use_opt_einsum): | ||
""" | ||
Tests the full circuit cutting pipeline returns the correct value and | ||
gradient for a simple circuit using the `cut_circuit` transform with the torch interface. | ||
""" | ||
torch = pytest.importorskip("torch") | ||
|
||
dev = qml.device("default.qubit", wires=2) | ||
|
||
@qml.qnode(dev, interface="torch") | ||
def circuit(x): | ||
qml.RX(x, wires=0) | ||
qml.RY(0.543, wires=1) | ||
qml.WireCut(wires=0) | ||
qml.CNOT(wires=[0, 1]) | ||
qml.RZ(0.240, wires=0) | ||
qml.RZ(0.133, wires=1) | ||
return qml.expval(qml.PauliZ(wires=[0])) | ||
|
||
x = torch.tensor(0.531, requires_grad=True) | ||
cut_circuit = qcut.cut_circuit(circuit, use_opt_einsum=use_opt_einsum) | ||
|
||
res = cut_circuit(x) | ||
res_expected = circuit(x) | ||
assert np.isclose(res.detach().numpy(), res_expected.detach().numpy()) | ||
assert isinstance(res, torch.Tensor) | ||
|
||
res.backward() | ||
grad = x.grad.detach().numpy() | ||
|
||
x.grad = None | ||
res_expected.backward() | ||
grad_expected = x.grad.detach().numpy() | ||
|
||
assert np.isclose(grad, grad_expected) | ||
|
||
def test_simple_cut_circuit_tf(self, use_opt_einsum): | ||
""" | ||
Tests the full circuit cutting pipeline returns the correct value and | ||
gradient for a simple circuit using the `cut_circuit` transform with the TF interface. | ||
""" | ||
tf = pytest.importorskip("tensorflow") | ||
|
||
dev = qml.device("default.qubit", wires=2) | ||
|
||
@qml.qnode(dev, interface="tf") | ||
def circuit(x): | ||
qml.RX(x, wires=0) | ||
qml.RY(0.543, wires=1) | ||
qml.WireCut(wires=0) | ||
qml.CNOT(wires=[0, 1]) | ||
qml.RZ(0.240, wires=0) | ||
qml.RZ(0.133, wires=1) | ||
return qml.expval(qml.PauliZ(wires=[0])) | ||
|
||
x = tf.Variable(0.531) | ||
cut_circuit = qcut.cut_circuit(circuit, use_opt_einsum=use_opt_einsum) | ||
|
||
with tf.GradientTape() as tape: | ||
res = cut_circuit(x) | ||
|
||
grad = tape.gradient(res, x) | ||
|
||
with tf.GradientTape() as tape: | ||
res_expected = circuit(x) | ||
|
||
grad_expected = tape.gradient(res_expected, x) | ||
|
||
assert np.isclose(res, res_expected) | ||
assert np.isclose(grad, grad_expected) | ||
|
||
def test_simple_cut_circuit_jax(self, use_opt_einsum): | ||
""" | ||
Tests the full circuit cutting pipeline returns the correct value and | ||
gradient for a simple circuit using the `cut_circuit` transform with the Jax interface and | ||
using JIT. | ||
""" | ||
jax = pytest.importorskip("jax") | ||
import jax.numpy as jnp | ||
|
||
dev = qml.device("default.qubit", wires=2) | ||
|
||
@qml.qnode(dev, interface="jax") | ||
def circuit(x): | ||
qml.RX(x, wires=0) | ||
qml.RY(0.543, wires=1) | ||
qml.WireCut(wires=0) | ||
qml.CNOT(wires=[0, 1]) | ||
qml.RZ(0.240, wires=0) | ||
qml.RZ(0.133, wires=1) | ||
return qml.expval(qml.PauliZ(wires=[0])) | ||
|
||
x = jnp.array(0.531) | ||
cut_circuit = qcut.cut_circuit(circuit, use_opt_einsum=use_opt_einsum) | ||
|
||
res = cut_circuit(x) | ||
res_expected = circuit(x) | ||
|
||
grad = jax.grad(cut_circuit)(x) | ||
grad_expected = jax.grad(circuit)(x) | ||
|
||
assert np.isclose(res, res_expected) | ||
assert np.isclose(grad, grad_expected) | ||
|
||
def test_with_mid_circuit_measurement(self, mocker, use_opt_einsum): | ||
"""Tests the full circuit cutting pipeline returns the correct value and gradient for a | ||
circuit that contains mid-circuit measurements, using the `cut_circuit` transform.""" | ||
dev = qml.device("default.qubit", wires=3) | ||
|
||
@qml.qnode(dev) | ||
def circuit(x): | ||
qml.RX(x, wires=0) | ||
qml.CNOT(wires=[0, 1]) | ||
qml.WireCut(wires=1) | ||
qml.RX(np.sin(x) ** 2, wires=1) | ||
qml.CNOT(wires=[1, 2]) | ||
qml.WireCut(wires=1) | ||
qml.CNOT(wires=[0, 1]) | ||
return qml.expval(qml.PauliZ(0) @ qml.PauliZ(1)) | ||
|
||
spy = mocker.spy(qcut, "qcut_processing_fn") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. perhaps it would be more appropriate to spy on There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Happy to change, but why do you think |
||
x = np.array(0.531, requires_grad=True) | ||
cut_circuit = qcut.cut_circuit(circuit, use_opt_einsum=use_opt_einsum) | ||
|
||
assert np.isclose(cut_circuit(x), float(circuit(x))) | ||
spy.assert_called_once() | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Previously we were treating the ops (gates) and measurements together. However, when
isinstance(op, MeasurementProcess)
, then settingop._wires = ...
does not work when theMeasurementProcess
is an expectation value of an observable, because the measurement process uses the contained observable to determineop.wires
.In this updated version, we treat the
MeasurementProcesses
more carefully after applying the gates.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice!
I'm curious what where the cases that revealed this? Mid circuit measurements and tensor products of obs in measurements?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Right, previously if your observable (or part of it) was on wire
i
, but wirei
had mid-circuit measurements and was relabelled to wirei'
, then the observable wasn't getting its wire updated from wirei
toi'
.