diff --git a/doc/releases/changelog-dev.md b/doc/releases/changelog-dev.md
index 9f314673fc7..0980e4c8bdb 100644
--- a/doc/releases/changelog-dev.md
+++ b/doc/releases/changelog-dev.md
@@ -51,6 +51,7 @@
[(#4803)](https://github.com/PennyLaneAI/pennylane/pull/4803)
[(#4832)](https://github.com/PennyLaneAI/pennylane/pull/4832)
[(#4901)](https://github.com/PennyLaneAI/pennylane/pull/4901)
+ [(#4917)](https://github.com/PennyLaneAI/pennylane/pull/4917)
Catalyst is seamlessly integrated with PennyLane ⚗️
diff --git a/pennylane/drawer/tape_text.py b/pennylane/drawer/tape_text.py
index a06b650de2f..76afa629fa5 100644
--- a/pennylane/drawer/tape_text.py
+++ b/pennylane/drawer/tape_text.py
@@ -23,7 +23,7 @@
from pennylane.measurements import Expectation, Probability, Sample, Variance, State, MidMeasureMP
from .drawable_layers import drawable_layers
-from .utils import convert_wire_order, unwrap_controls, find_mid_measure_cond_connections
+from .utils import convert_wire_order, unwrap_controls, cwire_connections
@dataclass
@@ -39,8 +39,8 @@ class _Config:
cur_layer: Optional[int] = None
"""Current layer index that is being updated"""
- final_cond_layers: Optional[list] = None
- """List mapping bits to the last layer where they are used"""
+ cwire_layers: Optional[list] = None
+ """A list of layers used (mid measure or conditional) for each classical wire."""
decimals: Optional[int] = None
"""Specifies how to round the parameters of operators"""
@@ -74,19 +74,19 @@ def _add_cond_grouping_symbols(op, layer_str, config):
max_w = max(mapped_wires)
max_b = max(mapped_bits) + n_wires
- ctrl_symbol = "╩" if config.cur_layer != config.final_cond_layers[max(mapped_bits)] else "╝"
- layer_str[max_b] = "═" + ctrl_symbol
+ ctrl_symbol = "╩" if config.cur_layer != config.cwire_layers[max(mapped_bits)][-1] else "╝"
+ layer_str[max_b] = f"═{ctrl_symbol}"
for w in range(max_w + 1, max(config.wire_map.values()) + 1):
layer_str[w] = "─║"
for b in range(n_wires, max_b):
if b - n_wires in mapped_bits:
- intersection = "╣" if config.cur_layer == config.final_cond_layers[b - n_wires] else "╬"
- layer_str[b] = "═" + intersection
+ intersection = "╣" if config.cur_layer == config.cwire_layers[b - n_wires][-1] else "╬"
+ layer_str[b] = f"═{intersection}"
else:
filler = " " if layer_str[b][-1] == " " else "═"
- layer_str[b] = filler + "║"
+ layer_str[b] = f"{filler}║"
return layer_str
@@ -107,7 +107,7 @@ def _add_mid_measure_grouping_symbols(op, layer_str, config):
for b in range(n_wires, bit):
filler = " " if layer_str[b][-1] == " " else "═"
- layer_str[b] += filler + "║"
+ layer_str[b] += f"{filler}║"
return layer_str
@@ -398,12 +398,9 @@ def tape_text(
wire_fillers = ["─", " "]
bit_fillers = ["═", " "]
enders = [True, False] # add "─┤" after all operations
- bit_maps = [{}, {}]
- bit_maps[0], measurement_layers, final_cond_layers = find_mid_measure_cond_connections(
- tape.operations, layers_list[0]
- )
- n_bits = len(bit_maps[0])
+ bit_map, cwire_layers, _ = cwire_connections(layers_list[0] + layers_list[1])
+ n_bits = len(bit_map)
wire_totals = [f"{wire}: " for wire in wire_map]
bit_totals = ["" for _ in range(n_bits)]
@@ -412,15 +409,15 @@ def tape_text(
wire_totals = [s.rjust(line_length, " ") for s in wire_totals]
bit_totals = [s.rjust(line_length, " ") for s in bit_totals]
- for layers, add, w_filler, b_filler, ender, bit_map in zip(
- layers_list, add_list, wire_fillers, bit_fillers, enders, bit_maps
+ for layers, add, w_filler, b_filler, ender in zip(
+ layers_list, add_list, wire_fillers, bit_fillers, enders
):
# Collect information needed for drawing layers
config = _Config(
wire_map=wire_map,
bit_map=bit_map,
cur_layer=-1,
- final_cond_layers=final_cond_layers,
+ cwire_layers=cwire_layers,
decimals=decimals,
cache=cache,
)
@@ -429,7 +426,7 @@ def tape_text(
# Add filler before current layer
layer_str = [w_filler] * n_wires + [" "] * n_bits
for b in bit_map.values():
- cur_b_filler = b_filler if measurement_layers[b] < i < final_cond_layers[b] else " "
+ cur_b_filler = b_filler if min(cwire_layers[b]) < i < max(cwire_layers[b]) else " "
layer_str[b + n_wires] = cur_b_filler
config.cur_layer = i
@@ -462,13 +459,11 @@ def tape_text(
# that are used for conditions correctly
cur_b_filler = (
b_filler
- if bit_map[cur_layer_mid_measure] >= b and i < final_cond_layers[b]
+ if bit_map[cur_layer_mid_measure] >= b and i < cwire_layers[b][-1]
else " "
)
else:
- cur_b_filler = (
- b_filler if measurement_layers[b] < i < final_cond_layers[b] else " "
- )
+ cur_b_filler = b_filler if cwire_layers[b][0] < i < cwire_layers[b][-1] else " "
layer_str[b + n_wires] = layer_str[b + n_wires].ljust(max_label_len, cur_b_filler)
line_length += max_label_len + 1 # one for the filler character
@@ -484,7 +479,7 @@ def tape_text(
bit_totals = []
for b in range(n_bits):
cur_b_filler = (
- b_filler if measurement_layers[b] < i <= final_cond_layers[b] else " "
+ b_filler if cwire_layers[b][0] < i <= cwire_layers[b][-1] else " "
)
bit_totals.append(cur_b_filler)
@@ -495,14 +490,12 @@ def tape_text(
wire_totals = [w_filler.join([t, s]) for t, s in zip(wire_totals, layer_str[:n_wires])]
for j, (bt, s) in enumerate(zip(bit_totals, layer_str[n_wires : n_wires + n_bits])):
- cur_b_filler = (
- b_filler if measurement_layers[j] < i <= final_cond_layers[j] else " "
- )
+ cur_b_filler = b_filler if cwire_layers[j][0] < i <= cwire_layers[j][-1] else " "
bit_totals[j] = cur_b_filler.join([bt, s])
if ender:
- wire_totals = [s + "─┤" for s in wire_totals]
- bit_totals = [s + " " for s in bit_totals]
+ wire_totals = [f"{s}─┤" for s in wire_totals]
+ bit_totals = [f"{s} " for s in bit_totals]
line_length += 2
diff --git a/pennylane/drawer/utils.py b/pennylane/drawer/utils.py
index b9a22c4ac2b..27a6d339976 100644
--- a/pennylane/drawer/utils.py
+++ b/pennylane/drawer/utils.py
@@ -103,70 +103,57 @@ def unwrap_controls(op):
return control_wires, control_values
-def find_mid_measure_cond_connections(operations, layers):
- """Collect and return information about connections between mid-circuit measurements
- and classical conditions.
-
- This utility function returns three items needed for processing mid-circuit measurements
- and classical conditions for drawing:
-
- * A dictionary mapping each mid-circuit measurement to a corresponding bit index.
- This map only contains mid-circuit measurements that are used for classical conditioning.
- * A list where each index is a bit and the values are the indices of the layers containing
- the mid-circuit measurement corresponding to the bits.
- * A list where each index is a bit and the values are the indices of the last layers that
- use those bits for classical conditions.
+def cwire_connections(layers):
+ """Extract the information required for classical control wires.
Args:
- operations (list[~.Operation]): List of operations on the tape
- layers (list[list[~.Operation]]): List of drawable layers containing list of operations
- for each layer
+ layers (List[List[.Operator, .MeasurementProcess]]): the operations and measurements sorted
+ into layers via ``drawable_layers``. Measurement layers may be appended to operation layers.
Returns:
- tuple[dict, list, list]: Data structures needed for correctly drawing classical conditions
- as described above.
- """
+ dict, list, list: map from mid circuit measurement to classical wire, list of list of accessed layers
+ for each classical wire, and largest wire corresponding to the accessed layers in the list above.
+
+ >>> with qml.queuing.AnnotatedQueue() as q:
+ ... m0 = qml.measure(0)
+ ... m1 = qml.measure(1)
+ ... qml.cond(m0 & m1, qml.PauliY)(0)
+ ... qml.cond(m0, qml.S)(3)
+ >>> tape = qml.tape.QuantumScript.from_queue(q)
+ >>> layers = drawable_layers(tape)
+ >>> bit_map, cwire_layers, cwire_wires = cwire_connections(layers)
+ >>> bit_map
+ {measure(wires=[0]): 0, measure(wires=[1]): 1}
+ >>> cwire_layers
+ [[0, 2, 3], [1, 2]]
+ >>> cwire_wires
+ [[0, 0, 3], [1, 0]]
+
+ From this information, we can see that the first classical wire is active in layers
+ 0, 2, and 3 while the second classical wire is active in layers 1 and 2. The first "active"
+ layer will always be the one with the mid circuit measurement.
- # Map between mid-circuit measurements and their position on the drawing
- # The bit map only contains mid-circuit measurements that are used in
- # classical conditions.
+ """
bit_map = {}
-
- # Map between classical bit positions and the layer of their corresponding mid-circuit
- # measurements.
- measurement_layers = []
-
- # Map between classical bit positions and the final layer where the bit is used.
- # This is needed to know when to stop drawing a bit line. The bit is the index,
- # so each of the two lists must have the same length as the number of bits
- final_cond_layers = []
-
- measurements_for_conds = set()
- conditional_ops = []
- for op in operations:
- if isinstance(op, Conditional):
- measurements_for_conds.update(op.meas_val.measurements)
- conditional_ops.append(op)
-
- if len(measurements_for_conds) > 0:
- cond_mid_measures = [op for op in operations if op in measurements_for_conds]
- cond_mid_measures.sort(key=operations.index)
-
- bit_map = dict(zip(cond_mid_measures, range(len(cond_mid_measures))))
-
- n_bits = len(bit_map)
-
- # Set lists to correct size
- measurement_layers = [None] * n_bits
- final_cond_layers = [None] * n_bits
-
- for i, layer in enumerate(layers):
- for op in layer:
- if isinstance(op, MidMeasureMP) and op in bit_map:
- measurement_layers[bit_map[op]] = i
-
- if isinstance(op, Conditional):
- for mid_measure in op.meas_val.measurements:
- final_cond_layers[bit_map[mid_measure]] = i
-
- return bit_map, measurement_layers, final_cond_layers
+ for layer in layers:
+ for op in layer:
+ if isinstance(op, Conditional):
+ for m in op.meas_val.measurements:
+ bit_map[m] = None # place holder till next pass
+
+ connected_layers = [[] for _ in bit_map]
+ connected_wires = [[] for _ in bit_map]
+ num_cwires = 0
+ for layer_idx, layer in enumerate(layers):
+ for op in layer:
+ if isinstance(op, MidMeasureMP) and op in bit_map:
+ bit_map[op] = num_cwires
+ connected_layers[num_cwires].append(layer_idx)
+ connected_wires[num_cwires].append(op.wires[0])
+ num_cwires += 1
+ elif isinstance(op, Conditional):
+ for m in op.meas_val.measurements:
+ cwire = bit_map[m]
+ connected_layers[cwire].append(layer_idx)
+ connected_wires[cwire].append(max(op.wires))
+ return bit_map, connected_layers, connected_wires
diff --git a/tests/drawer/test_drawer_utils.py b/tests/drawer/test_drawer_utils.py
index bdc13dc09c9..369a5af9d0b 100644
--- a/tests/drawer/test_drawer_utils.py
+++ b/tests/drawer/test_drawer_utils.py
@@ -21,7 +21,7 @@
default_wire_map,
convert_wire_order,
unwrap_controls,
- find_mid_measure_cond_connections,
+ cwire_connections,
)
from pennylane.wires import Wires
@@ -164,47 +164,56 @@ def test_multi_defined_control_values(
assert control_values == expected_control_values
-class TestFindMidMeasureCondConnections:
- """Tests for the find_mid_measure_cond_connections helper function"""
-
- def test_no_conds(self):
- """Test that the return values are empty if there are no conditional operations."""
- operations = [
- qml.RX(0.123, 0),
- qml.measurements.MidMeasureMP(0),
- qml.RX(0.456, 1),
- qml.measurements.MidMeasureMP(1),
- ]
- layers = [
- [qml.RX(0.123, 0), qml.RX(0.456, 1)],
- [qml.measurements.MidMeasureMP(0), qml.measurements.MidMeasureMP(1)],
- ]
-
- bit_map, measurement_layers, final_cond_layers = find_mid_measure_cond_connections(
- operations, layers
- )
-
- assert bit_map == {}
- assert len(measurement_layers) == len(final_cond_layers) == 0
-
- def test_multi_meas_multi_cond(self):
- """Test that multiple measurements and multiple conditions return the correct
- output."""
- with qml.queuing.AnnotatedQueue() as q:
- m0 = qml.measure(0)
- m1 = qml.measure(1)
- m2 = qml.measure(2)
- qml.cond(m0 & m1, qml.PauliX)(wires=4)
- qml.cond(m1 & m2, qml.PauliX)(wires=5)
- qml.cond(m2 & m0, qml.PauliX)(wires=6)
-
- operations = q.queue
- layers = [[op] for op in operations]
-
- bit_map, measurement_layers, final_cond_layers = find_mid_measure_cond_connections(
- operations, layers
- )
-
- assert bit_map == {operations[0]: 0, operations[1]: 1, operations[2]: 2}
- assert measurement_layers == [0, 1, 2]
- assert final_cond_layers == [5, 4, 5]
+# pylint: disable=use-implicit-booleaness-not-comparison
+class TestCwireConnections:
+ """Tests for the cwire_connections helper method."""
+
+ def test_null_circuit(self):
+ """Test null behavior with an empty circuit."""
+ cmap, layers, wires = cwire_connections([[]])
+ assert cmap == {}
+ assert layers == []
+ assert wires == []
+
+ def test_single_measure(self):
+ """Test a single meassurment that does not have a conditional."""
+ cmap, layers, wires = cwire_connections([qml.measure(0).measurements])
+ assert cmap == {}
+ assert layers == []
+ assert wires == []
+
+ def test_single_measure_single_cond(self):
+ """Test a case with a single measurement and a single conditional."""
+ m = qml.measure(0)
+ cond = qml.ops.Conditional(m, qml.PauliX(0))
+ layers = [m.measurements, [cond]]
+ cmap, clayers, wires = cwire_connections(layers)
+ assert cmap == {m.measurements[0]: 0}
+ assert clayers == [[0, 1]]
+ assert wires == [[0, 0]]
+
+ def test_multiple_measure_multiple_cond(self):
+ """Test a case with multiple measurments and multiple conditionals."""
+ m0 = qml.measure(0)
+ m1 = qml.measure(1)
+ m2_nonused = qml.measure(2)
+
+ cond0 = qml.ops.Conditional(m0 + m1, qml.PauliX(1))
+ cond1 = qml.ops.Conditional(m1, qml.PauliY(2))
+
+ layers = [m0.measurements, m1.measurements, [cond0], m2_nonused.measurements, [cond1]]
+ cmap, clayers, wires = cwire_connections(layers)
+ assert cmap == {m0.measurements[0]: 0, m1.measurements[0]: 1}
+ assert clayers == [[0, 2], [1, 2, 4]]
+ assert wires == [[0, 1], [1, 1, 2]]
+
+ def test_measurements_layer(self):
+ """Test cwire_connections works if measurement layers are appended at the end."""
+
+ m0 = qml.measure(0)
+ cond0 = qml.ops.Conditional(m0, qml.S(0))
+ layers = [m0.measurements, [cond0], [qml.expval(qml.PauliX(0))]]
+ cmap, clayers, wires = cwire_connections(layers)
+ assert cmap == {m0.measurements[0]: 0}
+ assert clayers == [[0, 1]]
+ assert wires == [[0, 0]]
diff --git a/tests/drawer/test_tape_text.py b/tests/drawer/test_tape_text.py
index 8744ea16df1..627e9a462c4 100644
--- a/tests/drawer/test_tape_text.py
+++ b/tests/drawer/test_tape_text.py
@@ -167,7 +167,7 @@ def test_add_cond_grouping_symbols(self, cond_op, bit_map, mv, cur_layer, args,
wire_map=default_wire_map,
bit_map=bit_map,
cur_layer=cur_layer,
- final_cond_layers=[0, 1],
+ cwire_layers=[[0], [1]],
),
)
@@ -290,7 +290,7 @@ def test_add_cond_op(self, cond_op, bit_map, mv, args, kwargs, out):
op,
layer_str,
_Config(
- wire_map=default_wire_map, bit_map=bit_map, cur_layer=1, final_cond_layers=[0, 1]
+ wire_map=default_wire_map, bit_map=bit_map, cur_layer=1, cwire_layers=[[0], [1]]
),
)