forked from Qiskit/qiskit
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
new unroll for-loops transpilation pass (Qiskit#9670)
* dynamic circuit optimization: unroll for loops * exceptions * check inside conditional blocks * docs * reno * Update qiskit/transpiler/passes/optimization/unroll_forloops.py Co-authored-by: Jake Lishman <jake@binhbar.com> * parameterless support * moved to utils * no classmethod, but function * docstring and __init__ * Update qiskit/transpiler/passes/optimization/unroll_forloops.py Co-authored-by: Jake Lishman <jake@binhbar.com> * Update test/python/transpiler/test_unroll_forloops.py Co-authored-by: Jake Lishman <jake@binhbar.com> * nested for-loops test * docstring note * new test with c_if * Remove commented-out code --------- Co-authored-by: Jake Lishman <jake@binhbar.com>
- Loading branch information
1 parent
30fabba
commit e98431e
Showing
5 changed files
with
254 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,79 @@ | ||
# This code is part of Qiskit. | ||
# | ||
# (C) Copyright IBM 2023. | ||
# | ||
# This code is licensed under the Apache License, Version 2.0. You may | ||
# obtain a copy of this license in the LICENSE.txt file in the root directory | ||
# of this source tree or at http://www.apache.org/licenses/LICENSE-2.0. | ||
# | ||
# Any modifications or derivative works of this code must retain this | ||
# copyright notice, and modified files need to carry a notice indicating | ||
# that they have been altered from the originals. | ||
|
||
""" UnrollForLoops transpilation pass """ | ||
|
||
from qiskit.circuit import ForLoopOp, ContinueLoopOp, BreakLoopOp, IfElseOp | ||
from qiskit.transpiler.basepasses import TransformationPass | ||
from qiskit.transpiler.passes.utils import control_flow | ||
from qiskit.converters import circuit_to_dag | ||
|
||
|
||
class UnrollForLoops(TransformationPass): | ||
"""``UnrollForLoops`` transpilation pass unrolls for-loops when possible.""" | ||
|
||
def __init__(self, max_target_depth=-1): | ||
"""Things like `for x in {0, 3, 4} {rx(x) qr[1];}` will turn into | ||
`rx(0) qr[1]; rx(3) qr[1]; rx(4) qr[1];`. | ||
.. note:: | ||
The ``UnrollForLoops`` unrolls only one level of block depth. No inner loop will | ||
be considered by ``max_target_depth``. | ||
Args: | ||
max_target_depth (int): Optional. Checks if the unrolled block is over a particular | ||
subcircuit depth. To disable the check, use ``-1`` (Default). | ||
""" | ||
super().__init__() | ||
self.max_target_depth = max_target_depth | ||
|
||
@control_flow.trivial_recurse | ||
def run(self, dag): | ||
"""Run the UnrollForLoops pass on `dag`. | ||
Args: | ||
dag (DAGCircuit): the directed acyclic graph to run on. | ||
Returns: | ||
DAGCircuit: Transformed DAG. | ||
""" | ||
for forloop_op in dag.op_nodes(ForLoopOp): | ||
(indexset, loop_param, body) = forloop_op.op.params | ||
|
||
# skip unrolling if it results in bigger than max_target_depth | ||
if 0 < self.max_target_depth < len(indexset) * body.depth(): | ||
continue | ||
|
||
# skip unroll when break_loop or continue_loop inside body | ||
if _body_contains_continue_or_break(body): | ||
continue | ||
|
||
unrolled_dag = circuit_to_dag(body).copy_empty_like() | ||
for index_value in indexset: | ||
bound_body = body.bind_parameters({loop_param: index_value}) if loop_param else body | ||
unrolled_dag.compose(circuit_to_dag(bound_body), inplace=True) | ||
dag.substitute_node_with_dag(forloop_op, unrolled_dag) | ||
|
||
return dag | ||
|
||
|
||
def _body_contains_continue_or_break(circuit): | ||
"""Checks if a circuit contains ``continue``s or ``break``s. Conditional bodies are inspected.""" | ||
for inst in circuit.data: | ||
operation = inst.operation | ||
if isinstance(operation, (ContinueLoopOp, BreakLoopOp)): | ||
return True | ||
if isinstance(operation, IfElseOp): | ||
for block in operation.params: | ||
if _body_contains_continue_or_break(block): | ||
return True | ||
return False |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
--- | ||
features: | ||
- | | ||
The transpiler pass :class:`~.UnrollForLoops` was added. It unrolls for-loops when possible (if | ||
no :class:`~.ContinueLoopOp` or a :class:`~.BreakLoopOp` is inside the body block). | ||
For example ``for x in {0, 3, 4} {rx(x) qr[1];}`` gets converted into ``rx(0) qr[1]; rx(3) qr[1]; rx(4) qr[1];``. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,166 @@ | ||
# This code is part of Qiskit. | ||
# | ||
# (C) Copyright IBM 2023. | ||
# | ||
# This code is licensed under the Apache License, Version 2.0. You may | ||
# obtain a copy of this license in the LICENSE.txt file in the root directory | ||
# of this source tree or at http://www.apache.org/licenses/LICENSE-2.0. | ||
# | ||
# Any modifications or derivative works of this code must retain this | ||
# copyright notice, and modified files need to carry a notice indicating | ||
# that they have been altered from the originals. | ||
|
||
"""Test the UnrollForLoops pass""" | ||
|
||
import unittest | ||
|
||
from qiskit.circuit import QuantumCircuit, Parameter, QuantumRegister, ClassicalRegister | ||
from qiskit.transpiler import PassManager | ||
from qiskit.test import QiskitTestCase | ||
from qiskit.transpiler.passes.utils.unroll_forloops import UnrollForLoops | ||
|
||
|
||
class TestUnrollForLoops(QiskitTestCase): | ||
"""Test UnrollForLoops pass""" | ||
|
||
def test_range(self): | ||
"""Check simples unrolling case""" | ||
qreg, creg = QuantumRegister(5, "q"), ClassicalRegister(2, "c") | ||
|
||
body = QuantumCircuit(3, 1) | ||
loop_parameter = Parameter("foo") | ||
indexset = range(0, 10, 2) | ||
|
||
body.rx(loop_parameter, [0, 1, 2]) | ||
|
||
circuit = QuantumCircuit(qreg, creg) | ||
circuit.for_loop(indexset, loop_parameter, body, [1, 2, 3], [1]) | ||
|
||
expected = QuantumCircuit(qreg, creg) | ||
for index_loop in indexset: | ||
expected.rx(index_loop, [1, 2, 3]) | ||
|
||
passmanager = PassManager() | ||
passmanager.append(UnrollForLoops()) | ||
result = passmanager.run(circuit) | ||
|
||
self.assertEqual(result, expected) | ||
|
||
def test_parameterless_range(self): | ||
"""Check simples unrolling case when there is not parameter""" | ||
qreg, creg = QuantumRegister(5, "q"), ClassicalRegister(2, "c") | ||
|
||
body = QuantumCircuit(3, 1) | ||
indexset = range(0, 10, 2) | ||
|
||
body.h([0, 1, 2]) | ||
|
||
circuit = QuantumCircuit(qreg, creg) | ||
circuit.for_loop(indexset, None, body, [1, 2, 3], [1]) | ||
|
||
expected = QuantumCircuit(qreg, creg) | ||
for _ in indexset: | ||
expected.h([1, 2, 3]) | ||
|
||
passmanager = PassManager() | ||
passmanager.append(UnrollForLoops()) | ||
result = passmanager.run(circuit) | ||
|
||
self.assertEqual(result, expected) | ||
|
||
def test_nested_forloop(self): | ||
"""Test unrolls only one level of nested for-loops""" | ||
circuit = QuantumCircuit(1) | ||
twice = range(2) | ||
with circuit.for_loop(twice): | ||
with circuit.for_loop(twice): | ||
circuit.h(0) | ||
|
||
expected = QuantumCircuit(1) | ||
for _ in twice: | ||
for _ in twice: | ||
expected.h(0) | ||
|
||
passmanager = PassManager() | ||
passmanager.append(UnrollForLoops()) | ||
result = passmanager.run(circuit) | ||
|
||
self.assertEqual(result, expected) | ||
|
||
def test_skip_continue_loop(self): | ||
"""Unrolling should not be done when a `continue;` in the body""" | ||
parameter = Parameter("x") | ||
loop_body = QuantumCircuit(1) | ||
loop_body.rx(parameter, 0) | ||
loop_body.continue_loop() | ||
|
||
qc = QuantumCircuit(2) | ||
qc.for_loop([0, 3, 4], parameter, loop_body, [1], []) | ||
qc.x(0) | ||
|
||
passmanager = PassManager() | ||
passmanager.append(UnrollForLoops()) | ||
result = passmanager.run(qc) | ||
|
||
self.assertEqual(result, qc) | ||
|
||
def test_skip_continue_in_conditional(self): | ||
"""Unrolling should not be done when a `continue;` is in a nested condition""" | ||
parameter = Parameter("x") | ||
|
||
true_body = QuantumCircuit(1) | ||
true_body.continue_loop() | ||
false_body = QuantumCircuit(1) | ||
false_body.rx(parameter, 0) | ||
|
||
qr = QuantumRegister(2, name="qr") | ||
cr = ClassicalRegister(2, name="cr") | ||
loop_body = QuantumCircuit(qr, cr) | ||
loop_body.if_else((cr, 0), true_body, false_body, [1], []) | ||
loop_body.x(0) | ||
|
||
qc = QuantumCircuit(qr, cr) | ||
qc.for_loop([0, 3, 4], parameter, loop_body, qr, cr) | ||
|
||
passmanager = PassManager() | ||
passmanager.append(UnrollForLoops()) | ||
result = passmanager.run(qc) | ||
|
||
self.assertEqual(result, qc) | ||
|
||
def test_skip_continue_c_if(self): | ||
"""Unrolling should not be done when a break in the c_if in the body""" | ||
circuit = QuantumCircuit(2, 1) | ||
with circuit.for_loop(range(2)): | ||
circuit.h(0) | ||
circuit.cx(0, 1) | ||
circuit.measure(0, 0) | ||
circuit.break_loop().c_if(0, True) | ||
|
||
passmanager = PassManager() | ||
passmanager.append(UnrollForLoops()) | ||
result = passmanager.run(circuit) | ||
|
||
self.assertEqual(result, circuit) | ||
|
||
def test_max_target_depth(self): | ||
"""Unrolling should not be done when results over `max_target_depth`""" | ||
|
||
loop_parameter = Parameter("foo") | ||
indexset = range(0, 10, 2) | ||
body = QuantumCircuit(3, 1) | ||
body.rx(loop_parameter, [0, 1, 2]) | ||
|
||
qreg, creg = QuantumRegister(5, "q"), ClassicalRegister(2, "c") | ||
circuit = QuantumCircuit(qreg, creg) | ||
circuit.for_loop(indexset, loop_parameter, body, [1, 2, 3], [1]) | ||
|
||
passmanager = PassManager() | ||
passmanager.append(UnrollForLoops(max_target_depth=2)) | ||
result = passmanager.run(circuit) | ||
|
||
self.assertEqual(result, circuit) | ||
|
||
|
||
if __name__ == "__main__": | ||
unittest.main() |