From 1f3c21735fa3ddff34ece39eb8822c1764522ecb Mon Sep 17 00:00:00 2001 From: John Lapeyre Date: Tue, 27 Jun 2023 23:08:54 -0400 Subject: [PATCH] Support control flow in ConsolidateBlocks --- .../passes/optimization/consolidate_blocks.py | 37 +++++++++++++++++++ .../transpiler/test_consolidate_blocks.py | 29 +++++++++++++++ 2 files changed, 66 insertions(+) diff --git a/qiskit/transpiler/passes/optimization/consolidate_blocks.py b/qiskit/transpiler/passes/optimization/consolidate_blocks.py index 9f573728572c..af7eaa4b25e4 100644 --- a/qiskit/transpiler/passes/optimization/consolidate_blocks.py +++ b/qiskit/transpiler/passes/optimization/consolidate_blocks.py @@ -23,6 +23,8 @@ from qiskit.extensions import UnitaryGate from qiskit.circuit.library.standard_gates import CXGate from qiskit.transpiler.basepasses import TransformationPass +from qiskit.circuit import ControlFlowOp + from qiskit.transpiler.passes.synthesis import unitary_synthesis @@ -71,6 +73,9 @@ def __init__( ) else: self.decomposer = TwoQubitBasisDecomposer(CXGate()) + self._basis_gates = basis_gates + self._kak_basis_gate = kak_basis_gate + self._approximation_degree = approximation_degree def run(self, dag): """Run the ConsolidateBlocks pass on `dag`. @@ -159,11 +164,43 @@ def run(self, dag): dag.remove_op_node(node) else: dag.replace_block_with_op(run, unitary, {qubit: 0}, cycle_check=False) + + dag = self._handle_control_flow_ops(dag) + # Clear collected blocks and runs as they are no longer valid after consolidation if "run_list" in self.property_set: del self.property_set["run_list"] if "block_list" in self.property_set: del self.property_set["block_list"] + + return dag + + def _handle_control_flow_ops(self, dag): + """ + This is similar to transpiler/passes/utils/control_flow.py except that the + collect blocks is redone for the control flow blocks. + """ + from qiskit.transpiler import PassManager + from qiskit.transpiler.passes import Collect2qBlocks + from qiskit.transpiler.passes import Collect1qRuns + + pass_manager = PassManager() + if "run_list" in self.property_set: + pass_manager.append(Collect1qRuns()) + if "block_list" in self.property_set: + pass_manager.append(Collect2qBlocks()) + + new_consolidate_blocks = self.__class__(self._kak_basis_gate, self.force_consolidate, + self._basis_gates, self._approximation_degree, + self.target) + + pass_manager.append(new_consolidate_blocks) + for node in dag.op_nodes(ControlFlowOp): + mapped_blocks = [] + for block in node.op.blocks: + new_circ = pass_manager.run(block) + mapped_blocks.append(new_circ) + node.op = node.op.replace_blocks(mapped_blocks) return dag def _check_not_in_basis(self, gate_name, qargs, global_index_map): diff --git a/test/python/transpiler/test_consolidate_blocks.py b/test/python/transpiler/test_consolidate_blocks.py index 3332026e436b..f9c51ba237a2 100644 --- a/test/python/transpiler/test_consolidate_blocks.py +++ b/test/python/transpiler/test_consolidate_blocks.py @@ -428,6 +428,35 @@ def test_identity_1q_unitary_is_removed(self): pm = PassManager([Collect2qBlocks(), Collect1qRuns(), ConsolidateBlocks()]) self.assertEqual(QuantumCircuit(5), pm.run(qc)) + def test_descent_into_control_flow(self): + """Test consolidation in blocks of control flow op is the same as at top level.""" + qc = QuantumCircuit(2) + u2gate1 = U2Gate(-1.2, np.pi) + u2gate2 = U2Gate(-3.4, np.pi) + qc.append(u2gate1, [0]) + qc.append(u2gate2, [1]) + qc.cx(0, 1) + qc.cx(1, 0) + + pass_manager = PassManager() + pass_manager.append(Collect2qBlocks()) + pass_manager.append(ConsolidateBlocks(force_consolidate=True)) + result_top = pass_manager.run(qc) + + qc_control_flow = QuantumCircuit(2, 1) + with qc_control_flow.if_test((0, False)): + qc_control_flow.append(u2gate1, [0]) + qc_control_flow.append(u2gate2, [1]) + qc_control_flow.cx(0, 1) + qc_control_flow.cx(1, 0) + + pass_manager = PassManager() + pass_manager.append(Collect2qBlocks()) + pass_manager.append(ConsolidateBlocks(force_consolidate=True)) + result_block = pass_manager.run(qc_control_flow) + gate_top = result_top[0].operation + gate_block = result_block[0].operation.blocks[0][0].operation + self.assertEqual(gate_top, gate_block) if __name__ == "__main__": unittest.main()