From 077210d6f44412e5a9119bbcf22f91cfc6062586 Mon Sep 17 00:00:00 2001 From: Matthew Treinish Date: Wed, 7 Aug 2024 20:35:20 -0400 Subject: [PATCH] Fix dag visualization with Var wires (#12848) * Fix dag visualization with Var wires This commit fixes the dag visualization for DAGs with classical variables. The Var type was not handled in the attribute callback functions for nodes and edges. This was causing the visualizer to fail if the dag contained these types. This fixes this by adding explict handling for the Var types and using the name attribute of the Var object. * Add release note and test (cherry picked from commit adbe88707cba6ee327e0663563b8dd1713c4dfc6) --- qiskit/visualization/dag_visualization.py | 13 ++++++++++--- .../notes/fix-var-wires-4ebc40e0b19df253.yaml | 8 ++++++++ test/python/visualization/test_dag_drawer.py | 15 ++++++++++++++- 3 files changed, 32 insertions(+), 4 deletions(-) create mode 100644 releasenotes/notes/fix-var-wires-4ebc40e0b19df253.yaml diff --git a/qiskit/visualization/dag_visualization.py b/qiskit/visualization/dag_visualization.py index 73b9c30f6dc8..6229765a6460 100644 --- a/qiskit/visualization/dag_visualization.py +++ b/qiskit/visualization/dag_visualization.py @@ -174,10 +174,13 @@ def node_attr_func(node): label = register_bit_labels.get( node.wire, f"q_{dag.find_bit(node.wire).index}" ) - else: + elif isinstance(node.wire, Clbit): label = register_bit_labels.get( node.wire, f"c_{dag.find_bit(node.wire).index}" ) + else: + label = str(node.wire.name) + n["label"] = label n["color"] = "black" n["style"] = "filled" @@ -187,10 +190,12 @@ def node_attr_func(node): label = register_bit_labels.get( node.wire, f"q[{dag.find_bit(node.wire).index}]" ) - else: + elif isinstance(node.wire, Clbit): label = register_bit_labels.get( node.wire, f"c[{dag.find_bit(node.wire).index}]" ) + else: + label = str(node.wire.name) n["label"] = label n["color"] = "black" n["style"] = "filled" @@ -203,8 +208,10 @@ def edge_attr_func(edge): e = {} if isinstance(edge, Qubit): label = register_bit_labels.get(edge, f"q_{dag.find_bit(edge).index}") - else: + elif isinstance(edge, Clbit): label = register_bit_labels.get(edge, f"c_{dag.find_bit(edge).index}") + else: + label = str(edge.name) e["label"] = label return e diff --git a/releasenotes/notes/fix-var-wires-4ebc40e0b19df253.yaml b/releasenotes/notes/fix-var-wires-4ebc40e0b19df253.yaml new file mode 100644 index 000000000000..7cd1e74806b0 --- /dev/null +++ b/releasenotes/notes/fix-var-wires-4ebc40e0b19df253.yaml @@ -0,0 +1,8 @@ +--- +fixes: + - | + Fixed an issue with :func:`.dag_drawer` and :meth:`.DAGCircuit.draw` + when attempting to visualize a :class:`.DAGCircuit` instance that contained + :class:`.Var` wires. The visualizer would raise an exception trying to + do this which has been fixed so the expected visualization will be + generated. diff --git a/test/python/visualization/test_dag_drawer.py b/test/python/visualization/test_dag_drawer.py index 4b920390e880..d789b1e70682 100644 --- a/test/python/visualization/test_dag_drawer.py +++ b/test/python/visualization/test_dag_drawer.py @@ -16,12 +16,14 @@ import tempfile import unittest -from qiskit.circuit import QuantumRegister, ClassicalRegister, QuantumCircuit, Qubit, Clbit +from qiskit.circuit import QuantumRegister, ClassicalRegister, QuantumCircuit, Qubit, Clbit, Store from qiskit.visualization import dag_drawer from qiskit.exceptions import InvalidFileError from qiskit.visualization import VisualizationError from qiskit.converters import circuit_to_dag, circuit_to_dagdependency from qiskit.utils import optionals as _optionals +from qiskit.dagcircuit import DAGCircuit +from qiskit.circuit.classical import expr, types from .visualization import path_to_diagram_reference, QiskitVisualizationTestCase @@ -108,6 +110,17 @@ def test_dag_drawer_with_dag_dep(self): image = Image.open(tmp_path) self.assertImagesAreEqual(image, image_ref, 0.1) + @unittest.skipUnless(_optionals.HAS_GRAPHVIZ, "Graphviz not installed") + @unittest.skipUnless(_optionals.HAS_PIL, "PIL not installed") + def test_dag_drawer_with_var_wires(self): + """Test visualization works with var nodes.""" + a = expr.Var.new("a", types.Bool()) + dag = DAGCircuit() + dag.add_input_var(a) + dag.apply_operation_back(Store(a, a), (), ()) + image = dag_drawer(dag) + self.assertIsNotNone(image) + if __name__ == "__main__": unittest.main(verbosity=2)