Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use rustworkx version of token_swapper #10001

Merged
merged 14 commits into from
Jul 13, 2023
164 changes: 11 additions & 153 deletions qiskit/transpiler/passes/routing/algorithms/token_swapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,8 @@
"""Permutation algorithms for general graphs."""

from __future__ import annotations
import copy
import logging
from collections.abc import Mapping, MutableSet, MutableMapping, Iterator, Iterable
from collections.abc import Mapping

import numpy as np
import rustworkx as rx
Expand Down Expand Up @@ -76,10 +75,13 @@ def permutation_circuit(self, permutation: Permutation, trials: int = 4) -> Perm
The circuit to implement the permutation
"""
sequential_swaps = self.map(permutation, trials=trials)

parallel_swaps = [[swap] for swap in sequential_swaps]
return permutation_circuit(parallel_swaps)

def map(self, mapping: Mapping[int, int], trials: int = 4) -> list[Swap[int]]:
def map(
self, mapping: Mapping[int, int], trials: int = 4, parallel_threshold: int = 50
) -> list[Swap[int]]:
"""Perform an approximately optimal Token Swapping algorithm to implement the permutation.

Supports partial mappings (i.e. not-permutations) for graphs with missing tokens.
Expand All @@ -91,157 +93,13 @@ def map(self, mapping: Mapping[int, int], trials: int = 4) -> list[Swap[int]]:
Args:
mapping: The partial mapping to implement in swaps.
trials: The number of trials to try to perform the mapping. Minimize over the trials.
parallel_threshold: The number of nodes in the graph beyond which the algorithm
will use parallel processing

Returns:
The swaps to implement the mapping
"""
tokens = dict(mapping)
digraph = rx.PyDiGraph()
sub_digraph = rx.PyDiGraph() # Excludes self-loops in digraph.
todo_nodes = {node for node, destination in tokens.items() if node != destination}
for node in self.graph.node_indexes():
self._add_token_edges(node, tokens, digraph, sub_digraph)

trial_results = iter(
list(
self._trial_map(
copy.copy(digraph), copy.copy(sub_digraph), todo_nodes.copy(), tokens.copy()
)
)
for _ in range(trials)
)

# Once we find a zero solution we stop.
def take_until_zero(results: Iterable[list[Swap[int]]]) -> Iterable[list[Swap[int]]]:
"""Take results until one is emitted of length zero (and also emit that)."""
for result in results:
yield result
if not result:
break

trial_results = take_until_zero(trial_results)
return min(trial_results, key=len)

def _trial_map(
self,
digraph: rx.PyDiGraph,
sub_digraph: rx.PyDiGraph,
todo_nodes: MutableSet[int],
tokens: MutableMapping[int, int],
) -> Iterator[Swap[int]]:
"""Try to map the tokens to their destinations and minimize the number of swaps."""

def swap(node0: int, node1: int) -> None:
"""Swap two nodes, maintaining data structures.

Args:
node0: The first node
node1: The second node
"""
self._swap(node0, node1, tokens, digraph, sub_digraph, todo_nodes)

# Can't just iterate over todo_nodes, since it may change during iteration.
steps = 0
while todo_nodes and steps <= 4 * len(self.graph) ** 2:
todo_node_id = self.seed.integers(0, len(todo_nodes))
todo_node = tuple(todo_nodes)[todo_node_id]

# Try to find a happy swap chain first by searching for a cycle,
# excluding self-loops.
# Note that if there are only unhappy swaps involving this todo_node,
# then an unhappy swap must be performed at some point.
# So it is not useful to globally search for all happy swap chains first.
cycle = rx.digraph_find_cycle(sub_digraph, source=todo_node)
if len(cycle) > 0:
assert len(cycle) > 1, "The cycle was not happy."
# We iterate over the cycle in reversed order, starting at the last edge.
# The first edge is excluded.
for edge in list(cycle)[-1:0:-1]:
yield edge
swap(edge[0], edge[1])
steps += len(cycle) - 1
else:
# Try to find a node without a token to swap with.
try:
edge = next(
edge
for edge in rx.digraph_dfs_edges(sub_digraph, todo_node)
if edge[1] not in tokens
)
# Swap predecessor and successor, because successor does not have a token
yield edge
swap(edge[0], edge[1])
steps += 1
except StopIteration:
# Unhappy swap case
cycle = rx.digraph_find_cycle(digraph, source=todo_node)
assert len(cycle) == 1, "The cycle was not unhappy."
unhappy_node = cycle[0][0]
# Find a node that wants to swap with this node.
try:
predecessor = next(
predecessor
for predecessor in digraph.predecessor_indices(unhappy_node)
if predecessor != unhappy_node
)
except StopIteration:
logger.error(
"Unexpected StopIteration raised when getting predecessors"
"in unhappy swap case."
)
return
yield unhappy_node, predecessor
swap(unhappy_node, predecessor)
steps += 1
if todo_nodes:
raise RuntimeError("Too many iterations while approximating the Token Swaps.")

def _add_token_edges(
self, node: int, tokens: Mapping[int, int], digraph: rx.PyDiGraph, sub_digraph: rx.PyDiGraph
) -> None:
"""Add diedges to the graph wherever a token can be moved closer to its destination."""
if node not in tokens:
return

if tokens[node] == node:
digraph.extend_from_edge_list([(node, node)])
return

for neighbor in self.graph.neighbors(node):
if self.distance(neighbor, tokens[node]) < self.distance(node, tokens[node]):
digraph.extend_from_edge_list([(node, neighbor)])
sub_digraph.extend_from_edge_list([(node, neighbor)])

def _swap(
self,
node1: int,
node2: int,
tokens: MutableMapping[int, int],
digraph: rx.PyDiGraph,
sub_digraph: rx.PyDiGraph,
todo_nodes: MutableSet[int],
) -> None:
"""Swap two nodes, maintaining the data structures."""
assert self.graph.has_edge(
node1, node2
), "The swap is being performed on a non-existent edge."
# Swap the tokens on the nodes, taking into account no-token nodes.
token1 = tokens.pop(node1, None)
token2 = tokens.pop(node2, None)
if token2 is not None:
tokens[node1] = token2
if token1 is not None:
tokens[node2] = token1
# Recompute the edges incident to node 1 and 2
for node in [node1, node2]:
digraph.remove_edges_from(
[(node, successor) for successor in digraph.successor_indices(node)]
)
sub_digraph.remove_edges_from(
[(node, successor) for successor in sub_digraph.successor_indices(node)]
)
self._add_token_edges(node, tokens, digraph, sub_digraph)
if node in tokens and tokens[node] != node:
todo_nodes.add(node)
elif node in todo_nodes:
todo_nodes.remove(node)
# Since integer seed is used in rustworkx, take random integer from np.random.randint
# and use that for the seed.
seed = self.seed.integers(1, 10000)
return rx.graph_token_swapper(self.graph, mapping, trials, seed, parallel_threshold)
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
---
upgrade:
- |
The :meth:`~.ApproximateTokenSwapper.map` has been modified to use the new ``rustworkx`` version
of :func:`~graph_token_swapper` for performance reasons. Qiskit Terra 0.25 now requires versison
0.13.0 of ``rustworkx``.
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
rustworkx>=0.12.0
rustworkx>=0.13.0
numpy>=1.17
ply>=3.10
psutil>=5
Expand Down
5 changes: 5 additions & 0 deletions test/python/transpiler/test_layout_transformation.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,8 +85,13 @@ def test_four_qubit_with_target(self):

self.assertEqual(circuit_to_dag(expected), output_dag)

@unittest.skip("rustworkx token_swapper produces correct, but sometimes random output")
def test_full_connected_coupling_map(self):
"""Test if the permutation {0->3,1->0,2->1,3->2} in a fully connected map."""

# TODO: Remove skip when https://github.com/Qiskit/rustworkx/pull/897 is
# merged and released. Should be rustworkx 0.13.1.

v = QuantumRegister(4, "v") # virtual qubits
from_layout = Layout({v[0]: 0, v[1]: 1, v[2]: 2, v[3]: 3})
to_layout = Layout({v[0]: 3, v[1]: 0, v[2]: 1, v[3]: 2})
Expand Down