From 88595a54daabe6b8303ca3107d26bff996ef7860 Mon Sep 17 00:00:00 2001 From: Jake Lishman Date: Thu, 29 Jun 2023 17:09:44 +0100 Subject: [PATCH] Add `Expr` support to `ControlFlowOp` representation This commit allows `IfElseOp` and `WhileLoopOp` to have a `condition` that is an `Expr` typed `Bool()`, and `SwitchCaseOp` to have a `target` that is an `Expr`. It does not fully add support to resolving these new allowed values into the control-flow builder interface (that will come later). --- qiskit/circuit/classical/expr/__init__.py | 8 +- qiskit/circuit/classical/expr/constructors.py | 4 +- qiskit/circuit/classical/expr/expr.py | 4 + qiskit/circuit/classical/expr/visitors.py | 45 +++++ qiskit/circuit/controlflow/__init__.py | 1 + qiskit/circuit/controlflow/_builder_utils.py | 82 +++++++- qiskit/circuit/controlflow/builder.py | 4 +- qiskit/circuit/controlflow/condition.py | 76 ------- qiskit/circuit/controlflow/if_else.py | 23 ++- qiskit/circuit/controlflow/switch_case.py | 18 +- qiskit/circuit/controlflow/while_loop.py | 22 +-- qiskit/circuit/quantumcircuit.py | 36 +++- qiskit/dagcircuit/collect_blocks.py | 4 +- qiskit/dagcircuit/dagcircuit.py | 36 ++-- qiskit/dagcircuit/dagdependency.py | 6 +- test/python/circuit/test_control_flow.py | 187 ++++++++++++------ 16 files changed, 359 insertions(+), 197 deletions(-) delete mode 100644 qiskit/circuit/controlflow/condition.py diff --git a/qiskit/circuit/classical/expr/__init__.py b/qiskit/circuit/classical/expr/__init__.py index 6aa20ca5fd94..43ac310665dd 100644 --- a/qiskit/circuit/classical/expr/__init__.py +++ b/qiskit/circuit/classical/expr/__init__.py @@ -137,6 +137,11 @@ that they wish to handle. Any non-overridden methods will call :meth:`~ExprVisitor.visit_generic`, which unless overridden will raise a ``RuntimeError`` to ensure that you are aware if new nodes have been added to the expression tree that you are not yet handling. + +For the convenience of simple visitors that only need to inspect the variables in an expression and +not the general structure, the iterator method :func:`iter_vars` is provided. + +.. autofunction:: iter_vars """ __all__ = [ @@ -147,6 +152,7 @@ "Unary", "Binary", "ExprVisitor", + "iter_vars", "lift", "cast", "bit_not", @@ -166,7 +172,7 @@ ] from .expr import Expr, Var, Value, Cast, Unary, Binary -from .visitors import ExprVisitor +from .visitors import ExprVisitor, iter_vars from .constructors import ( lift, cast, diff --git a/qiskit/circuit/classical/expr/constructors.py b/qiskit/circuit/classical/expr/constructors.py index b0c476b4e627..ac976301f16b 100644 --- a/qiskit/circuit/classical/expr/constructors.py +++ b/qiskit/circuit/classical/expr/constructors.py @@ -116,7 +116,7 @@ def lift_legacy_condition( lifted = expr.lift_legacy_condition(instr.condition) """ - from qiskit.circuit import Clbit + from qiskit.circuit import Clbit # pylint: disable=cyclic-import target, value = condition if isinstance(target, Clbit): @@ -159,7 +159,7 @@ def lift(value: typing.Any, /, type: types.Type | None = None) -> Expr: if type is not None: raise ValueError("use 'cast' to cast existing expressions, not 'lift'") return value - from qiskit.circuit import Clbit, ClassicalRegister + from qiskit.circuit import Clbit, ClassicalRegister # pylint: disable=cyclic-import inferred: types.Type if value is True or value is False or isinstance(value, Clbit): diff --git a/qiskit/circuit/classical/expr/expr.py b/qiskit/circuit/classical/expr/expr.py index 5978b908d11c..b9e9aad4a2b7 100644 --- a/qiskit/circuit/classical/expr/expr.py +++ b/qiskit/circuit/classical/expr/expr.py @@ -37,9 +37,13 @@ if typing.TYPE_CHECKING: import qiskit + _T_co = typing.TypeVar("_T_co", covariant=True) +# If adding nodes, remember to update `visitors.ExprVisitor` as well. + + class Expr(abc.ABC): """Root base class of all nodes in the expression tree. The base case should never be instantiated directly. diff --git a/qiskit/circuit/classical/expr/visitors.py b/qiskit/circuit/classical/expr/visitors.py index a61fec932ddb..82fe0b0a497d 100644 --- a/qiskit/circuit/classical/expr/visitors.py +++ b/qiskit/circuit/classical/expr/visitors.py @@ -16,6 +16,7 @@ __all__ = [ "ExprVisitor", + "iter_vars", ] import typing @@ -30,6 +31,7 @@ class ExprVisitor(typing.Generic[_T_co]): the ``visit_*`` methods that they are able to handle, and should be organised such that non-existent methods will never be called.""" + # The method names are self-explanatory and docstrings would just be noise. # pylint: disable=missing-function-docstring __slots__ = () @@ -51,3 +53,46 @@ def visit_binary(self, node: expr.Binary, /) -> _T_co: # pragma: no cover def visit_cast(self, node: expr.Cast, /) -> _T_co: # pragma: no cover return self.visit_generic(node) + + +class _VarWalkerImpl(ExprVisitor[typing.Iterable[expr.Var]]): + __slots__ = () + + def visit_var(self, node, /): + yield node + + def visit_value(self, node, /): + yield from () + + def visit_unary(self, node, /): + yield from node.operand.accept(self) + + def visit_binary(self, node, /): + yield from node.left.accept(self) + yield from node.right.accept(self) + + def visit_cast(self, node, /): + yield from node.operand.accept(self) + + +_VAR_WALKER = _VarWalkerImpl() + + +def iter_vars(node: expr.Expr) -> typing.Iterator[expr.Var]: + """Get an iterator over the :class:`~.expr.Var` nodes referenced at any level in the given + :class:`~.expr.Expr`. + + Examples: + Print out the name of each :class:`.ClassicalRegister` encountered:: + + from qiskit.circuit import ClassicalRegister + from qiskit.circuit.classical import expr + + cr1 = ClassicalRegister(3, "a") + cr2 = ClassicalRegister(3, "b") + + for node in expr.iter_vars(expr.bit_and(expr.bit_not(cr1), cr2)): + if isinstance(node.var, ClassicalRegister): + print(node.var.name) + """ + yield from node.accept(_VAR_WALKER) diff --git a/qiskit/circuit/controlflow/__init__.py b/qiskit/circuit/controlflow/__init__.py index 60df9c2a370a..abec942880a8 100644 --- a/qiskit/circuit/controlflow/__init__.py +++ b/qiskit/circuit/controlflow/__init__.py @@ -13,6 +13,7 @@ """Instruction sub-classes for dynamic circuits.""" +from ._builder_utils import condition_resources, node_resources, LegacyResources from .control_flow import ControlFlowOp from .continue_loop import ContinueLoopOp from .break_loop import BreakLoopOp diff --git a/qiskit/circuit/controlflow/_builder_utils.py b/qiskit/circuit/controlflow/_builder_utils.py index 49950a6655b2..b78079df7ae2 100644 --- a/qiskit/circuit/controlflow/_builder_utils.py +++ b/qiskit/circuit/controlflow/_builder_utils.py @@ -12,14 +12,92 @@ """Private utility functions that are used by the builder interfaces.""" -from typing import Iterable, Tuple, Set +from __future__ import annotations +import dataclasses +from typing import Iterable, Tuple, Set, Union, TypeVar + +from qiskit.circuit.classical import expr, types from qiskit.circuit.exceptions import CircuitError from qiskit.circuit.quantumcircuit import QuantumCircuit from qiskit.circuit.register import Register -from qiskit.circuit.classicalregister import ClassicalRegister +from qiskit.circuit.classicalregister import ClassicalRegister, Clbit from qiskit.circuit.quantumregister import QuantumRegister +_ConditionT = TypeVar( + "_ConditionT", bound=Union[Tuple[ClassicalRegister, int], Tuple[Clbit, int], expr.Expr] +) + + +def validate_condition(condition: _ConditionT) -> _ConditionT: + """Validate that a condition is in a valid format and return it, but raise if it is invalid. + + Args: + condition: the condition to be tested for validity. Must be either the legacy 2-tuple + format, or a :class:`~.expr.Expr` that has `Bool` type. + + Raises: + CircuitError: if the condition is not in a valid format. + + Returns: + The same condition as passed, if it was valid. + """ + if isinstance(condition, expr.Expr): + if condition.type.kind is not types.Bool: + raise CircuitError( + "Classical conditions must be expressions with the type 'Bool()'," + f" not '{condition.type}'." + ) + return condition + try: + bits, value = condition + if isinstance(bits, (ClassicalRegister, Clbit)) and isinstance(value, int): + return (bits, value) + except (TypeError, ValueError): + pass + raise CircuitError( + "A classical condition should be a 2-tuple of `(ClassicalRegister | Clbit, int)`," + f" but received '{condition!r}'." + ) + + +@dataclasses.dataclass +class LegacyResources: + """A pair of the :class:`.Clbit` and :class:`.ClassicalRegister` resources used by some other + object (such as a legacy condition or :class:`.expr.Expr` node).""" + + clbits: tuple[Clbit, ...] + cregs: tuple[ClassicalRegister, ...] + + +def node_resources(node: expr.Expr) -> LegacyResources: + """Get the legacy classical resources (:class:`.Clbit` and :class:`.ClassicalRegister`) + referenced by an :class:`~.expr.Expr`.""" + # It's generally convenient for us to ensure that the resources are returned in some + # deterministic order. This uses the ordering of 'dict' objects to fake out an ordered set. + clbits = {} + cregs = {} + for var in expr.iter_vars(node): + if isinstance(var.var, Clbit): + clbits[var.var] = None + elif isinstance(var.var, ClassicalRegister): + clbits.update((bit, None) for bit in var.var) + cregs[var.var] = None + return LegacyResources(tuple(clbits), tuple(cregs)) + + +def condition_resources( + condition: tuple[ClassicalRegister, int] | tuple[Clbit, int] | expr.Expr +) -> LegacyResources: + """Get the legacy classical resources (:class:`.Clbit` and :class:`.ClassicalRegister`) + referenced by a legacy condition or an :class:`~.expr.Expr`.""" + if isinstance(condition, expr.Expr): + return node_resources(condition) + target, _ = condition + if isinstance(target, ClassicalRegister): + return LegacyResources(tuple(target), (target,)) + return LegacyResources((target,), ()) + def partition_registers( registers: Iterable[Register], diff --git a/qiskit/circuit/controlflow/builder.py b/qiskit/circuit/controlflow/builder.py index 0d929f70d44b..c837b3471293 100644 --- a/qiskit/circuit/controlflow/builder.py +++ b/qiskit/circuit/controlflow/builder.py @@ -30,7 +30,7 @@ from qiskit.circuit.quantumregister import Qubit, QuantumRegister from qiskit.circuit.register import Register -from .condition import condition_registers +from ._builder_utils import condition_resources if typing.TYPE_CHECKING: import qiskit # pylint: disable=cyclic-import @@ -447,7 +447,7 @@ def build( self.add_register(register) out.add_register(register) if getattr(instruction.operation, "condition", None) is not None: - for register in condition_registers(instruction.operation.condition): + for register in condition_resources(instruction.operation.condition).cregs: if register not in self.registers: self.add_register(register) out.add_register(register) diff --git a/qiskit/circuit/controlflow/condition.py b/qiskit/circuit/controlflow/condition.py deleted file mode 100644 index 626b74faded1..000000000000 --- a/qiskit/circuit/controlflow/condition.py +++ /dev/null @@ -1,76 +0,0 @@ -# This code is part of Qiskit. -# -# (C) Copyright IBM 2021. -# -# This code is licensed under the Apache License, Version 2.0. You may -# obtain a copy of this license in the LICENSE.txt file in the root directory -# of this source tree or at http://www.apache.org/licenses/LICENSE-2.0. -# -# Any modifications or derivative works of this code must retain this -# copyright notice, and modified files need to carry a notice indicating -# that they have been altered from the originals. - -"""Functions for dealing with classical conditions.""" - -from typing import Tuple, Union - -from qiskit.circuit.classicalregister import ClassicalRegister, Clbit -from qiskit.circuit.exceptions import CircuitError - - -def validate_condition( - condition: Tuple[Union[ClassicalRegister, Clbit], int] -) -> Tuple[Union[ClassicalRegister, Clbit], int]: - """Validate that a condition is in a valid format and return it, but raise if it is invalid. - - Args: - condition: the condition to be tested for validity. - - Raises: - CircuitError: if the condition is not in a valid format. - - Returns: - The same condition as passed, if it was valid. - """ - try: - bits, value = condition - if isinstance(bits, (ClassicalRegister, Clbit)) and isinstance(value, int): - return (bits, value) - except (TypeError, ValueError): - pass - raise CircuitError( - "A classical condition should be a 2-tuple of `(ClassicalRegister | Clbit, int)`," - f" but received '{condition!r}'." - ) - - -def condition_bits(condition: Tuple[Union[ClassicalRegister, Clbit], int]) -> Tuple[Clbit, ...]: - """Return the classical resources used by ``condition`` as a tuple of :obj:`.Clbit`. - - This is useful when the exact set of bits is required, rather than the logical grouping of - :obj:`.ClassicalRegister`, such as when determining circuit blocking. - - Args: - condition: the valid condition to extract the bits from. - - Returns: - a tuple of all classical bits used in the condition. - """ - return (condition[0],) if isinstance(condition[0], Clbit) else tuple(condition[0]) - - -def condition_registers( - condition: Tuple[Union[ClassicalRegister, Clbit], int] -) -> Tuple[ClassicalRegister, ...]: - """Return any classical registers used by ``condition`` as a tuple of :obj:`.ClassicalRegister`. - - This is useful as a quick method for extracting the registers from a condition, if any exist. - The output might be empty if the condition is on a single bit. - - Args: - condition: the valid condition to extract any registers from. - - Returns: - a tuple of all classical registers used in the condition. - """ - return (condition[0],) if isinstance(condition[0], ClassicalRegister) else () diff --git a/qiskit/circuit/controlflow/if_else.py b/qiskit/circuit/controlflow/if_else.py index 23dc3a1d9ddd..f4095f7484e7 100644 --- a/qiskit/circuit/controlflow/if_else.py +++ b/qiskit/circuit/controlflow/if_else.py @@ -10,20 +10,26 @@ # copyright notice, and modified files need to carry a notice indicating # that they have been altered from the originals. -"Circuit operation representing an ``if/else`` statement." +"""Circuit operation representing an ``if/else`` statement.""" +from __future__ import annotations from typing import Optional, Tuple, Union, Iterable import itertools from qiskit.circuit import ClassicalRegister, Clbit, QuantumCircuit +from qiskit.circuit.classical import expr from qiskit.circuit.instructionset import InstructionSet from qiskit.circuit.exceptions import CircuitError from .builder import ControlFlowBuilderBlock, InstructionPlaceholder, InstructionResources -from .condition import validate_condition, condition_bits, condition_registers from .control_flow import ControlFlowOp -from ._builder_utils import partition_registers, unify_circuit_resources +from ._builder_utils import ( + partition_registers, + unify_circuit_resources, + validate_condition, + condition_resources, +) # This is just an indication of what's actually meant to be the public API. @@ -71,10 +77,10 @@ class IfElseOp(ControlFlowOp): def __init__( self, - condition: Tuple[Union[ClassicalRegister, Clbit], int], + condition: tuple[ClassicalRegister, int] | tuple[Clbit, int] | expr.Expr, true_body: QuantumCircuit, - false_body: Optional[QuantumCircuit] = None, - label: Optional[str] = None, + false_body: QuantumCircuit | None = None, + label: str | None = None, ): # Type checking generally left to @params.setter, but required here for # finding num_qubits and num_clbits. @@ -364,9 +370,10 @@ def in_loop(self) -> bool: return self._in_loop def __enter__(self): + resources = condition_resources(self._condition) self._circuit._push_scope( - clbits=condition_bits(self._condition), - registers=condition_registers(self._condition), + clbits=resources.clbits, + registers=resources.cregs, allow_jumps=self._in_loop, ) return ElseContext(self) diff --git a/qiskit/circuit/controlflow/switch_case.py b/qiskit/circuit/controlflow/switch_case.py index 179254970f90..119249dae4e1 100644 --- a/qiskit/circuit/controlflow/switch_case.py +++ b/qiskit/circuit/controlflow/switch_case.py @@ -12,12 +12,15 @@ """Circuit operation representing an ``switch/case`` statement.""" +from __future__ import annotations + __all__ = ("SwitchCaseOp", "CASE_DEFAULT") import contextlib from typing import Union, Iterable, Any, Tuple, Optional, List, Literal from qiskit.circuit import ClassicalRegister, Clbit, QuantumCircuit +from qiskit.circuit.classical import expr, types from qiskit.circuit.exceptions import CircuitError from .builder import InstructionPlaceholder, InstructionResources, ControlFlowBuilderBlock @@ -64,15 +67,24 @@ class SwitchCaseOp(ControlFlowOp): def __init__( self, - target: Union[Clbit, ClassicalRegister], + target: Clbit | ClassicalRegister | expr.Expr, cases: Iterable[Tuple[Any, QuantumCircuit]], *, label: Optional[str] = None, ): - if not isinstance(target, (Clbit, ClassicalRegister)): + if isinstance(target, expr.Expr): + if target.type.kind not in (types.Uint, types.Bool): + raise CircuitError( + "the switch target must be an expression with type 'Uint(n)' or 'Bool()'," + f" not '{target.type}'" + ) + elif not isinstance(target, (Clbit, ClassicalRegister)): raise CircuitError("the switch target must be a classical bit or register") - target_bits = 1 if isinstance(target, Clbit) else len(target) + if isinstance(target, expr.Expr): + target_bits = 1 if target.type.kind is types.Bool else target.type.width + else: + target_bits = 1 if isinstance(target, Clbit) else len(target) target_max = (1 << target_bits) - 1 case_ids = set() diff --git a/qiskit/circuit/controlflow/while_loop.py b/qiskit/circuit/controlflow/while_loop.py index bc5c30973087..db724cf26a68 100644 --- a/qiskit/circuit/controlflow/while_loop.py +++ b/qiskit/circuit/controlflow/while_loop.py @@ -10,13 +10,16 @@ # copyright notice, and modified files need to carry a notice indicating # that they have been altered from the originals. -"Circuit operation representing a ``while`` loop." +"""Circuit operation representing a ``while`` loop.""" -from typing import Optional, Tuple, Union +from __future__ import annotations + +from typing import Union, Tuple, Optional from qiskit.circuit import Clbit, ClassicalRegister, QuantumCircuit +from qiskit.circuit.classical import expr from qiskit.circuit.exceptions import CircuitError -from .condition import validate_condition, condition_bits, condition_registers +from ._builder_utils import validate_condition, condition_resources from .control_flow import ControlFlowOp @@ -53,13 +56,9 @@ class WhileLoopOp(ControlFlowOp): def __init__( self, - condition: Union[ - Tuple[ClassicalRegister, int], - Tuple[Clbit, int], - Tuple[Clbit, bool], - ], + condition: tuple[ClassicalRegister, int] | tuple[Clbit, int] | expr.Expr, body: QuantumCircuit, - label: Optional[str] = None, + label: str | None = None, ): num_qubits = body.num_qubits num_clbits = body.num_clbits @@ -155,9 +154,8 @@ def __init__( self._label = label def __enter__(self): - self._circuit._push_scope( - clbits=condition_bits(self._condition), registers=condition_registers(self._condition) - ) + resources = condition_resources(self._condition) + self._circuit._push_scope(clbits=resources.clbits, registers=resources.cregs) def __exit__(self, exc_type, exc_val, exc_tb): if exc_type is not None: diff --git a/qiskit/circuit/quantumcircuit.py b/qiskit/circuit/quantumcircuit.py index ac5079cb3e11..c935da6a0809 100644 --- a/qiskit/circuit/quantumcircuit.py +++ b/qiskit/circuit/quantumcircuit.py @@ -49,6 +49,7 @@ from qiskit.qasm.exceptions import QasmError from qiskit.circuit.exceptions import CircuitError from qiskit.utils import optionals as _optionals +from .classical import expr from .parameterexpression import ParameterExpression, ParameterValueType from .quantumregister import QuantumRegister, Qubit, AncillaRegister, AncillaQubit from .classicalregister import ClassicalRegister, Clbit @@ -1228,6 +1229,16 @@ def _resolve_classical_resource(self, specifier): raise CircuitError(f"Classical bit index {specifier} is out-of-range.") from None raise CircuitError(f"Unknown classical resource specifier: '{specifier}'.") + def _validate_expr(self, node: expr.Expr) -> expr.Expr: + for var in expr.iter_vars(node): + if isinstance(var.var, Clbit): + if var.var not in self._clbit_indices: + raise CircuitError(f"Clbit {var.var} is not present in this circuit.") + elif isinstance(var.var, ClassicalRegister): + if var.var not in self.cregs: + raise CircuitError(f"Register {var.var} is not present in this circuit.") + return node + def append( self, instruction: Operation | CircuitInstruction, @@ -4330,7 +4341,7 @@ def _update_parameter_table_on_instruction_removal(self, instruction: CircuitIns @typing.overload def while_loop( self, - condition: tuple[ClassicalRegister | Clbit, int], + condition: tuple[ClassicalRegister | Clbit, int] | expr.Expr, body: None, qubits: None, clbits: None, @@ -4342,7 +4353,7 @@ def while_loop( @typing.overload def while_loop( self, - condition: tuple[ClassicalRegister | Clbit, int], + condition: tuple[ClassicalRegister | Clbit, int] | expr.Expr, body: "QuantumCircuit", qubits: Sequence[QubitSpecifier], clbits: Sequence[ClbitSpecifier], @@ -4397,7 +4408,10 @@ def while_loop(self, condition, body=None, qubits=None, clbits=None, *, label=No # pylint: disable=cyclic-import from qiskit.circuit.controlflow.while_loop import WhileLoopOp, WhileLoopContext - condition = (self._resolve_classical_resource(condition[0]), condition[1]) + if isinstance(condition, expr.Expr): + condition = self._validate_expr(condition) + else: + condition = (self._resolve_classical_resource(condition[0]), condition[1]) if body is None: if qubits is not None or clbits is not None: @@ -4602,7 +4616,10 @@ def if_test( # pylint: disable=cyclic-import from qiskit.circuit.controlflow.if_else import IfElseOp, IfContext - condition = (self._resolve_classical_resource(condition[0]), condition[1]) + if isinstance(condition, expr.Expr): + condition = self._validate_expr(condition) + else: + condition = (self._resolve_classical_resource(condition[0]), condition[1]) if true_body is None: if qubits is not None or clbits is not None: @@ -4668,7 +4685,11 @@ def if_else( # pylint: disable=cyclic-import from qiskit.circuit.controlflow.if_else import IfElseOp - condition = (self._resolve_classical_resource(condition[0]), condition[1]) + if isinstance(condition, expr.Expr): + condition = self._validate_expr(condition) + else: + condition = (self._resolve_classical_resource(condition[0]), condition[1]) + return self.append(IfElseOp(condition, true_body, false_body, label), qubits, clbits) @typing.overload @@ -4749,7 +4770,10 @@ def switch(self, target, cases=None, qubits=None, clbits=None, *, label=None): # pylint: disable=cyclic-import from qiskit.circuit.controlflow.switch_case import SwitchCaseOp, SwitchContext - target = self._resolve_classical_resource(target) + if isinstance(target, expr.Expr): + target = self._validate_expr(target) + else: + target = self._resolve_classical_resource(target) if cases is None: if qubits is not None or clbits is not None: raise CircuitError( diff --git a/qiskit/dagcircuit/collect_blocks.py b/qiskit/dagcircuit/collect_blocks.py index 3c09d5dcb82b..cb447128ab93 100644 --- a/qiskit/dagcircuit/collect_blocks.py +++ b/qiskit/dagcircuit/collect_blocks.py @@ -15,7 +15,7 @@ into smaller sub-blocks, and to consolidate blocks.""" from qiskit.circuit import QuantumCircuit, CircuitInstruction, ClassicalRegister -from qiskit.circuit.controlflow.condition import condition_bits +from qiskit.circuit.controlflow import condition_resources from . import DAGOpNode, DAGCircuit, DAGDependency from .exceptions import DAGCircuitError @@ -272,7 +272,7 @@ def collapse_to_operation(self, blocks, collapse_fn): cur_clbits.update(node.cargs) cond = getattr(node.op, "condition", None) if cond is not None: - cur_clbits.update(condition_bits(cond)) + cur_clbits.update(condition_resources(cond).clbits) if isinstance(cond[0], ClassicalRegister): cur_clregs.append(cond[0]) diff --git a/qiskit/dagcircuit/dagcircuit.py b/qiskit/dagcircuit/dagcircuit.py index a0eb6feec718..35313de08ec5 100644 --- a/qiskit/dagcircuit/dagcircuit.py +++ b/qiskit/dagcircuit/dagcircuit.py @@ -30,8 +30,7 @@ import rustworkx as rx from qiskit.circuit import ControlFlowOp, ForLoopOp, IfElseOp, WhileLoopOp, SwitchCaseOp -from qiskit.circuit.controlflow.condition import condition_bits -from qiskit.circuit.exceptions import CircuitError +from qiskit.circuit.controlflow import condition_resources from qiskit.circuit.quantumregister import QuantumRegister, Qubit from qiskit.circuit.classicalregister import ClassicalRegister, Clbit from qiskit.circuit.gate import Gate @@ -450,12 +449,14 @@ def _check_condition(self, name, condition): Raises: DAGCircuitError: if conditioning on an invalid register """ - if ( - condition is not None - and condition[0] not in self.clbits - and condition[0].name not in self.cregs - ): - raise DAGCircuitError("invalid creg in condition for %s" % name) + if condition is None: + return + resources = condition_resources(condition) + for creg in resources.cregs: + if creg.name not in self.cregs: + raise DAGCircuitError(f"invalid creg in condition for {name}") + if not set(resources.clbits).issubset(self.clbits): + raise DAGCircuitError(f"invalid clbits in condition for {name}") def _check_bits(self, args, amap): """Check the values of a list of (qu)bit arguments. @@ -479,24 +480,13 @@ def _bits_in_condition(cond): """Return a list of bits in the given condition. Args: - cond (tuple or None): optional condition (ClassicalRegister, int) or (Clbit, bool) + cond (tuple or expr.Expr or None): optional condition in any form that the control-flow + operations accept. Returns: list[Clbit]: list of classical bits - - Raises: - CircuitError: if cond[0] is not ClassicalRegister or Clbit """ - if cond is None: - return [] - elif isinstance(cond[0], ClassicalRegister): - # Returns a list of all the cbits in the given creg cond[0]. - return cond[0][:] - elif isinstance(cond[0], Clbit): - # Returns a singleton list of the conditional cbit. - return [cond[0]] - else: - raise CircuitError("Condition must be used with ClassicalRegister or Clbit.") + return [] if cond is None else list(condition_resources(cond).clbits) def _increment_op(self, op): if op.name in self._op_names: @@ -1133,7 +1123,7 @@ def replace_block_with_op(self, node_block, op, wire_pos_map, cycle_check=True): block_cargs |= set(nd.cargs) cond = getattr(nd.op, "condition", None) if cond is not None: - block_cargs.update(condition_bits(cond)) + block_cargs.update(condition_resources(cond).clbits) # Create replacement node new_node = DAGOpNode( diff --git a/qiskit/dagcircuit/dagdependency.py b/qiskit/dagcircuit/dagdependency.py index 128129e76372..8082fdadfa75 100644 --- a/qiskit/dagcircuit/dagdependency.py +++ b/qiskit/dagcircuit/dagdependency.py @@ -19,7 +19,7 @@ import rustworkx as rx -from qiskit.circuit.controlflow.condition import condition_bits +from qiskit.circuit.controlflow import condition_resources from qiskit.circuit.quantumregister import QuantumRegister, Qubit from qiskit.circuit.classicalregister import ClassicalRegister, Clbit from qiskit.dagcircuit.exceptions import DAGDependencyError @@ -395,7 +395,7 @@ def _create_op_node(self, operation, qargs, cargs): # (1) cindices_list are specific to template optimization and should not be computed # in this place. # (2) Template optimization pass needs currently does not handle general conditions. - cond_bits = condition_bits(operation.condition) + cond_bits = condition_resources(operation.condition).clbits cindices_list = [self.clbits.index(clbit) for clbit in cond_bits] else: cindices_list = [] @@ -592,7 +592,7 @@ def replace_block_with_op(self, node_block, op, wire_pos_map, cycle_check=True): block_cargs |= set(nd.cargs) cond = getattr(nd.op, "condition", None) if cond is not None: - block_cargs.update(condition_bits(cond)) + block_cargs.update(condition_resources(cond).clbits) # Create replacement node new_node = self._create_op_node( diff --git a/test/python/circuit/test_control_flow.py b/test/python/circuit/test_control_flow.py index bac094289d16..e827b679a4a2 100644 --- a/test/python/circuit/test_control_flow.py +++ b/test/python/circuit/test_control_flow.py @@ -14,11 +14,12 @@ import math -from ddt import ddt, data, unpack +from ddt import ddt, data, unpack, idata from qiskit.test import QiskitTestCase from qiskit.circuit import Clbit, ClassicalRegister, Instruction, Parameter, QuantumCircuit, Qubit -from qiskit.circuit.controlflow import CASE_DEFAULT +from qiskit.circuit.classical import expr, types +from qiskit.circuit.controlflow import CASE_DEFAULT, condition_resources, node_resources from qiskit.circuit.library import XGate, RXGate from qiskit.circuit.exceptions import CircuitError @@ -33,19 +34,29 @@ ) +CONDITION_PARAMETRISATION = ( + (Clbit(), True), + (ClassicalRegister(3, "test_creg"), 3), + (ClassicalRegister(3, "test_creg"), True), + expr.lift(Clbit()), + expr.logic_not(Clbit()), + expr.equal(ClassicalRegister(3, "test_creg"), 3), + expr.not_equal(1, ClassicalRegister(3, "test_creg")), +) + + @ddt class TestCreatingControlFlowOperations(QiskitTestCase): """Tests instantiation of instruction subclasses for dynamic QuantumCircuits.""" - @data( - (Clbit(), True), - (ClassicalRegister(3, "test_creg"), 3), - (ClassicalRegister(3, "test_creg"), True), - ) + @idata(CONDITION_PARAMETRISATION) def test_while_loop_instantiation(self, condition): """Verify creation and properties of a WhileLoopOp.""" body = QuantumCircuit(3, 1) - body.add_register([condition[0]] if isinstance(condition[0], Clbit) else condition[0]) + resources = condition_resources(condition) + body.add_bits(resources.clbits) + for reg in resources.cregs: + body.add_register(reg) op = WhileLoopOp(condition, body) @@ -69,6 +80,9 @@ def test_while_loop_invalid_instantiation(self): with self.assertRaisesRegex(CircuitError, r"A classical condition should be a 2-tuple"): _ = WhileLoopOp((Clbit(), None), body) + with self.assertRaisesRegex(CircuitError, r"type 'Bool\(\)'"): + _ = WhileLoopOp(expr.Value(2, types.Uint(2)), body) + with self.assertRaisesRegex(CircuitError, r"of type QuantumCircuit"): _ = WhileLoopOp(condition, XGate()) @@ -177,11 +191,7 @@ def test_for_loop_invalid_params_setter(self): with self.assertRaisesRegex(CircuitError, r"to be either of type Parameter or None"): _ = ForLoopOp(indexset, "foo", body) - @data( - (Clbit(), True), - (ClassicalRegister(3, "test_creg"), 3), - (ClassicalRegister(3, "test_creg"), True), - ) + @idata(CONDITION_PARAMETRISATION) def test_if_else_instantiation_with_else(self, condition): """Verify creation and properties of a IfElseOp with an else branch.""" true_body = QuantumCircuit(3, 1) @@ -198,11 +208,7 @@ def test_if_else_instantiation_with_else(self, condition): self.assertEqual(op.condition, condition) self.assertEqual(op.blocks, (true_body, false_body)) - @data( - (Clbit(), True), - (ClassicalRegister(3, "test_creg"), 3), - (ClassicalRegister(3, "test_creg"), True), - ) + @idata(CONDITION_PARAMETRISATION) def test_if_else_instantiation_without_else(self, condition): """Verify creation and properties of a IfElseOp without an else branch.""" true_body = QuantumCircuit(3, 1) @@ -230,6 +236,9 @@ def test_if_else_invalid_instantiation(self): with self.assertRaisesRegex(CircuitError, r"A classical condition should be a 2-tuple"): _ = IfElseOp((1, 2), true_body, false_body) + with self.assertRaisesRegex(CircuitError, r"type 'Bool\(\)'"): + _ = IfElseOp(expr.Value(2, types.Uint(2)), true_body, false_body) + with self.assertRaisesRegex(CircuitError, r"true_body parameter of type QuantumCircuit"): _ = IfElseOp(condition, XGate()) @@ -332,6 +341,47 @@ def test_switch_register(self): self.assertEqual(op.cases(), {0: case1, 1: case2, 2: case3}) self.assertEqual(list(op.blocks), [case1, case2, case3]) + def test_switch_expr_uint(self): + """Test that a switch statement can be constructed with a Uint `Expr` as a condition.""" + qubit = Qubit() + creg = ClassicalRegister(2) + case1 = QuantumCircuit([qubit], creg) + case1.x(0) + case2 = QuantumCircuit([qubit], creg) + case2.y(0) + case3 = QuantumCircuit([qubit], creg) + case3.z(0) + + op = SwitchCaseOp(expr.lift(creg), [(0, case1), (1, case2), (2, case3)]) + self.assertIsInstance(op, Instruction) + self.assertEqual(op.name, "switch_case") + self.assertEqual(op.num_qubits, 1) + self.assertEqual(op.num_clbits, 2) + self.assertEqual(op.target, expr.Var(creg, types.Uint(creg.size))) + self.assertEqual(op.cases(), {0: case1, 1: case2, 2: case3}) + self.assertEqual(list(op.blocks), [case1, case2, case3]) + + def test_switch_expr_bool(self): + """Test that a switch statement can be constructed with a Bool `Expr` as a condition.""" + qubit = Qubit() + clbit = Clbit() + case1 = QuantumCircuit([qubit, clbit]) + case1.x(0) + case2 = QuantumCircuit([qubit, clbit]) + case2.z(0) + + op = SwitchCaseOp(expr.logic_not(clbit), [(True, case1), (False, case2)]) + self.assertIsInstance(op, Instruction) + self.assertEqual(op.name, "switch_case") + self.assertEqual(op.num_qubits, 1) + self.assertEqual(op.num_clbits, 1) + self.assertEqual( + op.target, + expr.Unary(expr.Unary.Op.LOGIC_NOT, expr.Var(clbit, types.Bool()), types.Bool()), + ) + self.assertEqual(op.cases(), {True: case1, False: case2}) + self.assertEqual(list(op.blocks), [case1, case2]) + def test_switch_with_default(self): """Test that a switch statement can be constructed with a default case at the end.""" qubit = Qubit() @@ -352,6 +402,28 @@ def test_switch_with_default(self): self.assertEqual(op.cases(), {0: case1, 1: case2, CASE_DEFAULT: case3}) self.assertEqual(list(op.blocks), [case1, case2, case3]) + def test_switch_expr_with_default(self): + """Test that a switch statement can be constructed with a default case at the end when the + target is an `Expr`.""" + qubit = Qubit() + creg = ClassicalRegister(2) + case1 = QuantumCircuit([qubit], creg) + case1.x(0) + case2 = QuantumCircuit([qubit], creg) + case2.y(0) + case3 = QuantumCircuit([qubit], creg) + case3.z(0) + + target = expr.bit_xor(creg, 0b11) + op = SwitchCaseOp(target, [(0, case1), (1, case2), (CASE_DEFAULT, case3)]) + self.assertIsInstance(op, Instruction) + self.assertEqual(op.name, "switch_case") + self.assertEqual(op.num_qubits, 1) + self.assertEqual(op.num_clbits, 2) + self.assertEqual(op.target, target) + self.assertEqual(op.cases(), {0: case1, 1: case2, CASE_DEFAULT: case3}) + self.assertEqual(list(op.blocks), [case1, case2, case3]) + def test_switch_multiple_cases_to_same_block(self): """Test that it is possible to add multiple cases that apply to the same block, if they are given as a compound value. This is an allowed special case of block fall-through.""" @@ -447,6 +519,9 @@ class TestAddingControlFlowOperations(QiskitTestCase): (Clbit(), [False, True]), (ClassicalRegister(3, "test_creg"), [3, 1]), (ClassicalRegister(3, "test_creg"), [0, (1, 2), CASE_DEFAULT]), + (expr.lift(Clbit()), [False, True]), + (expr.lift(ClassicalRegister(3, "test_creg")), [3, 1]), + (expr.bit_not(ClassicalRegister(3, "test_creg")), [0, (1, 2), CASE_DEFAULT]), ) @unpack def test_appending_switch_case_op(self, target, labels): @@ -458,8 +533,13 @@ def test_appending_switch_case_op(self, target, labels): qc = QuantumCircuit(5, 2) if isinstance(target, ClassicalRegister): qc.add_register(target) - else: + elif isinstance(target, Clbit): qc.add_bits([target]) + else: + resources = node_resources(target) + qc.add_bits(resources.clbits) + for reg in resources.cregs: + qc.add_register(reg) qc.append(op, [1, 2, 3], [1]) self.assertEqual(qc.data[0].operation.name, "switch_case") @@ -472,6 +552,9 @@ def test_appending_switch_case_op(self, target, labels): (Clbit(), [False, True]), (ClassicalRegister(3, "test_creg"), [3, 1]), (ClassicalRegister(3, "test_creg"), [0, (1, 2), CASE_DEFAULT]), + (expr.lift(Clbit()), [False, True]), + (expr.lift(ClassicalRegister(3, "test_creg")), [3, 1]), + (expr.bit_not(ClassicalRegister(3, "test_creg")), [0, (1, 2), CASE_DEFAULT]), ) @unpack def test_quantumcircuit_switch(self, target, labels): @@ -481,8 +564,13 @@ def test_quantumcircuit_switch(self, target, labels): qc = QuantumCircuit(5, 2) if isinstance(target, ClassicalRegister): qc.add_register(target) - else: + elif isinstance(target, Clbit): qc.add_bits([target]) + else: + resources = node_resources(target) + qc.add_bits(resources.clbits) + for reg in resources.cregs: + qc.add_register(reg) qc.switch(target, zip(labels, bodies), [1, 2, 3], [1]) self.assertEqual(qc.data[0].operation.name, "switch_case") @@ -491,11 +579,7 @@ def test_quantumcircuit_switch(self, target, labels): self.assertEqual(qc.data[0].qubits, tuple(qc.qubits[1:4])) self.assertEqual(qc.data[0].clbits, (qc.clbits[1],)) - @data( - (Clbit(), True), - (ClassicalRegister(3, "test_creg"), 3), - (ClassicalRegister(3, "test_creg"), True), - ) + @idata(CONDITION_PARAMETRISATION) def test_appending_while_loop_op(self, condition): """Verify we can append a WhileLoopOp to a QuantumCircuit.""" body = QuantumCircuit(3, 1) @@ -511,20 +595,16 @@ def test_appending_while_loop_op(self, condition): self.assertEqual(qc.data[0].qubits, tuple(qc.qubits[1:4])) self.assertEqual(qc.data[0].clbits, (qc.clbits[1],)) - @data( - (Clbit(), True), - (ClassicalRegister(3, "test_creg"), 3), - (ClassicalRegister(3, "test_creg"), True), - ) + @idata(CONDITION_PARAMETRISATION) def test_quantumcircuit_while_loop(self, condition): """Verify we can append a WhileLoopOp to a QuantumCircuit via qc.while_loop.""" body = QuantumCircuit(3, 1) qc = QuantumCircuit(5, 2) - if isinstance(condition[0], ClassicalRegister): - qc.add_register(condition[0]) - else: - qc.add_bits([condition[0]]) + resources = condition_resources(condition) + qc.add_bits(resources.clbits) + for reg in resources.cregs: + qc.add_register(reg) qc.while_loop(condition, body, [1, 2, 3], [1]) self.assertEqual(qc.data[0].operation.name, "while_loop") @@ -567,11 +647,7 @@ def test_quantumcircuit_for_loop_op(self): self.assertEqual(qc.data[0].qubits, tuple(qc.qubits[1:4])) self.assertEqual(qc.data[0].clbits, (qc.clbits[1],)) - @data( - (Clbit(), True), - (ClassicalRegister(3, "test_creg"), 3), - (ClassicalRegister(3, "test_creg"), True), - ) + @idata(CONDITION_PARAMETRISATION) def test_appending_if_else_op(self, condition): """Verify we can append a IfElseOp to a QuantumCircuit.""" true_body = QuantumCircuit(3, 1) @@ -580,7 +656,10 @@ def test_appending_if_else_op(self, condition): op = IfElseOp(condition, true_body, false_body) qc = QuantumCircuit(5, 2) - qc.add_register([condition[0]] if isinstance(condition[0], Clbit) else condition[0]) + resources = condition_resources(condition) + qc.add_bits(resources.clbits) + for reg in resources.cregs: + qc.add_register(reg) qc.append(op, [1, 2, 3], [1]) self.assertEqual(qc.data[0].operation.name, "if_else") @@ -589,18 +668,17 @@ def test_appending_if_else_op(self, condition): self.assertEqual(qc.data[0].qubits, tuple(qc.qubits[1:4])) self.assertEqual(qc.data[0].clbits, (qc.clbits[1],)) - @data( - (Clbit(), True), - (ClassicalRegister(3, "test_creg"), 3), - (ClassicalRegister(3, "test_creg"), True), - ) + @idata(CONDITION_PARAMETRISATION) def test_quantumcircuit_if_else_op(self, condition): """Verify we can append a IfElseOp to a QuantumCircuit via qc.if_else.""" true_body = QuantumCircuit(3, 1) false_body = QuantumCircuit(3, 1) qc = QuantumCircuit(5, 2) - qc.add_register([condition[0]] if isinstance(condition[0], Clbit) else condition[0]) + resources = condition_resources(condition) + qc.add_bits(resources.clbits) + for reg in resources.cregs: + qc.add_register(reg) qc.if_else(condition, true_body, false_body, [1, 2, 3], [1]) self.assertEqual(qc.data[0].operation.name, "if_else") @@ -609,17 +687,16 @@ def test_quantumcircuit_if_else_op(self, condition): self.assertEqual(qc.data[0].qubits, tuple(qc.qubits[1:4])) self.assertEqual(qc.data[0].clbits, (qc.clbits[1],)) - @data( - (Clbit(), True), - (ClassicalRegister(3, "test_creg"), 3), - (ClassicalRegister(3, "test_creg"), True), - ) + @idata(CONDITION_PARAMETRISATION) def test_quantumcircuit_if_test_op(self, condition): """Verify we can append a IfElseOp to a QuantumCircuit via qc.if_test.""" true_body = QuantumCircuit(3, 1) qc = QuantumCircuit(5, 2) - qc.add_register([condition[0]] if isinstance(condition[0], Clbit) else condition[0]) + resources = condition_resources(condition) + qc.add_bits(resources.clbits) + for reg in resources.cregs: + qc.add_register(reg) qc.if_test(condition, true_body, [1, 2, 3], [1]) self.assertEqual(qc.data[0].operation.name, "if_else") @@ -628,11 +705,7 @@ def test_quantumcircuit_if_test_op(self, condition): self.assertEqual(qc.data[0].qubits, tuple(qc.qubits[1:4])) self.assertEqual(qc.data[0].clbits, (qc.clbits[1],)) - @data( - (Clbit(), True), - (ClassicalRegister(3, "test_creg"), 3), - (ClassicalRegister(3, "test_creg"), True), - ) + @idata(CONDITION_PARAMETRISATION) def test_appending_if_else_op_with_condition_outside(self, condition): """Verify we catch if IfElseOp has a condition outside outer circuit.""" true_body = QuantumCircuit(3, 1)