Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve efficiency of vf2 pass search with free nodes #9148

Merged
merged 6 commits into from
Jan 18, 2023
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion qiskit/transpiler/passes/layout/vf2_layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ def run(self, dag):
if result is None:
self.property_set["VF2Layout_stop_reason"] = VF2LayoutStopReason.MORE_THAN_2Q
return
im_graph, im_graph_node_map, reverse_im_graph_node_map = result
im_graph, im_graph_node_map, reverse_im_graph_node_map, free_nodes = result
cm_graph, cm_nodes = vf2_utils.shuffle_coupling_graph(
self.coupling_map, self.seed, self.strict_direction
)
Expand Down Expand Up @@ -233,6 +233,13 @@ def mapping_to_layout(layout_mapping):
if chosen_layout is None:
stop_reason = VF2LayoutStopReason.NO_SOLUTION_FOUND
else:
chosen_layout = vf2_utils.map_free_qubits(
free_nodes,
chosen_layout,
cm_graph.num_nodes(),
reverse_im_graph_node_map,
self.avg_error_map,
)
self.property_set["layout"] = chosen_layout
for reg in dag.qregs.values():
self.property_set["layout"].add_register(reg)
Expand Down
9 changes: 8 additions & 1 deletion qiskit/transpiler/passes/layout/vf2_post_layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ def run(self, dag):
if result is None:
self.property_set["VF2PostLayout_stop_reason"] = VF2PostLayoutStopReason.MORE_THAN_2Q
return
im_graph, im_graph_node_map, reverse_im_graph_node_map = result
im_graph, im_graph_node_map, reverse_im_graph_node_map, free_nodes = result

if self.target is not None:
# If qargs is None then target is global and ideal so no
Expand Down Expand Up @@ -322,6 +322,13 @@ def run(self, dag):
if chosen_layout is None:
stop_reason = VF2PostLayoutStopReason.NO_SOLUTION_FOUND
else:
chosen_layout = vf2_utils.map_free_qubits(
free_nodes,
chosen_layout,
cm_graph.num_nodes(),
reverse_im_graph_node_map,
self.avg_error_map,
)
existing_layout = self.property_set["layout"]
# If any ancillas in initial layout map them back to the final layout output
if existing_layout is not None and len(existing_layout) > len(chosen_layout):
Expand Down
52 changes: 47 additions & 5 deletions qiskit/transpiler/passes/layout/vf2_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import random

import numpy as np
from rustworkx import PyDiGraph, PyGraph
from rustworkx import PyDiGraph, PyGraph, connected_components

from qiskit.circuit import ControlFlowOp, ForLoopOp
from qiskit.converters import circuit_to_dag
Expand Down Expand Up @@ -79,7 +79,20 @@ def _visit(dag, weight, wire_map):
_visit(dag, 1, {bit: bit for bit in dag.qubits})
except MultiQEncountered:
return None
return im_graph, im_graph_node_map, reverse_im_graph_node_map
# Remove components with no 2q interactions from interaction graph
# these will be evaluated separately independently of scoring isomorphic
# mappings. This is not done for strict direction because for post layout
# we need to factor in local operation constraints when evaluating a graph
free_nodes = {}
if not strict_direction:
conn_comp = connected_components(im_graph)
for comp in conn_comp:
if len(comp) == 1:
index = comp.pop()
free_nodes[index] = im_graph[index]
im_graph.remove_node(index)

return im_graph, im_graph_node_map, reverse_im_graph_node_map, free_nodes


def score_layout(
Expand All @@ -98,8 +111,9 @@ def score_layout(
size = 0
nlayout = NLayout(layout_mapping, size + 1, size + 1)
bit_list = np.zeros(len(im_graph), dtype=np.int32)
for node_index in bit_map.values():
bit_list[node_index] = sum(im_graph[node_index].values())
if strict_direction:
for node_index in bit_map.values():
bit_list[node_index] = sum(im_graph[node_index].values())
Comment on lines -101 to +116
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This (and the bit in build_average_error_map) look like a slightly separate bugfix. Is that right?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, although it's semi related I think (it's been a while since I looked at this patch). The strict direction case doesn't use the common base between the passes in the same way (since vf2post has to do a lot more checking to ensure direcitonality constraints and target basis are preserved). Splitting out the 1q piece from things broke things IIRC (only vaguely remember) so we needed the if condition to branch the logic on that otherwise it failed.

edge_list = {
(edge[0], edge[1]): sum(edge[2].values()) for edge in im_graph.edge_index_map().values()
}
Expand Down Expand Up @@ -162,7 +176,12 @@ def build_average_error_map(target, properties, coupling_map):
continue
avg_map.add_error(qargs, statistics.mean(v))
built = True
elif coupling_map is not None:
# if there are no error rates in the target we should fallback to using the degree heuristic
# used for a coupling map. To do this we can build the coupling map from the target before
# running the fallback heuristic
if not built and target is not None and coupling_map is None:
coupling_map = target.build_coupling_map()
if not built and coupling_map is not None:
for qubit in range(num_qubits):
avg_map.add_error(
(qubit, qubit),
Expand Down Expand Up @@ -194,3 +213,26 @@ def shuffle_coupling_graph(coupling_map, seed, strict_direction=True):
cm_nodes = [k for k, v in sorted(enumerate(cm_nodes), key=lambda item: item[1])]
cm_graph = shuffled_cm_graph
return cm_graph, cm_nodes


def map_free_qubits(
free_nodes, partial_layout, num_physical_qubits, reverse_bit_map, avg_error_map
):
"""Add any free nodes to a layout."""
if not free_nodes:
return partial_layout
if avg_error_map is not None:
free_qubits = sorted(
set(range(num_physical_qubits)) - partial_layout.get_physical_bits().keys(),
key=lambda bit: avg_error_map.get((bit, bit), 1.0),
)
# If no error map is available this means there is no scoring heuristic available for this
# backend and we can just randomly pick a free qubit
else:
free_qubits = list(
set(range(num_physical_qubits)) - partial_layout.get_physical_bits().keys()
)
for im_index in sorted(free_nodes, key=lambda x: sum(free_nodes[x].values())):
selected_qubit = free_qubits.pop(0)
partial_layout.add(reverse_bit_map[im_index], selected_qubit)
return partial_layout
10 changes: 10 additions & 0 deletions src/error_map.rs
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,16 @@ impl ErrorMap {
fn __contains__(&self, key: [usize; 2]) -> PyResult<bool> {
Ok(self.error_map.contains_key(&key))
}

fn get(&self, py: Python, key: [usize; 2], default: Option<PyObject>) -> PyObject {
match self.error_map.get(&key).copied() {
Some(val) => val.to_object(py),
None => match default {
Some(val) => val,
None => py.None(),
},
}
}
}

#[pymodule]
Expand Down
28 changes: 15 additions & 13 deletions src/vf2_layout.rs
Original file line number Diff line number Diff line change
Expand Up @@ -73,19 +73,21 @@ pub fn score_layout(
} else {
edge_list.par_iter().filter_map(edge_filter_map).product()
};
fidelity *= if bit_list.len() < PARALLEL_THRESHOLD || !run_in_parallel {
bit_counts
.iter()
.enumerate()
.filter_map(bit_filter_map)
.product::<f64>()
} else {
bit_counts
.par_iter()
.enumerate()
.filter_map(bit_filter_map)
.product()
};
if strict_direction {
fidelity *= if bit_list.len() < PARALLEL_THRESHOLD || !run_in_parallel {
bit_counts
.iter()
.enumerate()
.filter_map(bit_filter_map)
.product::<f64>()
} else {
bit_counts
.par_iter()
.enumerate()
.filter_map(bit_filter_map)
.product()
};
}
Ok(1. - fidelity)
}

Expand Down
6 changes: 4 additions & 2 deletions test/python/transpiler/test_vf2_layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -580,6 +580,7 @@ def test_max_trials_exceeded(self):
qr = QuantumRegister(2)
qc = QuantumCircuit(qr)
qc.x(qr)
qc.cx(0, 1)
qc.measure_all()
cmap = CouplingMap(backend.configuration().coupling_map)
properties = backend.properties()
Expand All @@ -599,6 +600,7 @@ def test_time_limit_exceeded(self):
qr = QuantumRegister(2)
qc = QuantumCircuit(qr)
qc.x(qr)
qc.cx(0, 1)
qc.measure_all()
cmap = CouplingMap(backend.configuration().coupling_map)
properties = backend.properties()
Expand All @@ -620,7 +622,7 @@ def test_reasonable_limits_for_simple_layouts(self):
"""Test that the default trials is set to a reasonable number."""
backend = FakeManhattan()
qc = QuantumCircuit(5)
qc.h(2)
qc.cx(2, 3)
qc.cx(0, 1)
cmap = CouplingMap(backend.configuration().coupling_map)
properties = backend.properties()
Expand All @@ -633,7 +635,7 @@ def test_reasonable_limits_for_simple_layouts(self):
"DEBUG:qiskit.transpiler.passes.layout.vf2_layout:Trial 159 is >= configured max trials 159",
cm.output,
)
self.assertEqual(set(property_set["layout"].get_physical_bits()), {49, 40, 58, 0, 1})
self.assertEqual(set(property_set["layout"].get_physical_bits()), {49, 40, 33, 0, 34})

def test_no_limits_with_negative(self):
"""Test that we're not enforcing a trial limit if set to negative."""
Expand Down
19 changes: 0 additions & 19 deletions test/python/transpiler/test_vf2_post_layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
from qiskit.circuit import ControlFlowOp
from qiskit.circuit.library import CXGate, XGate
from qiskit.transpiler import CouplingMap, Layout, TranspilerError
from qiskit.transpiler.passes.layout import vf2_utils
from qiskit.transpiler.passes.layout.vf2_post_layout import VF2PostLayout, VF2PostLayoutStopReason
from qiskit.converters import circuit_to_dag
from qiskit.test import QiskitTestCase
Expand Down Expand Up @@ -385,24 +384,6 @@ def test_all_1q_score(self):
score = vf2_pass._score_layout(layout, bit_map, reverse_bit_map, im_graph)
self.assertAlmostEqual(0.002925, score, places=5)

def test_all_1q_avg_score(self):
"""Test average scoring for all 1q input."""
bit_map = {Qubit(): 0, Qubit(): 1}
reverse_bit_map = {v: k for k, v in bit_map.items()}
im_graph = rustworkx.PyDiGraph()
im_graph.add_node({"sx": 1})
im_graph.add_node({"sx": 1})
backend = FakeYorktownV2()
vf2_pass = VF2PostLayout(target=backend.target)
vf2_pass.avg_error_map = vf2_utils.build_average_error_map(
vf2_pass.target, vf2_pass.properties, vf2_pass.coupling_map
)
layout = {0: 0, 1: 1}
score = vf2_utils.score_layout(
vf2_pass.avg_error_map, layout, bit_map, reverse_bit_map, im_graph
)
self.assertAlmostEqual(0.02054, score, places=5)


class TestVF2PostLayoutUndirected(QiskitTestCase):
"""Tests the VF2Layout pass"""
Expand Down