diff --git a/qiskit/transpiler/passes/__init__.py b/qiskit/transpiler/passes/__init__.py index 662eb6c3324c..73e6ea055165 100644 --- a/qiskit/transpiler/passes/__init__.py +++ b/qiskit/transpiler/passes/__init__.py @@ -176,6 +176,7 @@ ContainsInstruction GatesInBasis ConvertConditionsToIfOps + UnrollForLoops """ # layout selection (placement) @@ -290,3 +291,4 @@ from .utils import ContainsInstruction from .utils import GatesInBasis from .utils import ConvertConditionsToIfOps +from .utils import UnrollForLoops diff --git a/qiskit/transpiler/passes/utils/__init__.py b/qiskit/transpiler/passes/utils/__init__.py index c69244b20b4a..fd2d00bbc4b6 100644 --- a/qiskit/transpiler/passes/utils/__init__.py +++ b/qiskit/transpiler/passes/utils/__init__.py @@ -27,6 +27,7 @@ from .contains_instruction import ContainsInstruction from .gates_basis import GatesInBasis from .convert_conditions_to_if_ops import ConvertConditionsToIfOps +from .unroll_forloops import UnrollForLoops from .minimum_point import MinimumPoint # Utility functions diff --git a/qiskit/transpiler/passes/utils/unroll_forloops.py b/qiskit/transpiler/passes/utils/unroll_forloops.py new file mode 100644 index 000000000000..c83a47f72eec --- /dev/null +++ b/qiskit/transpiler/passes/utils/unroll_forloops.py @@ -0,0 +1,79 @@ +# This code is part of Qiskit. +# +# (C) Copyright IBM 2023. +# +# This code is licensed under the Apache License, Version 2.0. You may +# obtain a copy of this license in the LICENSE.txt file in the root directory +# of this source tree or at http://www.apache.org/licenses/LICENSE-2.0. +# +# Any modifications or derivative works of this code must retain this +# copyright notice, and modified files need to carry a notice indicating +# that they have been altered from the originals. + +""" UnrollForLoops transpilation pass """ + +from qiskit.circuit import ForLoopOp, ContinueLoopOp, BreakLoopOp, IfElseOp +from qiskit.transpiler.basepasses import TransformationPass +from qiskit.transpiler.passes.utils import control_flow +from qiskit.converters import circuit_to_dag + + +class UnrollForLoops(TransformationPass): + """``UnrollForLoops`` transpilation pass unrolls for-loops when possible.""" + + def __init__(self, max_target_depth=-1): + """Things like `for x in {0, 3, 4} {rx(x) qr[1];}` will turn into + `rx(0) qr[1]; rx(3) qr[1]; rx(4) qr[1];`. + + .. note:: + The ``UnrollForLoops`` unrolls only one level of block depth. No inner loop will + be considered by ``max_target_depth``. + + Args: + max_target_depth (int): Optional. Checks if the unrolled block is over a particular + subcircuit depth. To disable the check, use ``-1`` (Default). + """ + super().__init__() + self.max_target_depth = max_target_depth + + @control_flow.trivial_recurse + def run(self, dag): + """Run the UnrollForLoops pass on `dag`. + + Args: + dag (DAGCircuit): the directed acyclic graph to run on. + + Returns: + DAGCircuit: Transformed DAG. + """ + for forloop_op in dag.op_nodes(ForLoopOp): + (indexset, loop_param, body) = forloop_op.op.params + + # skip unrolling if it results in bigger than max_target_depth + if 0 < self.max_target_depth < len(indexset) * body.depth(): + continue + + # skip unroll when break_loop or continue_loop inside body + if _body_contains_continue_or_break(body): + continue + + unrolled_dag = circuit_to_dag(body).copy_empty_like() + for index_value in indexset: + bound_body = body.bind_parameters({loop_param: index_value}) if loop_param else body + unrolled_dag.compose(circuit_to_dag(bound_body), inplace=True) + dag.substitute_node_with_dag(forloop_op, unrolled_dag) + + return dag + + +def _body_contains_continue_or_break(circuit): + """Checks if a circuit contains ``continue``s or ``break``s. Conditional bodies are inspected.""" + for inst in circuit.data: + operation = inst.operation + if isinstance(operation, (ContinueLoopOp, BreakLoopOp)): + return True + if isinstance(operation, IfElseOp): + for block in operation.params: + if _body_contains_continue_or_break(block): + return True + return False diff --git a/releasenotes/notes/unroll-forloops-7bf8000620f738e7.yaml b/releasenotes/notes/unroll-forloops-7bf8000620f738e7.yaml new file mode 100644 index 000000000000..d224c6705fb4 --- /dev/null +++ b/releasenotes/notes/unroll-forloops-7bf8000620f738e7.yaml @@ -0,0 +1,6 @@ +--- +features: + - | + The transpiler pass :class:`~.UnrollForLoops` was added. It unrolls for-loops when possible (if + no :class:`~.ContinueLoopOp` or a :class:`~.BreakLoopOp` is inside the body block). + For example ``for x in {0, 3, 4} {rx(x) qr[1];}`` gets converted into ``rx(0) qr[1]; rx(3) qr[1]; rx(4) qr[1];``. diff --git a/test/python/transpiler/test_unroll_forloops.py b/test/python/transpiler/test_unroll_forloops.py new file mode 100644 index 000000000000..ffff53761a72 --- /dev/null +++ b/test/python/transpiler/test_unroll_forloops.py @@ -0,0 +1,166 @@ +# This code is part of Qiskit. +# +# (C) Copyright IBM 2023. +# +# This code is licensed under the Apache License, Version 2.0. You may +# obtain a copy of this license in the LICENSE.txt file in the root directory +# of this source tree or at http://www.apache.org/licenses/LICENSE-2.0. +# +# Any modifications or derivative works of this code must retain this +# copyright notice, and modified files need to carry a notice indicating +# that they have been altered from the originals. + +"""Test the UnrollForLoops pass""" + +import unittest + +from qiskit.circuit import QuantumCircuit, Parameter, QuantumRegister, ClassicalRegister +from qiskit.transpiler import PassManager +from qiskit.test import QiskitTestCase +from qiskit.transpiler.passes.utils.unroll_forloops import UnrollForLoops + + +class TestUnrollForLoops(QiskitTestCase): + """Test UnrollForLoops pass""" + + def test_range(self): + """Check simples unrolling case""" + qreg, creg = QuantumRegister(5, "q"), ClassicalRegister(2, "c") + + body = QuantumCircuit(3, 1) + loop_parameter = Parameter("foo") + indexset = range(0, 10, 2) + + body.rx(loop_parameter, [0, 1, 2]) + + circuit = QuantumCircuit(qreg, creg) + circuit.for_loop(indexset, loop_parameter, body, [1, 2, 3], [1]) + + expected = QuantumCircuit(qreg, creg) + for index_loop in indexset: + expected.rx(index_loop, [1, 2, 3]) + + passmanager = PassManager() + passmanager.append(UnrollForLoops()) + result = passmanager.run(circuit) + + self.assertEqual(result, expected) + + def test_parameterless_range(self): + """Check simples unrolling case when there is not parameter""" + qreg, creg = QuantumRegister(5, "q"), ClassicalRegister(2, "c") + + body = QuantumCircuit(3, 1) + indexset = range(0, 10, 2) + + body.h([0, 1, 2]) + + circuit = QuantumCircuit(qreg, creg) + circuit.for_loop(indexset, None, body, [1, 2, 3], [1]) + + expected = QuantumCircuit(qreg, creg) + for _ in indexset: + expected.h([1, 2, 3]) + + passmanager = PassManager() + passmanager.append(UnrollForLoops()) + result = passmanager.run(circuit) + + self.assertEqual(result, expected) + + def test_nested_forloop(self): + """Test unrolls only one level of nested for-loops""" + circuit = QuantumCircuit(1) + twice = range(2) + with circuit.for_loop(twice): + with circuit.for_loop(twice): + circuit.h(0) + + expected = QuantumCircuit(1) + for _ in twice: + for _ in twice: + expected.h(0) + + passmanager = PassManager() + passmanager.append(UnrollForLoops()) + result = passmanager.run(circuit) + + self.assertEqual(result, expected) + + def test_skip_continue_loop(self): + """Unrolling should not be done when a `continue;` in the body""" + parameter = Parameter("x") + loop_body = QuantumCircuit(1) + loop_body.rx(parameter, 0) + loop_body.continue_loop() + + qc = QuantumCircuit(2) + qc.for_loop([0, 3, 4], parameter, loop_body, [1], []) + qc.x(0) + + passmanager = PassManager() + passmanager.append(UnrollForLoops()) + result = passmanager.run(qc) + + self.assertEqual(result, qc) + + def test_skip_continue_in_conditional(self): + """Unrolling should not be done when a `continue;` is in a nested condition""" + parameter = Parameter("x") + + true_body = QuantumCircuit(1) + true_body.continue_loop() + false_body = QuantumCircuit(1) + false_body.rx(parameter, 0) + + qr = QuantumRegister(2, name="qr") + cr = ClassicalRegister(2, name="cr") + loop_body = QuantumCircuit(qr, cr) + loop_body.if_else((cr, 0), true_body, false_body, [1], []) + loop_body.x(0) + + qc = QuantumCircuit(qr, cr) + qc.for_loop([0, 3, 4], parameter, loop_body, qr, cr) + + passmanager = PassManager() + passmanager.append(UnrollForLoops()) + result = passmanager.run(qc) + + self.assertEqual(result, qc) + + def test_skip_continue_c_if(self): + """Unrolling should not be done when a break in the c_if in the body""" + circuit = QuantumCircuit(2, 1) + with circuit.for_loop(range(2)): + circuit.h(0) + circuit.cx(0, 1) + circuit.measure(0, 0) + circuit.break_loop().c_if(0, True) + + passmanager = PassManager() + passmanager.append(UnrollForLoops()) + result = passmanager.run(circuit) + + self.assertEqual(result, circuit) + + def test_max_target_depth(self): + """Unrolling should not be done when results over `max_target_depth`""" + + loop_parameter = Parameter("foo") + indexset = range(0, 10, 2) + body = QuantumCircuit(3, 1) + body.rx(loop_parameter, [0, 1, 2]) + + qreg, creg = QuantumRegister(5, "q"), ClassicalRegister(2, "c") + circuit = QuantumCircuit(qreg, creg) + circuit.for_loop(indexset, loop_parameter, body, [1, 2, 3], [1]) + + passmanager = PassManager() + passmanager.append(UnrollForLoops(max_target_depth=2)) + result = passmanager.run(circuit) + + self.assertEqual(result, circuit) + + +if __name__ == "__main__": + unittest.main()