diff --git a/doc/releases/changelog-dev.md b/doc/releases/changelog-dev.md
index 09e5b982c4a..b9022bd69a3 100644
--- a/doc/releases/changelog-dev.md
+++ b/doc/releases/changelog-dev.md
@@ -208,6 +208,9 @@
Circuit fragments that are disconnected from the terminal measurements are now removed.
[(#2254)](https://github.com/PennyLaneAI/pennylane/pull/2254)
+
+ `WireCut` operations that do not lead to a disconnection are now being removed.
+ [(#2260)](https://github.com/PennyLaneAI/pennylane/pull/2260)
Improvements
diff --git a/pennylane/transforms/qcut.py b/pennylane/transforms/qcut.py
index 2d21ae1741e..1f1b40736eb 100644
--- a/pennylane/transforms/qcut.py
+++ b/pennylane/transforms/qcut.py
@@ -233,6 +233,7 @@ def tape_to_graph(tape: QuantumTape) -> MultiDiGraph:
return graph
+# pylint: disable=too-many-branches
def fragment_graph(graph: MultiDiGraph) -> Tuple[Tuple[MultiDiGraph], MultiDiGraph]:
"""
Fragments a graph into a collection of subgraphs as well as returning
@@ -305,7 +306,16 @@ def fragment_graph(graph: MultiDiGraph) -> Tuple[Tuple[MultiDiGraph], MultiDiGra
if subgraph.has_node(node2):
end_fragment = i
- communication_graph.add_edge(start_fragment, end_fragment, pair=(node1, node2))
+ if start_fragment != end_fragment:
+ communication_graph.add_edge(start_fragment, end_fragment, pair=(node1, node2))
+ else:
+ # The MeasureNode and PrepareNode pair live in the same fragment and did not result
+ # in a disconnection. We can therefore remove these nodes. Note that we do not need
+ # to worry about adding back an edge between the predecessor to node1 and the successor
+ # to node2 because our next step is to convert the fragment circuit graphs to tapes,
+ # a process that does not depend on edge connections in the subgraph.
+ subgraphs[start_fragment].remove_node(node1)
+ subgraphs[end_fragment].remove_node(node2)
terminal_indices = [i for i, s in enumerate(subgraphs) for n in measure_nodes if s.has_node(n)]
diff --git a/tests/transforms/test_qcut.py b/tests/transforms/test_qcut.py
index a33c989eec1..782c16aee24 100644
--- a/tests/transforms/test_qcut.py
+++ b/tests/transforms/test_qcut.py
@@ -21,7 +21,7 @@
from itertools import product
import pytest
-from networkx import MultiDiGraph
+from networkx import MultiDiGraph, number_of_selfloops
from scipy.stats import unitary_group
import pennylane as qml
@@ -900,6 +900,23 @@ def test_communication_graph_persistence(self):
assert communication_graph_0.nodes == communication_graph_1.nodes
assert communication_graph_0.edges == communication_graph_1.edges
+ def test_contained_cut(self):
+ """Tests that fragmentation ignores `MeasureNode` and `PrepareNode` pairs that do not
+ result in a disconnection"""
+ with qml.tape.QuantumTape() as tape:
+ qml.RX(0.4, wires=0)
+ qml.CNOT(wires=[0, 1])
+ qml.WireCut(wires=0)
+ qml.CNOT(wires=[0, 1])
+ qml.RX(0.4, wires=0)
+ qml.expval(qml.PauliZ(0))
+
+ g = qcut.tape_to_graph(tape)
+ qcut.replace_wire_cut_nodes(g)
+ fragments, communication_graph = qcut.fragment_graph(g)
+ assert len(fragments) == 1
+ assert number_of_selfloops(communication_graph) == 0
+
class TestGraphToTape:
"""Tests that directed multigraphs are correctly converted to tapes"""
@@ -2095,7 +2112,7 @@ def circuit(x):
res = circuit(x)
assert np.allclose(res, np.cos(x))
assert len(spy.call_args[0][0]) == 1 # there should be 2 tensors for wire 0
- assert spy.call_args[0][0][0].shape == (4, 4)
+ assert spy.call_args[0][0][0].shape == ()
class TestCutCircuitTransformValidation: