Skip to content

Commit

Permalink
Fix qml.draw/_mpl to work with deferred measurements circuits that …
Browse files Browse the repository at this point in the history
…use MCMs in measurements (#5610)

**Context:**
`defer_measurements` leaves `MeasurementValue`s in the terminal
measurement processes if they are present. This is not friendly with the
drawer as the classical wires for those measurement values do not exist
when using `defer_measurements`. This PR adds a "hack" fix so that the
drawer/MPL drawer work with this scenario.

**Description of the Change:**
Add `transform_deferred_measurements_tape` util function, which replaces
any `MeasurementValue`s present in the tape after applying
`defer_measurements` with wires. This transform is only applied in
`tape_text` and `tape_mpl`.

**Benefits:**
Drawer works better with MCMs

**Possible Drawbacks:**
Hacky fix, technical debt to implement cleaner fix later

**Related GitHub Issues:**
#5588
  • Loading branch information
mudit2812 authored May 1, 2024
1 parent 7621935 commit 091cd68
Show file tree
Hide file tree
Showing 5 changed files with 91 additions and 2 deletions.
4 changes: 4 additions & 0 deletions doc/releases/changelog-0.36.0.md
Original file line number Diff line number Diff line change
Expand Up @@ -551,6 +551,10 @@

<h3>Bug fixes 🐛</h3>

* Fixed a bug where `qml.draw` and `qml.draw_mpl` incorrectly raised errors for circuits collecting statistics on mid-circuit measurements
while using `qml.defer_measurements`.
[(#5610)](https://github.com/PennyLaneAI/pennylane/pull/5610)

* Using shot vectors with `param_shift(... broadcast=True)` caused a bug. This combination is no longer supported
and will be added again in the next release.
[(#5612)](https://github.com/PennyLaneAI/pennylane/pull/5612)
Expand Down
9 changes: 8 additions & 1 deletion pennylane/drawer/tape_mpl.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,13 @@
from pennylane.measurements import MidMeasureMP
from .mpldrawer import MPLDrawer
from .drawable_layers import drawable_layers
from .utils import convert_wire_order, unwrap_controls, cwire_connections, default_bit_map
from .utils import (
convert_wire_order,
cwire_connections,
default_bit_map,
transform_deferred_measurements_tape,
unwrap_controls,
)
from .style import _set_style

has_mpl = True
Expand Down Expand Up @@ -216,6 +222,7 @@ def _tape_mpl(tape, wire_order=None, show_all_wires=False, decimals=None, *, fig
fontsize = kwargs.get("fontsize", None)

wire_map = convert_wire_order(tape, wire_order=wire_order, show_all_wires=show_all_wires)
tape = transform_deferred_measurements_tape(tape)
tape = qml.map_wires(tape, wire_map=wire_map)[0][0]
bit_map = default_bit_map(tape)

Expand Down
9 changes: 8 additions & 1 deletion pennylane/drawer/tape_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,13 @@
)

from .drawable_layers import drawable_layers
from .utils import convert_wire_order, unwrap_controls, cwire_connections, default_bit_map
from .utils import (
convert_wire_order,
cwire_connections,
default_bit_map,
transform_deferred_measurements_tape,
unwrap_controls,
)


@dataclass
Expand Down Expand Up @@ -425,6 +431,7 @@ def tape_text(
New tape offset: 4
"""
tape = transform_deferred_measurements_tape(tape)
cache = cache or {}
cache.setdefault("tape_offset", 0)
cache.setdefault("matrices", [])
Expand Down
20 changes: 20 additions & 0 deletions pennylane/drawer/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
"""
from pennylane.ops import Controlled, Conditional
from pennylane.measurements import MeasurementProcess, MidMeasureMP, MeasurementValue
from pennylane.tape import QuantumScript


def default_wire_map(tape):
Expand Down Expand Up @@ -201,3 +202,22 @@ def cwire_connections(layers, bit_map):
connected_layers[cwire].append(layer_idx)

return connected_layers, connected_wires


def transform_deferred_measurements_tape(tape):
"""Helper function to replace MeasurementValues with wires for tapes using
deferred measurements."""
if not any(isinstance(op, MidMeasureMP) for op in tape.operations) and any(
m.mv is not None for m in tape.measurements
):
new_measurements = []
for m in tape.measurements:
if m.mv is not None:
new_m = type(m)(wires=m.wires)
new_measurements.append(new_m)
else:
new_measurements.append(m)
new_tape = QuantumScript(tape.operations, new_measurements)
return new_tape

return tape
51 changes: 51 additions & 0 deletions tests/transforms/test_defer_measurements.py
Original file line number Diff line number Diff line change
Expand Up @@ -1580,6 +1580,57 @@ def qfunc():
)
assert qml.draw(transformed_qnode)() == expected

@pytest.mark.parametrize(
"mp, label",
[
(qml.sample, "Sample"),
(qml.probs, "Probs"),
(qml.var, "Var[None]"),
(qml.counts, "Counts"),
(qml.expval, "<None>"),
],
)
def test_drawing_with_mcm_terminal_measure(self, mp, label):
"""Test that drawing a func works correctly when collecting statistics on
mid-circuit measurements."""

def qfunc():
m_0 = qml.measure(0, reset=True)
qml.cond(m_0, qml.RY)(0.312, wires=1)

return mp(op=m_0), qml.expval(qml.Z(1))

dev = qml.device("default.qubit", wires=4)

transformed_qfunc = qml.transforms.defer_measurements(qfunc)
transformed_qnode = qml.QNode(transformed_qfunc, dev)

spaces = " " * len(label)
expval = "<Z>".ljust(len(label))
expected = (
f"0: ─╭●─╭X───────────┤ {spaces}\n"
f"1: ─│──│──╭RY(0.31)─┤ {expval}\n"
f"2: ─╰X─╰●─╰●────────┤ {label}"
)
assert qml.draw(transformed_qnode)() == expected

@pytest.mark.parametrize("mp", [qml.sample, qml.probs, qml.var, qml.counts, qml.expval])
def test_draw_mpl_with_mcm_terminal_measure(self, mp):
"""Test that no error is raised when drawing a circuit which collects
statistics on mid-circuit measurements"""

def qfunc():
m_0 = qml.measure(0, reset=True)
qml.cond(m_0, qml.RY)(0.312, wires=1)

return mp(op=m_0), qml.expval(qml.Z(1))

dev = qml.device("default.qubit", wires=4)

transformed_qfunc = qml.transforms.defer_measurements(qfunc)
transformed_qnode = qml.QNode(transformed_qfunc, dev)
_ = qml.draw_mpl(transformed_qnode)()


def test_custom_wire_labels_allowed_without_reuse():
"""Test that custom wire labels work if no qubits are re-used."""
Expand Down

0 comments on commit 091cd68

Please sign in to comment.