diff --git a/doc/releases/changelog-dev.md b/doc/releases/changelog-dev.md index 09e5b982c4a..b9022bd69a3 100644 --- a/doc/releases/changelog-dev.md +++ b/doc/releases/changelog-dev.md @@ -208,6 +208,9 @@ Circuit fragments that are disconnected from the terminal measurements are now removed. [(#2254)](https://github.com/PennyLaneAI/pennylane/pull/2254) + + `WireCut` operations that do not lead to a disconnection are now being removed. + [(#2260)](https://github.com/PennyLaneAI/pennylane/pull/2260)

Improvements

diff --git a/pennylane/transforms/qcut.py b/pennylane/transforms/qcut.py index 2d21ae1741e..1f1b40736eb 100644 --- a/pennylane/transforms/qcut.py +++ b/pennylane/transforms/qcut.py @@ -233,6 +233,7 @@ def tape_to_graph(tape: QuantumTape) -> MultiDiGraph: return graph +# pylint: disable=too-many-branches def fragment_graph(graph: MultiDiGraph) -> Tuple[Tuple[MultiDiGraph], MultiDiGraph]: """ Fragments a graph into a collection of subgraphs as well as returning @@ -305,7 +306,16 @@ def fragment_graph(graph: MultiDiGraph) -> Tuple[Tuple[MultiDiGraph], MultiDiGra if subgraph.has_node(node2): end_fragment = i - communication_graph.add_edge(start_fragment, end_fragment, pair=(node1, node2)) + if start_fragment != end_fragment: + communication_graph.add_edge(start_fragment, end_fragment, pair=(node1, node2)) + else: + # The MeasureNode and PrepareNode pair live in the same fragment and did not result + # in a disconnection. We can therefore remove these nodes. Note that we do not need + # to worry about adding back an edge between the predecessor to node1 and the successor + # to node2 because our next step is to convert the fragment circuit graphs to tapes, + # a process that does not depend on edge connections in the subgraph. + subgraphs[start_fragment].remove_node(node1) + subgraphs[end_fragment].remove_node(node2) terminal_indices = [i for i, s in enumerate(subgraphs) for n in measure_nodes if s.has_node(n)] diff --git a/tests/transforms/test_qcut.py b/tests/transforms/test_qcut.py index a33c989eec1..782c16aee24 100644 --- a/tests/transforms/test_qcut.py +++ b/tests/transforms/test_qcut.py @@ -21,7 +21,7 @@ from itertools import product import pytest -from networkx import MultiDiGraph +from networkx import MultiDiGraph, number_of_selfloops from scipy.stats import unitary_group import pennylane as qml @@ -900,6 +900,23 @@ def test_communication_graph_persistence(self): assert communication_graph_0.nodes == communication_graph_1.nodes assert communication_graph_0.edges == communication_graph_1.edges + def test_contained_cut(self): + """Tests that fragmentation ignores `MeasureNode` and `PrepareNode` pairs that do not + result in a disconnection""" + with qml.tape.QuantumTape() as tape: + qml.RX(0.4, wires=0) + qml.CNOT(wires=[0, 1]) + qml.WireCut(wires=0) + qml.CNOT(wires=[0, 1]) + qml.RX(0.4, wires=0) + qml.expval(qml.PauliZ(0)) + + g = qcut.tape_to_graph(tape) + qcut.replace_wire_cut_nodes(g) + fragments, communication_graph = qcut.fragment_graph(g) + assert len(fragments) == 1 + assert number_of_selfloops(communication_graph) == 0 + class TestGraphToTape: """Tests that directed multigraphs are correctly converted to tapes""" @@ -2095,7 +2112,7 @@ def circuit(x): res = circuit(x) assert np.allclose(res, np.cos(x)) assert len(spy.call_args[0][0]) == 1 # there should be 2 tensors for wire 0 - assert spy.call_args[0][0][0].shape == (4, 4) + assert spy.call_args[0][0][0].shape == () class TestCutCircuitTransformValidation: