Skip to content

Commit

Permalink
Nailing down automatic circuit cutter's API (#2168)
Browse files Browse the repository at this point in the history
Adds the CutStrategy class for deriving automatic graph cutter's hyperparameters.
  • Loading branch information
zeyueN authored Feb 25, 2022
1 parent 60bdfd1 commit 1c0a57c
Show file tree
Hide file tree
Showing 4 changed files with 517 additions and 2 deletions.
7 changes: 6 additions & 1 deletion doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,11 @@
The postprocessing function for the `cut_circuit` transform has been added.
[(#2192)](https://github.com/PennyLaneAI/pennylane/pull/2192)

A class `CutStrategy` which acts as an interface and coordinates device/user
constraints with circuit execution requirements to come up with the best sets
of graph partitioning parameters.
[(#2168)](https://github.com/PennyLaneAI/pennylane/pull/2168)

<h3>Improvements</h3>

* The `gradients` module has been streamlined and special-purpose functions
Expand Down Expand Up @@ -288,6 +293,6 @@ The Operator class has undergone a major refactor with the following changes:
This release contains contributions from (in alphabetical order):

Thomas Bromley, Anthony Hayes, Josh Izaac, Christina Lee,
Maria Fernanda Morris, Maria Schuld, Jay Soni, Antal Száva,
Maria Fernanda Morris, Zeyue Niu, Maria Schuld, Jay Soni, Antal Száva,
David Wierichs

2 changes: 2 additions & 0 deletions pennylane/transforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@
~transforms.expand_fragment_tapes
~transforms.contract_tensors
~transforms.qcut_processing_fn
~transforms.CutStrategy
Transforms that act on tapes
~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Expand Down Expand Up @@ -179,4 +180,5 @@
expand_fragment_tapes,
contract_tensors,
qcut_processing_fn,
CutStrategy,
)
325 changes: 324 additions & 1 deletion pennylane/transforms/qcut.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,11 @@

import copy
import string
import warnings
import uuid
from typing import Sequence, Tuple, List, Dict, Any, Union, ClassVar
from itertools import product
from typing import List, Sequence, Tuple
from dataclasses import dataclass, InitVar

import pennylane as qml
from networkx import MultiDiGraph, weakly_connected_components
Expand Down Expand Up @@ -853,3 +855,324 @@ def qcut_processing_fn(
tensors, communication_graph, prepare_nodes, measure_nodes, use_opt_einsum
)
return result


@dataclass()
class CutStrategy:
"""
A circuit-cutting distribution policy for executing (large) circuits on available (comparably
smaller) devices.
Args:
devices (Union[qml.Device, Sequence[qml.Device]]): Single, or Sequence of, device(s).
Optional only when ``max_free_wires`` is provided.
max_free_wires (int): Number of wires for the largest available device. Optional only when
``devices`` is provided where it defaults to the maximum number of wires among
``devices``.
min_free_wires (int): Number of wires for the smallest available device, or, equivalently,
the smallest max fragment-wire-size that the partitioning is allowed to explore.
When provided, this parameter will be used to derive an upper-bound to the range of
explored number of fragments. Optional, defaults to ``max_free_wires``.
num_fragments_probed (Union[int, Sequence[int]]): Single, or 2-Sequence of, number(s)
specifying the potential (range of) number of fragments for the partitioner to attempt.
Optional, defaults to probing all valid strategies derivable from the circuit and
devices.
max_free_gates (int): Maximum allowed circuit depth for the deepest available device.
Optional, defaults to unlimited depth.
min_free_gates (int): Maximum allowed circuit depth for the shallowest available device.
Optional, defaults to ``max_free_gates``.
imbalance_tolerance (float): The global maximum allowed imbalance for all partition trials.
Optional, defaults to unlimited imbalance. Used only if there's a known hard balancing
constraint on the partitioning problem.
**Example**
The following cut strategy specifies that a circuit should be cut into between
``2`` to ``5`` fragments, with each fragment having at most ``6`` wires and
at least ``4`` wires:
>>> cut_strategy = qml.transforms.CutStrategy(
... max_free_wires=6,
... min_free_wires=4,
... num_fragments_probed=(2, 5),
... )
"""

# pylint: disable=too-many-arguments, too-many-instance-attributes

#: Initialization argument only, used to derive ``max_free_wires`` and ``min_free_wires``.
devices: InitVar[Union[qml.Device, Sequence[qml.Device]]] = None

#: Number of wires for the largest available device.
max_free_wires: int = None
#: Number of wires for the smallest available device.
min_free_wires: int = None
#: The potential (range of) number of fragments for the partitioner to attempt.
num_fragments_probed: Union[int, Sequence[int]] = None
#: Maximum allowed circuit depth for the deepest available device.
max_free_gates: int = None
#: Maximum allowed circuit depth for the shallowest available device.
min_free_gates: int = None
#: The global maximum allowed imbalance for all partition trials.
imbalance_tolerance: float = None

#: Class attribute, threshold for warning about too many fragments.
HIGH_NUM_FRAGMENTS: ClassVar[int] = 20
#: Class attribute, threshold for warning about too many partition attempts.
HIGH_PARTITION_ATTEMPTS: ClassVar[int] = 20

def __post_init__(
self,
devices,
):
"""Deriving cutting constraints from given devices and parameters."""

self.max_free_wires = self.max_free_wires or self.min_free_wires
if isinstance(self.num_fragments_probed, int):
self.num_fragments_probed = [self.num_fragments_probed]
if isinstance(self.num_fragments_probed, (list, tuple)):
self.num_fragments_probed = sorted(self.num_fragments_probed)
self.k_lower = self.num_fragments_probed[0]
self.k_upper = self.num_fragments_probed[-1]
if self.k_lower <= 0:
raise ValueError("`num_fragments_probed` must be positive int(s)")
else:
self.k_lower, self.k_upper = None, None

if devices is None and self.max_free_wires is None:
raise ValueError("One of arguments `devices` and max_free_wires` must be provided.")

if isinstance(devices, qml.Device):
devices = (devices,)

if devices is not None:
if not isinstance(devices, Sequence) or any(
(not isinstance(d, qml.Device) for d in devices)
):
raise ValueError(
"Argument `devices` must be a list or tuple containing elements of type "
"`qml.Device`"
)

device_wire_sizes = [len(d.wires) for d in devices]

self.max_free_wires = self.max_free_wires or max(device_wire_sizes)
self.min_free_wires = self.min_free_wires or min(device_wire_sizes)

if (self.imbalance_tolerance is not None) and not (
isinstance(self.imbalance_tolerance, (float, int)) and self.imbalance_tolerance >= 0
):
raise ValueError(
"The overall `imbalance_tolerance` is expected to be a non-negative number, "
f"got {type(self.imbalance_tolerance)} with value {self.imbalance_tolerance}."
)

def get_cut_kwargs(
self,
tape_dag: MultiDiGraph,
max_wires_by_fragment: Sequence[int] = None,
max_gates_by_fragment: Sequence[int] = None,
) -> List[Dict[str, Any]]:
"""Derive the complete set of arguments, based on a given circuit, for passing to a graph
partitioner.
Args:
tape_dag (MultiDiGraph): Graph representing a tape, typically the output of
:func:`tape_to_graph`.
max_wires_by_fragment (Sequence[int]): User-predetermined list of wire limits by
fragment. If supplied, the number of fragments will be derived from it and
exploration of other choices will not be made.
max_gates_by_fragment (Sequence[int]): User-predetermined list of gate limits by
fragment. If supplied, the number of fragments will be derived from it and
exploration of other choices will not be made.
Returns:
List[Dict[str, Any]]: A list of minimal kwargs being passed to a graph
partitioner method.
**Example**
Deriving kwargs for a given circuit and feeding them to a custom partitioner, along with
extra parameters specified using ``extra_kwargs``:
>>> cut_strategy = qcut.CutStrategy(devices=dev)
>>> cut_kwargs = cut_strategy.get_cut_kwargs(tape_dag)
>>> cut_trials = [
... my_partition_fn(tape_dag, **kwargs, **extra_kwargs) for kwargs in cut_kwargs
... ]
"""
tape_wires = set(w for _, _, w in tape_dag.edges.data("wire"))
num_tape_wires = len(tape_wires)
num_tape_gates = tape_dag.order()
self._validate_input(max_wires_by_fragment, max_gates_by_fragment)

probed_cuts = self._infer_probed_cuts(
num_tape_wires=num_tape_wires,
num_tape_gates=num_tape_gates,
max_wires_by_fragment=max_wires_by_fragment,
max_gates_by_fragment=max_gates_by_fragment,
)

return probed_cuts

@staticmethod
def _infer_imbalance(
k, num_wires, num_gates, free_wires, free_gates, imbalance_tolerance=None
) -> float:
"""Helper function for determining best imbalance limit."""
avg_fragment_wires = (num_wires - 1) // k + 1
avg_fragment_gates = (num_gates - 1) // k + 1
if free_wires < avg_fragment_wires:
raise ValueError(
"`free_wires` should be no less than the average number of wires per fragment. "
f"Got {free_wires} >= {avg_fragment_wires} ."
)
if free_gates < avg_fragment_gates:
raise ValueError(
"`free_gates` should be no less than the average number of gates per fragment. "
f"Got {free_gates} >= {avg_fragment_gates} ."
)

wire_imbalance = free_wires / avg_fragment_wires - 1
gate_imbalance = free_gates / avg_fragment_gates - 1
imbalance = min(gate_imbalance, wire_imbalance)
if imbalance_tolerance is not None:
imbalance = min(imbalance, imbalance_tolerance)

return imbalance

@staticmethod
def _validate_input(
max_wires_by_fragment,
max_gates_by_fragment,
):
"""Helper parameter checker."""
if max_wires_by_fragment is not None:
if not isinstance(max_wires_by_fragment, (list, tuple)):
raise ValueError(
"`max_wires_by_fragment` is expected to be a list or tuple, but got "
f"{type(max_gates_by_fragment)}."
)
if any(not (isinstance(i, int) and i > 0) for i in max_wires_by_fragment):
raise ValueError(
"`max_wires_by_fragment` is expected to contain positive integers only."
)
if max_gates_by_fragment is not None:
if not isinstance(max_gates_by_fragment, (list, tuple)):
raise ValueError(
"`max_gates_by_fragment` is expected to be a list or tuple, but got "
f"{type(max_gates_by_fragment)}."
)
if any(not (isinstance(i, int) and i > 0) for i in max_gates_by_fragment):
raise ValueError(
"`max_gates_by_fragment` is expected to contain positive integers only."
)
if max_wires_by_fragment is not None and max_gates_by_fragment is not None:
if len(max_wires_by_fragment) != len(max_gates_by_fragment):
raise ValueError(
"The lengths of `max_wires_by_fragment` and `max_gates_by_fragment` should be "
f"equal, but got {len(max_wires_by_fragment)} and {len(max_gates_by_fragment)}."
)

def _infer_probed_cuts(
self,
num_tape_wires,
num_tape_gates,
max_wires_by_fragment=None,
max_gates_by_fragment=None,
) -> List[Dict[str, Any]]:
"""
Helper function for deriving the minimal set of best default partitioning constraints
for the graph partitioner.
Args:
num_tape_wires (int): Number of wires in the circuit tape to be partitioned.
num_tape_gates (int): Number of gates in the circuit tape to be partitioned.
max_wires_by_fragment (Sequence[int]): User-predetermined list of wire limits by
fragment. If supplied, the number of fragments will be derived from it and
exploration of other choices will not be made.
max_gates_by_fragment (Sequence[int]): User-predetermined list of gate limits by
fragment. If supplied, the number of fragments will be derived from it and
exploration of other choices will not be made.
Returns:
List[Dict[str, Any]]: A list of minimal set of kwargs being passed to a graph
partitioner method.
"""

# Assumes unlimited width/depth if not supplied.
max_free_wires = self.max_free_wires or num_tape_wires
max_free_gates = self.max_free_gates or num_tape_gates

# Assumes same number of wires/gates across all devices if min_free_* not provided.
min_free_wires = self.min_free_wires or max_free_wires
min_free_gates = self.min_free_gates or max_free_gates

# The lower bound of k corresponds to executing each fragment on the largest available device.
k_lb = 1 + max(
(num_tape_wires - 1) // max_free_wires, # wire limited
(num_tape_gates - 1) // max_free_gates, # gate limited
)
# The upper bound of k corresponds to executing each fragment on the smallest available device.
k_ub = 1 + max(
(num_tape_wires - 1) // min_free_wires, # wire limited
(num_tape_gates - 1) // min_free_gates, # gate limited
)

# The global imbalance tolerance, if not given, defaults to a very loose upper bound:
imbalance_tolerance = k_ub if self.imbalance_tolerance is None else self.imbalance_tolerance

probed_cuts = []

if max_gates_by_fragment is None and max_wires_by_fragment is None:

# k_lower, when supplied by a user, can be higher than k_lb if the the desired k is known:
k_lower = self.k_lower if self.k_lower is not None else k_lb
# k_upper, when supplied by a user, can be higher than k_ub to encourage exploration:
k_upper = self.k_upper if self.k_upper is not None else k_ub

if k_lower < k_lb:
warnings.warn(
f"The provided `k_lower={k_lower}` is less than the lowest allowed value, "
f"will override and set `k_lower={k_lb}`."
)
k_lower = k_lb

if k_lower > self.HIGH_NUM_FRAGMENTS:
warnings.warn(
f"The attempted number of fragments seems high with lower bound at {k_lower}."
)

# Prepare the list of ks to explore:
ks = list(range(k_lower, k_upper + 1))

if len(ks) > self.HIGH_PARTITION_ATTEMPTS:
warnings.warn(f"The numer of partition attempts seems high ({len(ks)}).")
else:
# When the by-fragment wire and/or gate limits are supplied, derive k and imbalance and
# return a single partition config.
ks = [len(max_wires_by_fragment or max_gates_by_fragment)]

for k in ks:
imbalance = self._infer_imbalance(
k,
num_tape_wires,
num_tape_gates,
max_free_wires if max_wires_by_fragment is None else max(max_wires_by_fragment),
max_free_gates if max_gates_by_fragment is None else max(max_gates_by_fragment),
imbalance_tolerance,
)
cut_kwargs = {
"num_fragments": k,
"imbalance": imbalance,
}
if max_wires_by_fragment is not None:
cut_kwargs["max_wires_by_fragment"] = max_wires_by_fragment
if max_gates_by_fragment is not None:
cut_kwargs["max_gates_by_fragment"] = max_gates_by_fragment

probed_cuts.append(cut_kwargs)

return probed_cuts
Loading

0 comments on commit 1c0a57c

Please sign in to comment.