-
Notifications
You must be signed in to change notification settings - Fork 2.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
new unroll for-loops transpilation pass #9670
Changes from 7 commits
5158eed
57313e0
eb9b3af
7f1e42a
5a4e8ce
aed085e
9001559
fd8d939
b9f0a5f
4045b35
9422e4f
41312ad
b33e1fa
fbbcdb7
5b13707
925e6fc
ea53541
c1966fe
87bc0b0
8039977
c20f236
f85ae3c
be0a50d
c445575
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,78 @@ | ||
# This code is part of Qiskit. | ||
# | ||
# (C) Copyright IBM 2023. | ||
# | ||
# This code is licensed under the Apache License, Version 2.0. You may | ||
# obtain a copy of this license in the LICENSE.txt file in the root directory | ||
# of this source tree or at http://www.apache.org/licenses/LICENSE-2.0. | ||
# | ||
# Any modifications or derivative works of this code must retain this | ||
# copyright notice, and modified files need to carry a notice indicating | ||
# that they have been altered from the originals. | ||
|
||
""" UnrollForLoops transpilation pass """ | ||
|
||
from qiskit.circuit import ForLoopOp, ContinueLoopOp, BreakLoopOp, IfElseOp | ||
from qiskit.transpiler.basepasses import TransformationPass | ||
from qiskit.transpiler.passes.utils import control_flow | ||
from qiskit.converters import circuit_to_dag | ||
|
||
|
||
class UnrollForLoops(TransformationPass): | ||
"""UnrollForLoops transpilation pass unrolls for-loops when possible. Things like | ||
`for x in {0, 3, 4} {rx(x) qr[1];}` will turn into `rx(0) qr[1]; rx(3) qr[1]; rx(4) qr[1];`. | ||
""" | ||
|
||
def __init__(self, max_target_depth=-1): | ||
"""UnrollForLoops transpilation pass unrolls for-loops when possible. Things like | ||
`for x in {0, 3, 4} {rx(x) qr[1];}` will turn into `rx(0) qr[1]; rx(3) qr[1]; rx(4) qr[1];`. | ||
|
||
Args: | ||
max_target_depth (int): Optional. Checks if the unrolled block is over a particular depth. | ||
To disable the check, use ``-1`` (Default) | ||
""" | ||
super().__init__() | ||
self.max_target_depth = max_target_depth | ||
|
||
@control_flow.trivial_recurse | ||
def run(self, dag): | ||
"""Run the UnrollForLoops pass on `dag`. | ||
|
||
Args: | ||
dag (DAGCircuit): the directed acyclic graph to run on. | ||
|
||
Returns: | ||
DAGCircuit: Transformed DAG. | ||
""" | ||
for forloop_op in dag.op_nodes(ForLoopOp): | ||
(indexset, loop_parameter, body) = forloop_op.op.params | ||
|
||
# skip unrolling if it results in bigger than max_target_depth | ||
if self.max_target_depth > 0 and len(indexset) * body.depth() > self.max_target_depth: | ||
1ucian0 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
continue | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
I think that's probably actually the right choice here, but I think it's worth mentioning: if an inner for loop is not unrolled because of the max-depth constraint (during the depth-first nature of There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
# skip unroll when break_loop or continue_loop inside body | ||
if UnrollForLoops.body_contains_continue_or_break(body): | ||
continue | ||
|
||
unrolled_dag = circuit_to_dag(body).copy_empty_like() | ||
for index_value in indexset: | ||
bound_body_dag = circuit_to_dag(body.bind_parameters({loop_parameter: index_value})) | ||
unrolled_dag.compose(bound_body_dag, inplace=True) | ||
dag.substitute_node_with_dag(forloop_op, unrolled_dag) | ||
|
||
return dag | ||
|
||
@classmethod | ||
def body_contains_continue_or_break(cls, circuit): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There's no real reason for this to be a class method (doesn't need to access There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done in b33e1fa |
||
"""Checks if a circuit contains ``continue``s or ``break``s. Conditional bodies are inspected.""" | ||
for inst in circuit.data: | ||
operation = inst.operation | ||
for type_ in [ContinueLoopOp, BreakLoopOp]: | ||
if isinstance(operation, type_): | ||
return True | ||
1ucian0 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
if isinstance(operation, IfElseOp): | ||
for block in operation.params: | ||
if UnrollForLoops.body_contains_continue_or_break(block): | ||
return True | ||
return False |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
--- | ||
features: | ||
- | | ||
The transpiler pass :class:`~.UnrollForLoops` was added. It unrolls for-loops when possible (if | ||
no :class:`~.ContinueLoopOp` or a :class:`~.BreakLoopOp` is inside the body block). | ||
For example ``for x in {0, 3, 4} {rx(x) qr[1];}`` gets converted into ``rx(0) qr[1]; rx(3) qr[1]; rx(4) qr[1];``. |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,110 @@ | ||
# This code is part of Qiskit. | ||
# | ||
# (C) Copyright IBM 2023. | ||
# | ||
# This code is licensed under the Apache License, Version 2.0. You may | ||
# obtain a copy of this license in the LICENSE.txt file in the root directory | ||
# of this source tree or at http://www.apache.org/licenses/LICENSE-2.0. | ||
# | ||
# Any modifications or derivative works of this code must retain this | ||
# copyright notice, and modified files need to carry a notice indicating | ||
# that they have been altered from the originals. | ||
|
||
"""Test the UnrollForLoops pass""" | ||
|
||
import unittest | ||
|
||
from qiskit.circuit import QuantumCircuit, Parameter, QuantumRegister, ClassicalRegister | ||
from qiskit.transpiler import PassManager | ||
from qiskit.test import QiskitTestCase | ||
from qiskit.transpiler.passes.optimization.unroll_forloops import UnrollForLoops | ||
|
||
|
||
class TestUnrool(QiskitTestCase): | ||
1ucian0 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
"""Test UnrollForLoops pass""" | ||
|
||
def test_range(self): | ||
"""Check simples unrolling case""" | ||
qreg, creg = QuantumRegister(5, "q"), ClassicalRegister(2, "c") | ||
|
||
body = QuantumCircuit(3, 1) | ||
loop_parameter = Parameter("foo") | ||
indexset = range(0, 10, 2) | ||
|
||
body.rx(loop_parameter, [0, 1, 2]) | ||
|
||
circuit = QuantumCircuit(qreg, creg) | ||
circuit.for_loop(indexset, loop_parameter, body, [1, 2, 3], [1]) | ||
|
||
expected = QuantumCircuit(qreg, creg) | ||
for index_loop in indexset: | ||
expected.rx(index_loop, [1, 2, 3]) | ||
|
||
passmanager = PassManager() | ||
passmanager.append(UnrollForLoops()) | ||
result = passmanager.run(circuit) | ||
|
||
self.assertEqual(result, expected) | ||
|
||
def test_skip_continue_loop(self): | ||
"""Unrolling should not be done when a `continue;` in the body""" | ||
parameter = Parameter("x") | ||
loop_body = QuantumCircuit(1) | ||
loop_body.rx(parameter, 0) | ||
loop_body.continue_loop() | ||
|
||
qc = QuantumCircuit(2) | ||
qc.for_loop([0, 3, 4], parameter, loop_body, [1], []) | ||
qc.x(0) | ||
|
||
passmanager = PassManager() | ||
passmanager.append(UnrollForLoops()) | ||
result = passmanager.run(qc) | ||
|
||
self.assertEqual(result, qc) | ||
|
||
def test_skip_continue_in_conditional(self): | ||
"""Unrolling should not be done when a `continue;` is in a nested condition""" | ||
parameter = Parameter("x") | ||
|
||
true_body = QuantumCircuit(1) | ||
true_body.continue_loop() | ||
false_body = QuantumCircuit(1) | ||
false_body.rx(parameter, 0) | ||
|
||
qr = QuantumRegister(2, name="qr") | ||
cr = ClassicalRegister(2, name="cr") | ||
loop_body = QuantumCircuit(qr, cr) | ||
loop_body.if_else((cr, 0), true_body, false_body, [1], []) | ||
loop_body.x(0) | ||
|
||
qc = QuantumCircuit(qr, cr) | ||
qc.for_loop([0, 3, 4], parameter, loop_body, qr, cr) | ||
|
||
passmanager = PassManager() | ||
passmanager.append(UnrollForLoops()) | ||
result = passmanager.run(qc) | ||
|
||
self.assertEqual(result, qc) | ||
|
||
def test_max_target_depth(self): | ||
"""Unrolling should not be done when results over `max_target_depth`""" | ||
|
||
loop_parameter = Parameter("foo") | ||
indexset = range(0, 10, 2) | ||
body = QuantumCircuit(3, 1) | ||
body.rx(loop_parameter, [0, 1, 2]) | ||
|
||
qreg, creg = QuantumRegister(5, "q"), ClassicalRegister(2, "c") | ||
circuit = QuantumCircuit(qreg, creg) | ||
circuit.for_loop(indexset, loop_parameter, body, [1, 2, 3], [1]) | ||
|
||
passmanager = PassManager() | ||
passmanager.append(UnrollForLoops(max_target_depth=2)) | ||
result = passmanager.run(circuit) | ||
|
||
self.assertEqual(result, circuit) | ||
|
||
|
||
if __name__ == "__main__": | ||
unittest.main() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No need to repeat documentation between the class docstring and
__init__
- both get concatenated into the docs output. That said, we need to ensure thatqiskit.transpiler.passes.<whichever category>
andqiskit.transpiler.passes
both import this into their namespace, and theqiskit.transpiler.passes
docstring initiates documentation of the class.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Indeed! fixed in fbbcdb7