diff --git a/doc/releases/changelog-dev.md b/doc/releases/changelog-dev.md index 9f314673fc7..0980e4c8bdb 100644 --- a/doc/releases/changelog-dev.md +++ b/doc/releases/changelog-dev.md @@ -51,6 +51,7 @@ [(#4803)](https://github.com/PennyLaneAI/pennylane/pull/4803) [(#4832)](https://github.com/PennyLaneAI/pennylane/pull/4832) [(#4901)](https://github.com/PennyLaneAI/pennylane/pull/4901) + [(#4917)](https://github.com/PennyLaneAI/pennylane/pull/4917)

Catalyst is seamlessly integrated with PennyLane ⚗️

diff --git a/pennylane/drawer/tape_text.py b/pennylane/drawer/tape_text.py index a06b650de2f..76afa629fa5 100644 --- a/pennylane/drawer/tape_text.py +++ b/pennylane/drawer/tape_text.py @@ -23,7 +23,7 @@ from pennylane.measurements import Expectation, Probability, Sample, Variance, State, MidMeasureMP from .drawable_layers import drawable_layers -from .utils import convert_wire_order, unwrap_controls, find_mid_measure_cond_connections +from .utils import convert_wire_order, unwrap_controls, cwire_connections @dataclass @@ -39,8 +39,8 @@ class _Config: cur_layer: Optional[int] = None """Current layer index that is being updated""" - final_cond_layers: Optional[list] = None - """List mapping bits to the last layer where they are used""" + cwire_layers: Optional[list] = None + """A list of layers used (mid measure or conditional) for each classical wire.""" decimals: Optional[int] = None """Specifies how to round the parameters of operators""" @@ -74,19 +74,19 @@ def _add_cond_grouping_symbols(op, layer_str, config): max_w = max(mapped_wires) max_b = max(mapped_bits) + n_wires - ctrl_symbol = "╩" if config.cur_layer != config.final_cond_layers[max(mapped_bits)] else "╝" - layer_str[max_b] = "═" + ctrl_symbol + ctrl_symbol = "╩" if config.cur_layer != config.cwire_layers[max(mapped_bits)][-1] else "╝" + layer_str[max_b] = f"═{ctrl_symbol}" for w in range(max_w + 1, max(config.wire_map.values()) + 1): layer_str[w] = "─║" for b in range(n_wires, max_b): if b - n_wires in mapped_bits: - intersection = "╣" if config.cur_layer == config.final_cond_layers[b - n_wires] else "╬" - layer_str[b] = "═" + intersection + intersection = "╣" if config.cur_layer == config.cwire_layers[b - n_wires][-1] else "╬" + layer_str[b] = f"═{intersection}" else: filler = " " if layer_str[b][-1] == " " else "═" - layer_str[b] = filler + "║" + layer_str[b] = f"{filler}║" return layer_str @@ -107,7 +107,7 @@ def _add_mid_measure_grouping_symbols(op, layer_str, config): for b in range(n_wires, bit): filler = " " if layer_str[b][-1] == " " else "═" - layer_str[b] += filler + "║" + layer_str[b] += f"{filler}║" return layer_str @@ -398,12 +398,9 @@ def tape_text( wire_fillers = ["─", " "] bit_fillers = ["═", " "] enders = [True, False] # add "─┤" after all operations - bit_maps = [{}, {}] - bit_maps[0], measurement_layers, final_cond_layers = find_mid_measure_cond_connections( - tape.operations, layers_list[0] - ) - n_bits = len(bit_maps[0]) + bit_map, cwire_layers, _ = cwire_connections(layers_list[0] + layers_list[1]) + n_bits = len(bit_map) wire_totals = [f"{wire}: " for wire in wire_map] bit_totals = ["" for _ in range(n_bits)] @@ -412,15 +409,15 @@ def tape_text( wire_totals = [s.rjust(line_length, " ") for s in wire_totals] bit_totals = [s.rjust(line_length, " ") for s in bit_totals] - for layers, add, w_filler, b_filler, ender, bit_map in zip( - layers_list, add_list, wire_fillers, bit_fillers, enders, bit_maps + for layers, add, w_filler, b_filler, ender in zip( + layers_list, add_list, wire_fillers, bit_fillers, enders ): # Collect information needed for drawing layers config = _Config( wire_map=wire_map, bit_map=bit_map, cur_layer=-1, - final_cond_layers=final_cond_layers, + cwire_layers=cwire_layers, decimals=decimals, cache=cache, ) @@ -429,7 +426,7 @@ def tape_text( # Add filler before current layer layer_str = [w_filler] * n_wires + [" "] * n_bits for b in bit_map.values(): - cur_b_filler = b_filler if measurement_layers[b] < i < final_cond_layers[b] else " " + cur_b_filler = b_filler if min(cwire_layers[b]) < i < max(cwire_layers[b]) else " " layer_str[b + n_wires] = cur_b_filler config.cur_layer = i @@ -462,13 +459,11 @@ def tape_text( # that are used for conditions correctly cur_b_filler = ( b_filler - if bit_map[cur_layer_mid_measure] >= b and i < final_cond_layers[b] + if bit_map[cur_layer_mid_measure] >= b and i < cwire_layers[b][-1] else " " ) else: - cur_b_filler = ( - b_filler if measurement_layers[b] < i < final_cond_layers[b] else " " - ) + cur_b_filler = b_filler if cwire_layers[b][0] < i < cwire_layers[b][-1] else " " layer_str[b + n_wires] = layer_str[b + n_wires].ljust(max_label_len, cur_b_filler) line_length += max_label_len + 1 # one for the filler character @@ -484,7 +479,7 @@ def tape_text( bit_totals = [] for b in range(n_bits): cur_b_filler = ( - b_filler if measurement_layers[b] < i <= final_cond_layers[b] else " " + b_filler if cwire_layers[b][0] < i <= cwire_layers[b][-1] else " " ) bit_totals.append(cur_b_filler) @@ -495,14 +490,12 @@ def tape_text( wire_totals = [w_filler.join([t, s]) for t, s in zip(wire_totals, layer_str[:n_wires])] for j, (bt, s) in enumerate(zip(bit_totals, layer_str[n_wires : n_wires + n_bits])): - cur_b_filler = ( - b_filler if measurement_layers[j] < i <= final_cond_layers[j] else " " - ) + cur_b_filler = b_filler if cwire_layers[j][0] < i <= cwire_layers[j][-1] else " " bit_totals[j] = cur_b_filler.join([bt, s]) if ender: - wire_totals = [s + "─┤" for s in wire_totals] - bit_totals = [s + " " for s in bit_totals] + wire_totals = [f"{s}─┤" for s in wire_totals] + bit_totals = [f"{s} " for s in bit_totals] line_length += 2 diff --git a/pennylane/drawer/utils.py b/pennylane/drawer/utils.py index b9a22c4ac2b..27a6d339976 100644 --- a/pennylane/drawer/utils.py +++ b/pennylane/drawer/utils.py @@ -103,70 +103,57 @@ def unwrap_controls(op): return control_wires, control_values -def find_mid_measure_cond_connections(operations, layers): - """Collect and return information about connections between mid-circuit measurements - and classical conditions. - - This utility function returns three items needed for processing mid-circuit measurements - and classical conditions for drawing: - - * A dictionary mapping each mid-circuit measurement to a corresponding bit index. - This map only contains mid-circuit measurements that are used for classical conditioning. - * A list where each index is a bit and the values are the indices of the layers containing - the mid-circuit measurement corresponding to the bits. - * A list where each index is a bit and the values are the indices of the last layers that - use those bits for classical conditions. +def cwire_connections(layers): + """Extract the information required for classical control wires. Args: - operations (list[~.Operation]): List of operations on the tape - layers (list[list[~.Operation]]): List of drawable layers containing list of operations - for each layer + layers (List[List[.Operator, .MeasurementProcess]]): the operations and measurements sorted + into layers via ``drawable_layers``. Measurement layers may be appended to operation layers. Returns: - tuple[dict, list, list]: Data structures needed for correctly drawing classical conditions - as described above. - """ + dict, list, list: map from mid circuit measurement to classical wire, list of list of accessed layers + for each classical wire, and largest wire corresponding to the accessed layers in the list above. + + >>> with qml.queuing.AnnotatedQueue() as q: + ... m0 = qml.measure(0) + ... m1 = qml.measure(1) + ... qml.cond(m0 & m1, qml.PauliY)(0) + ... qml.cond(m0, qml.S)(3) + >>> tape = qml.tape.QuantumScript.from_queue(q) + >>> layers = drawable_layers(tape) + >>> bit_map, cwire_layers, cwire_wires = cwire_connections(layers) + >>> bit_map + {measure(wires=[0]): 0, measure(wires=[1]): 1} + >>> cwire_layers + [[0, 2, 3], [1, 2]] + >>> cwire_wires + [[0, 0, 3], [1, 0]] + + From this information, we can see that the first classical wire is active in layers + 0, 2, and 3 while the second classical wire is active in layers 1 and 2. The first "active" + layer will always be the one with the mid circuit measurement. - # Map between mid-circuit measurements and their position on the drawing - # The bit map only contains mid-circuit measurements that are used in - # classical conditions. + """ bit_map = {} - - # Map between classical bit positions and the layer of their corresponding mid-circuit - # measurements. - measurement_layers = [] - - # Map between classical bit positions and the final layer where the bit is used. - # This is needed to know when to stop drawing a bit line. The bit is the index, - # so each of the two lists must have the same length as the number of bits - final_cond_layers = [] - - measurements_for_conds = set() - conditional_ops = [] - for op in operations: - if isinstance(op, Conditional): - measurements_for_conds.update(op.meas_val.measurements) - conditional_ops.append(op) - - if len(measurements_for_conds) > 0: - cond_mid_measures = [op for op in operations if op in measurements_for_conds] - cond_mid_measures.sort(key=operations.index) - - bit_map = dict(zip(cond_mid_measures, range(len(cond_mid_measures)))) - - n_bits = len(bit_map) - - # Set lists to correct size - measurement_layers = [None] * n_bits - final_cond_layers = [None] * n_bits - - for i, layer in enumerate(layers): - for op in layer: - if isinstance(op, MidMeasureMP) and op in bit_map: - measurement_layers[bit_map[op]] = i - - if isinstance(op, Conditional): - for mid_measure in op.meas_val.measurements: - final_cond_layers[bit_map[mid_measure]] = i - - return bit_map, measurement_layers, final_cond_layers + for layer in layers: + for op in layer: + if isinstance(op, Conditional): + for m in op.meas_val.measurements: + bit_map[m] = None # place holder till next pass + + connected_layers = [[] for _ in bit_map] + connected_wires = [[] for _ in bit_map] + num_cwires = 0 + for layer_idx, layer in enumerate(layers): + for op in layer: + if isinstance(op, MidMeasureMP) and op in bit_map: + bit_map[op] = num_cwires + connected_layers[num_cwires].append(layer_idx) + connected_wires[num_cwires].append(op.wires[0]) + num_cwires += 1 + elif isinstance(op, Conditional): + for m in op.meas_val.measurements: + cwire = bit_map[m] + connected_layers[cwire].append(layer_idx) + connected_wires[cwire].append(max(op.wires)) + return bit_map, connected_layers, connected_wires diff --git a/tests/drawer/test_drawer_utils.py b/tests/drawer/test_drawer_utils.py index bdc13dc09c9..369a5af9d0b 100644 --- a/tests/drawer/test_drawer_utils.py +++ b/tests/drawer/test_drawer_utils.py @@ -21,7 +21,7 @@ default_wire_map, convert_wire_order, unwrap_controls, - find_mid_measure_cond_connections, + cwire_connections, ) from pennylane.wires import Wires @@ -164,47 +164,56 @@ def test_multi_defined_control_values( assert control_values == expected_control_values -class TestFindMidMeasureCondConnections: - """Tests for the find_mid_measure_cond_connections helper function""" - - def test_no_conds(self): - """Test that the return values are empty if there are no conditional operations.""" - operations = [ - qml.RX(0.123, 0), - qml.measurements.MidMeasureMP(0), - qml.RX(0.456, 1), - qml.measurements.MidMeasureMP(1), - ] - layers = [ - [qml.RX(0.123, 0), qml.RX(0.456, 1)], - [qml.measurements.MidMeasureMP(0), qml.measurements.MidMeasureMP(1)], - ] - - bit_map, measurement_layers, final_cond_layers = find_mid_measure_cond_connections( - operations, layers - ) - - assert bit_map == {} - assert len(measurement_layers) == len(final_cond_layers) == 0 - - def test_multi_meas_multi_cond(self): - """Test that multiple measurements and multiple conditions return the correct - output.""" - with qml.queuing.AnnotatedQueue() as q: - m0 = qml.measure(0) - m1 = qml.measure(1) - m2 = qml.measure(2) - qml.cond(m0 & m1, qml.PauliX)(wires=4) - qml.cond(m1 & m2, qml.PauliX)(wires=5) - qml.cond(m2 & m0, qml.PauliX)(wires=6) - - operations = q.queue - layers = [[op] for op in operations] - - bit_map, measurement_layers, final_cond_layers = find_mid_measure_cond_connections( - operations, layers - ) - - assert bit_map == {operations[0]: 0, operations[1]: 1, operations[2]: 2} - assert measurement_layers == [0, 1, 2] - assert final_cond_layers == [5, 4, 5] +# pylint: disable=use-implicit-booleaness-not-comparison +class TestCwireConnections: + """Tests for the cwire_connections helper method.""" + + def test_null_circuit(self): + """Test null behavior with an empty circuit.""" + cmap, layers, wires = cwire_connections([[]]) + assert cmap == {} + assert layers == [] + assert wires == [] + + def test_single_measure(self): + """Test a single meassurment that does not have a conditional.""" + cmap, layers, wires = cwire_connections([qml.measure(0).measurements]) + assert cmap == {} + assert layers == [] + assert wires == [] + + def test_single_measure_single_cond(self): + """Test a case with a single measurement and a single conditional.""" + m = qml.measure(0) + cond = qml.ops.Conditional(m, qml.PauliX(0)) + layers = [m.measurements, [cond]] + cmap, clayers, wires = cwire_connections(layers) + assert cmap == {m.measurements[0]: 0} + assert clayers == [[0, 1]] + assert wires == [[0, 0]] + + def test_multiple_measure_multiple_cond(self): + """Test a case with multiple measurments and multiple conditionals.""" + m0 = qml.measure(0) + m1 = qml.measure(1) + m2_nonused = qml.measure(2) + + cond0 = qml.ops.Conditional(m0 + m1, qml.PauliX(1)) + cond1 = qml.ops.Conditional(m1, qml.PauliY(2)) + + layers = [m0.measurements, m1.measurements, [cond0], m2_nonused.measurements, [cond1]] + cmap, clayers, wires = cwire_connections(layers) + assert cmap == {m0.measurements[0]: 0, m1.measurements[0]: 1} + assert clayers == [[0, 2], [1, 2, 4]] + assert wires == [[0, 1], [1, 1, 2]] + + def test_measurements_layer(self): + """Test cwire_connections works if measurement layers are appended at the end.""" + + m0 = qml.measure(0) + cond0 = qml.ops.Conditional(m0, qml.S(0)) + layers = [m0.measurements, [cond0], [qml.expval(qml.PauliX(0))]] + cmap, clayers, wires = cwire_connections(layers) + assert cmap == {m0.measurements[0]: 0} + assert clayers == [[0, 1]] + assert wires == [[0, 0]] diff --git a/tests/drawer/test_tape_text.py b/tests/drawer/test_tape_text.py index 8744ea16df1..627e9a462c4 100644 --- a/tests/drawer/test_tape_text.py +++ b/tests/drawer/test_tape_text.py @@ -167,7 +167,7 @@ def test_add_cond_grouping_symbols(self, cond_op, bit_map, mv, cur_layer, args, wire_map=default_wire_map, bit_map=bit_map, cur_layer=cur_layer, - final_cond_layers=[0, 1], + cwire_layers=[[0], [1]], ), ) @@ -290,7 +290,7 @@ def test_add_cond_op(self, cond_op, bit_map, mv, args, kwargs, out): op, layer_str, _Config( - wire_map=default_wire_map, bit_map=bit_map, cur_layer=1, final_cond_layers=[0, 1] + wire_map=default_wire_map, bit_map=bit_map, cur_layer=1, cwire_layers=[[0], [1]] ), )