Skip to content

Commit

Permalink
Recreate full dag instead of inplace substitution in BasisTranslator (#…
Browse files Browse the repository at this point in the history
…12195)

* Recreate full dag instead of inplace substitution in BasisTranslator

This commit tweaks the internal logic of the basis translator transpiler
pass to do a full dag recreation instead of inplace modification. If
only a few operations were to be substituted it would probably be more
efficient to do an inplace modification, but in general the basis
translator ends up replacing far more operations than not. In such cases
just iterating over the dag and rebuilding it is more efficient because
the overhead of `apply_operation_back()` is minimal compared to
`substitute_node_with_dag()` (although it's higher than
`subtitute_node(.., inplace=True)`).

* Return boolean together with dag in 'apply_translation' to maintain original 'flow_blocks' logic and fix drawer test.

* Remove print

---------

Co-authored-by: Elena Peña Tapia <epenatap@gmail.com>
(cherry picked from commit 37b334f)
  • Loading branch information
mtreinish authored and mergify[bot] committed Jul 30, 2024
1 parent 0c6acb7 commit 82aa707
Showing 1 changed file with 49 additions and 32 deletions.
81 changes: 49 additions & 32 deletions qiskit/transpiler/passes/basis/basis_translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,54 +249,58 @@ def run(self, dag):
replace_start_time = time.time()

def apply_translation(dag, wire_map):
dag_updated = False
for node in dag.op_nodes():
is_updated = False
out_dag = dag.copy_empty_like()
for node in dag.topological_op_nodes():
node_qargs = tuple(wire_map[bit] for bit in node.qargs)
qubit_set = frozenset(node_qargs)
if node.name in target_basis or len(node.qargs) < self._min_qubits:
if node.name in CONTROL_FLOW_OP_NAMES:
flow_blocks = []
for block in node.op.blocks:
dag_block = circuit_to_dag(block)
dag_updated = apply_translation(
updated_dag, is_updated = apply_translation(
dag_block,
{
inner: wire_map[outer]
for inner, outer in zip(block.qubits, node.qargs)
},
)
if dag_updated:
flow_circ_block = dag_to_circuit(dag_block)
if is_updated:
flow_circ_block = dag_to_circuit(updated_dag)
else:
flow_circ_block = block
flow_blocks.append(flow_circ_block)
node.op = node.op.replace_blocks(flow_blocks)
out_dag.apply_operation_back(node.op, node.qargs, node.cargs, check=False)
continue
if (
node_qargs in self._qargs_with_non_global_operation
and node.name in self._qargs_with_non_global_operation[node_qargs]
):
out_dag.apply_operation_back(node.op, node.qargs, node.cargs, check=False)
continue

if dag.has_calibration_for(node):
out_dag.apply_operation_back(node.op, node.qargs, node.cargs, check=False)
continue
if qubit_set in extra_instr_map:
self._replace_node(dag, node, extra_instr_map[qubit_set])
self._replace_node(out_dag, node, extra_instr_map[qubit_set])
elif (node.name, node.num_qubits) in instr_map:
self._replace_node(dag, node, instr_map)
self._replace_node(out_dag, node, instr_map)
else:
raise TranspilerError(f"BasisTranslator did not map {node.name}.")
dag_updated = True
return dag_updated
is_updated = True
return out_dag, is_updated

apply_translation(dag, qarg_indices)
out_dag, _ = apply_translation(dag, qarg_indices)
replace_end_time = time.time()
logger.info(
"Basis translation instructions replaced in %.3fs.",
replace_end_time - replace_start_time,
)

return dag
return out_dag

def _replace_node(self, dag, node, instr_map):
target_params, target_dag = instr_map[node.name, node.num_qubits]
Expand All @@ -307,12 +311,18 @@ def _replace_node(self, dag, node, instr_map):
)
if node.params:
parameter_map = dict(zip(target_params, node.params))
bound_target_dag = target_dag.copy_empty_like()
for inner_node in target_dag.topological_op_nodes():
new_node = DAGOpNode.from_instruction(
inner_node._to_circuit_instruction(),
dag=bound_target_dag,
dag=target_dag,
)
new_node.qargs = tuple(
node.qargs[target_dag.find_bit(x).index] for x in inner_node.qargs
)
new_node.cargs = tuple(
node.cargs[target_dag.find_bit(x).index] for x in inner_node.cargs
)

if not new_node.is_standard_gate:
new_node.op = new_node.op.copy()
if any(isinstance(x, ParameterExpression) for x in inner_node.params):
Expand All @@ -334,39 +344,46 @@ def _replace_node(self, dag, node, instr_map):
new_node.params = new_params
if not new_node.is_standard_gate:
new_node.op.params = new_params
bound_target_dag._apply_op_node_back(new_node)
dag._apply_op_node_back(new_node)

if isinstance(target_dag.global_phase, ParameterExpression):
old_phase = target_dag.global_phase
bind_dict = {x: parameter_map[x] for x in old_phase.parameters}
if any(isinstance(x, ParameterExpression) for x in bind_dict.values()):
new_phase = old_phase
for x in bind_dict.items():
new_phase = new_phase.assign(*x)

else:
new_phase = old_phase.bind(bind_dict)
if not new_phase.parameters:
new_phase = new_phase.numeric()
if isinstance(new_phase, complex):
raise TranspilerError(f"Global phase must be real, but got '{new_phase}'")
bound_target_dag.global_phase = new_phase
else:
bound_target_dag = target_dag

if len(bound_target_dag.op_nodes()) == 1 and len(
bound_target_dag.op_nodes()[0].qargs
) == len(node.qargs):
dag_op = bound_target_dag.op_nodes()[0].op
# dag_op may be the same instance as other ops in the dag,
# so if there is a condition, need to copy
if getattr(node, "condition", None):
dag_op = dag_op.copy()
dag.substitute_node(node, dag_op, inplace=True)

if bound_target_dag.global_phase:
dag.global_phase += bound_target_dag.global_phase
dag.global_phase += new_phase

else:
dag.substitute_node_with_dag(node, bound_target_dag)
for inner_node in target_dag.topological_op_nodes():
new_node = DAGOpNode.from_instruction(
inner_node._to_circuit_instruction(),
dag=target_dag,
)
new_node.qargs = tuple(
node.qargs[target_dag.find_bit(x).index] for x in inner_node.qargs
)
new_node.cargs = tuple(
node.cargs[target_dag.find_bit(x).index] for x in inner_node.cargs
)
if not new_node.is_standard_gate:
new_node.op = new_node.op.copy()
# dag_op may be the same instance as other ops in the dag,
# so if there is a condition, need to copy
if getattr(node.op, "condition", None):
new_node_op = new_node.op.to_mutable()
new_node_op.condition = node.op.condition
new_node.op = new_node_op
dag._apply_op_node_back(new_node)
if target_dag.global_phase:
dag.global_phase += target_dag.global_phase

@singledispatchmethod
def _extract_basis(self, circuit):
Expand Down

0 comments on commit 82aa707

Please sign in to comment.