Skip to content

Commit

Permalink
Add utility for extracting classical control information to drawer (#…
Browse files Browse the repository at this point in the history
…4917)

This PR replaces `find_mid_measure_cond_connections` with
`cwire_connections`. It returns very similar information, but is a
slight variant that will be easier to use in the matplotlib drawing of
classical wires.

---------

Co-authored-by: Mudit Pandey <mudit.pandey@xanadu.ai>
  • Loading branch information
albi3ro and mudit2812 committed Jan 19, 2024
1 parent 96b089b commit 12a8d4a
Show file tree
Hide file tree
Showing 5 changed files with 126 additions and 136 deletions.
1 change: 1 addition & 0 deletions doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
[(#4803)](https://github.com/PennyLaneAI/pennylane/pull/4803)
[(#4832)](https://github.com/PennyLaneAI/pennylane/pull/4832)
[(#4901)](https://github.com/PennyLaneAI/pennylane/pull/4901)
[(#4917)](https://github.com/PennyLaneAI/pennylane/pull/4917)

<h4>Catalyst is seamlessly integrated with PennyLane ⚗️</h4>

Expand Down
49 changes: 21 additions & 28 deletions pennylane/drawer/tape_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from pennylane.measurements import Expectation, Probability, Sample, Variance, State, MidMeasureMP

from .drawable_layers import drawable_layers
from .utils import convert_wire_order, unwrap_controls, find_mid_measure_cond_connections
from .utils import convert_wire_order, unwrap_controls, cwire_connections


@dataclass
Expand All @@ -39,8 +39,8 @@ class _Config:
cur_layer: Optional[int] = None
"""Current layer index that is being updated"""

final_cond_layers: Optional[list] = None
"""List mapping bits to the last layer where they are used"""
cwire_layers: Optional[list] = None
"""A list of layers used (mid measure or conditional) for each classical wire."""

decimals: Optional[int] = None
"""Specifies how to round the parameters of operators"""
Expand Down Expand Up @@ -74,19 +74,19 @@ def _add_cond_grouping_symbols(op, layer_str, config):
max_w = max(mapped_wires)
max_b = max(mapped_bits) + n_wires

ctrl_symbol = "╩" if config.cur_layer != config.final_cond_layers[max(mapped_bits)] else "╝"
layer_str[max_b] = "═" + ctrl_symbol
ctrl_symbol = "╩" if config.cur_layer != config.cwire_layers[max(mapped_bits)][-1] else "╝"
layer_str[max_b] = f"═{ctrl_symbol}"

for w in range(max_w + 1, max(config.wire_map.values()) + 1):
layer_str[w] = "─║"

for b in range(n_wires, max_b):
if b - n_wires in mapped_bits:
intersection = "╣" if config.cur_layer == config.final_cond_layers[b - n_wires] else "╬"
layer_str[b] = "═" + intersection
intersection = "╣" if config.cur_layer == config.cwire_layers[b - n_wires][-1] else "╬"
layer_str[b] = f"═{intersection}"
else:
filler = " " if layer_str[b][-1] == " " else "═"
layer_str[b] = filler + "║"
layer_str[b] = f"{filler}║"

return layer_str

Expand All @@ -107,7 +107,7 @@ def _add_mid_measure_grouping_symbols(op, layer_str, config):

for b in range(n_wires, bit):
filler = " " if layer_str[b][-1] == " " else "═"
layer_str[b] += filler + "║"
layer_str[b] += f"{filler}║"

return layer_str

Expand Down Expand Up @@ -398,12 +398,9 @@ def tape_text(
wire_fillers = ["─", " "]
bit_fillers = ["═", " "]
enders = [True, False] # add "─┤" after all operations
bit_maps = [{}, {}]

bit_maps[0], measurement_layers, final_cond_layers = find_mid_measure_cond_connections(
tape.operations, layers_list[0]
)
n_bits = len(bit_maps[0])
bit_map, cwire_layers, _ = cwire_connections(layers_list[0] + layers_list[1])
n_bits = len(bit_map)

wire_totals = [f"{wire}: " for wire in wire_map]
bit_totals = ["" for _ in range(n_bits)]
Expand All @@ -412,15 +409,15 @@ def tape_text(
wire_totals = [s.rjust(line_length, " ") for s in wire_totals]
bit_totals = [s.rjust(line_length, " ") for s in bit_totals]

for layers, add, w_filler, b_filler, ender, bit_map in zip(
layers_list, add_list, wire_fillers, bit_fillers, enders, bit_maps
for layers, add, w_filler, b_filler, ender in zip(
layers_list, add_list, wire_fillers, bit_fillers, enders
):
# Collect information needed for drawing layers
config = _Config(
wire_map=wire_map,
bit_map=bit_map,
cur_layer=-1,
final_cond_layers=final_cond_layers,
cwire_layers=cwire_layers,
decimals=decimals,
cache=cache,
)
Expand All @@ -429,7 +426,7 @@ def tape_text(
# Add filler before current layer
layer_str = [w_filler] * n_wires + [" "] * n_bits
for b in bit_map.values():
cur_b_filler = b_filler if measurement_layers[b] < i < final_cond_layers[b] else " "
cur_b_filler = b_filler if min(cwire_layers[b]) < i < max(cwire_layers[b]) else " "
layer_str[b + n_wires] = cur_b_filler

config.cur_layer = i
Expand Down Expand Up @@ -462,13 +459,11 @@ def tape_text(
# that are used for conditions correctly
cur_b_filler = (
b_filler
if bit_map[cur_layer_mid_measure] >= b and i < final_cond_layers[b]
if bit_map[cur_layer_mid_measure] >= b and i < cwire_layers[b][-1]
else " "
)
else:
cur_b_filler = (
b_filler if measurement_layers[b] < i < final_cond_layers[b] else " "
)
cur_b_filler = b_filler if cwire_layers[b][0] < i < cwire_layers[b][-1] else " "
layer_str[b + n_wires] = layer_str[b + n_wires].ljust(max_label_len, cur_b_filler)

line_length += max_label_len + 1 # one for the filler character
Expand All @@ -484,7 +479,7 @@ def tape_text(
bit_totals = []
for b in range(n_bits):
cur_b_filler = (
b_filler if measurement_layers[b] < i <= final_cond_layers[b] else " "
b_filler if cwire_layers[b][0] < i <= cwire_layers[b][-1] else " "
)
bit_totals.append(cur_b_filler)

Expand All @@ -495,14 +490,12 @@ def tape_text(
wire_totals = [w_filler.join([t, s]) for t, s in zip(wire_totals, layer_str[:n_wires])]

for j, (bt, s) in enumerate(zip(bit_totals, layer_str[n_wires : n_wires + n_bits])):
cur_b_filler = (
b_filler if measurement_layers[j] < i <= final_cond_layers[j] else " "
)
cur_b_filler = b_filler if cwire_layers[j][0] < i <= cwire_layers[j][-1] else " "
bit_totals[j] = cur_b_filler.join([bt, s])

if ender:
wire_totals = [s + "─┤" for s in wire_totals]
bit_totals = [s + " " for s in bit_totals]
wire_totals = [f"{s}─┤" for s in wire_totals]
bit_totals = [f"{s} " for s in bit_totals]

line_length += 2

Expand Down
109 changes: 48 additions & 61 deletions pennylane/drawer/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,70 +103,57 @@ def unwrap_controls(op):
return control_wires, control_values


def find_mid_measure_cond_connections(operations, layers):
"""Collect and return information about connections between mid-circuit measurements
and classical conditions.
This utility function returns three items needed for processing mid-circuit measurements
and classical conditions for drawing:
* A dictionary mapping each mid-circuit measurement to a corresponding bit index.
This map only contains mid-circuit measurements that are used for classical conditioning.
* A list where each index is a bit and the values are the indices of the layers containing
the mid-circuit measurement corresponding to the bits.
* A list where each index is a bit and the values are the indices of the last layers that
use those bits for classical conditions.
def cwire_connections(layers):
"""Extract the information required for classical control wires.
Args:
operations (list[~.Operation]): List of operations on the tape
layers (list[list[~.Operation]]): List of drawable layers containing list of operations
for each layer
layers (List[List[.Operator, .MeasurementProcess]]): the operations and measurements sorted
into layers via ``drawable_layers``. Measurement layers may be appended to operation layers.
Returns:
tuple[dict, list, list]: Data structures needed for correctly drawing classical conditions
as described above.
"""
dict, list, list: map from mid circuit measurement to classical wire, list of list of accessed layers
for each classical wire, and largest wire corresponding to the accessed layers in the list above.
>>> with qml.queuing.AnnotatedQueue() as q:
... m0 = qml.measure(0)
... m1 = qml.measure(1)
... qml.cond(m0 & m1, qml.PauliY)(0)
... qml.cond(m0, qml.S)(3)
>>> tape = qml.tape.QuantumScript.from_queue(q)
>>> layers = drawable_layers(tape)
>>> bit_map, cwire_layers, cwire_wires = cwire_connections(layers)
>>> bit_map
{measure(wires=[0]): 0, measure(wires=[1]): 1}
>>> cwire_layers
[[0, 2, 3], [1, 2]]
>>> cwire_wires
[[0, 0, 3], [1, 0]]
From this information, we can see that the first classical wire is active in layers
0, 2, and 3 while the second classical wire is active in layers 1 and 2. The first "active"
layer will always be the one with the mid circuit measurement.
# Map between mid-circuit measurements and their position on the drawing
# The bit map only contains mid-circuit measurements that are used in
# classical conditions.
"""
bit_map = {}

# Map between classical bit positions and the layer of their corresponding mid-circuit
# measurements.
measurement_layers = []

# Map between classical bit positions and the final layer where the bit is used.
# This is needed to know when to stop drawing a bit line. The bit is the index,
# so each of the two lists must have the same length as the number of bits
final_cond_layers = []

measurements_for_conds = set()
conditional_ops = []
for op in operations:
if isinstance(op, Conditional):
measurements_for_conds.update(op.meas_val.measurements)
conditional_ops.append(op)

if len(measurements_for_conds) > 0:
cond_mid_measures = [op for op in operations if op in measurements_for_conds]
cond_mid_measures.sort(key=operations.index)

bit_map = dict(zip(cond_mid_measures, range(len(cond_mid_measures))))

n_bits = len(bit_map)

# Set lists to correct size
measurement_layers = [None] * n_bits
final_cond_layers = [None] * n_bits

for i, layer in enumerate(layers):
for op in layer:
if isinstance(op, MidMeasureMP) and op in bit_map:
measurement_layers[bit_map[op]] = i

if isinstance(op, Conditional):
for mid_measure in op.meas_val.measurements:
final_cond_layers[bit_map[mid_measure]] = i

return bit_map, measurement_layers, final_cond_layers
for layer in layers:
for op in layer:
if isinstance(op, Conditional):
for m in op.meas_val.measurements:
bit_map[m] = None # place holder till next pass

connected_layers = [[] for _ in bit_map]
connected_wires = [[] for _ in bit_map]
num_cwires = 0
for layer_idx, layer in enumerate(layers):
for op in layer:
if isinstance(op, MidMeasureMP) and op in bit_map:
bit_map[op] = num_cwires
connected_layers[num_cwires].append(layer_idx)
connected_wires[num_cwires].append(op.wires[0])
num_cwires += 1
elif isinstance(op, Conditional):
for m in op.meas_val.measurements:
cwire = bit_map[m]
connected_layers[cwire].append(layer_idx)
connected_wires[cwire].append(max(op.wires))
return bit_map, connected_layers, connected_wires
99 changes: 54 additions & 45 deletions tests/drawer/test_drawer_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
default_wire_map,
convert_wire_order,
unwrap_controls,
find_mid_measure_cond_connections,
cwire_connections,
)
from pennylane.wires import Wires

Expand Down Expand Up @@ -164,47 +164,56 @@ def test_multi_defined_control_values(
assert control_values == expected_control_values


class TestFindMidMeasureCondConnections:
"""Tests for the find_mid_measure_cond_connections helper function"""

def test_no_conds(self):
"""Test that the return values are empty if there are no conditional operations."""
operations = [
qml.RX(0.123, 0),
qml.measurements.MidMeasureMP(0),
qml.RX(0.456, 1),
qml.measurements.MidMeasureMP(1),
]
layers = [
[qml.RX(0.123, 0), qml.RX(0.456, 1)],
[qml.measurements.MidMeasureMP(0), qml.measurements.MidMeasureMP(1)],
]

bit_map, measurement_layers, final_cond_layers = find_mid_measure_cond_connections(
operations, layers
)

assert bit_map == {}
assert len(measurement_layers) == len(final_cond_layers) == 0

def test_multi_meas_multi_cond(self):
"""Test that multiple measurements and multiple conditions return the correct
output."""
with qml.queuing.AnnotatedQueue() as q:
m0 = qml.measure(0)
m1 = qml.measure(1)
m2 = qml.measure(2)
qml.cond(m0 & m1, qml.PauliX)(wires=4)
qml.cond(m1 & m2, qml.PauliX)(wires=5)
qml.cond(m2 & m0, qml.PauliX)(wires=6)

operations = q.queue
layers = [[op] for op in operations]

bit_map, measurement_layers, final_cond_layers = find_mid_measure_cond_connections(
operations, layers
)

assert bit_map == {operations[0]: 0, operations[1]: 1, operations[2]: 2}
assert measurement_layers == [0, 1, 2]
assert final_cond_layers == [5, 4, 5]
# pylint: disable=use-implicit-booleaness-not-comparison
class TestCwireConnections:
"""Tests for the cwire_connections helper method."""

def test_null_circuit(self):
"""Test null behavior with an empty circuit."""
cmap, layers, wires = cwire_connections([[]])
assert cmap == {}
assert layers == []
assert wires == []

def test_single_measure(self):
"""Test a single meassurment that does not have a conditional."""
cmap, layers, wires = cwire_connections([qml.measure(0).measurements])
assert cmap == {}
assert layers == []
assert wires == []

def test_single_measure_single_cond(self):
"""Test a case with a single measurement and a single conditional."""
m = qml.measure(0)
cond = qml.ops.Conditional(m, qml.PauliX(0))
layers = [m.measurements, [cond]]
cmap, clayers, wires = cwire_connections(layers)
assert cmap == {m.measurements[0]: 0}
assert clayers == [[0, 1]]
assert wires == [[0, 0]]

def test_multiple_measure_multiple_cond(self):
"""Test a case with multiple measurments and multiple conditionals."""
m0 = qml.measure(0)
m1 = qml.measure(1)
m2_nonused = qml.measure(2)

cond0 = qml.ops.Conditional(m0 + m1, qml.PauliX(1))
cond1 = qml.ops.Conditional(m1, qml.PauliY(2))

layers = [m0.measurements, m1.measurements, [cond0], m2_nonused.measurements, [cond1]]
cmap, clayers, wires = cwire_connections(layers)
assert cmap == {m0.measurements[0]: 0, m1.measurements[0]: 1}
assert clayers == [[0, 2], [1, 2, 4]]
assert wires == [[0, 1], [1, 1, 2]]

def test_measurements_layer(self):
"""Test cwire_connections works if measurement layers are appended at the end."""

m0 = qml.measure(0)
cond0 = qml.ops.Conditional(m0, qml.S(0))
layers = [m0.measurements, [cond0], [qml.expval(qml.PauliX(0))]]
cmap, clayers, wires = cwire_connections(layers)
assert cmap == {m0.measurements[0]: 0}
assert clayers == [[0, 1]]
assert wires == [[0, 0]]
4 changes: 2 additions & 2 deletions tests/drawer/test_tape_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ def test_add_cond_grouping_symbols(self, cond_op, bit_map, mv, cur_layer, args,
wire_map=default_wire_map,
bit_map=bit_map,
cur_layer=cur_layer,
final_cond_layers=[0, 1],
cwire_layers=[[0], [1]],
),
)

Expand Down Expand Up @@ -290,7 +290,7 @@ def test_add_cond_op(self, cond_op, bit_map, mv, args, kwargs, out):
op,
layer_str,
_Config(
wire_map=default_wire_map, bit_map=bit_map, cur_layer=1, final_cond_layers=[0, 1]
wire_map=default_wire_map, bit_map=bit_map, cur_layer=1, cwire_layers=[[0], [1]]
),
)

Expand Down

0 comments on commit 12a8d4a

Please sign in to comment.