Skip to content

Commit

Permalink
Extend qml.cond (#2275)
Browse files Browse the repository at this point in the history
* implement and test __invert__

* add else_op to qml.cond

* using inversion in an integration

* format

* copy measurement under the hood when else; else test

* format

* err tests

* format

* docstrings

* isort

* test_condition

* docstrings

* format

* no need for separate logic for ops

* refactor wrapper as per Josh's suggestion from code review

* Update tests/transforms/test_defer_measurements.py

Co-authored-by: Josh Izaac <josh146@gmail.com>

* two unit tests for checking qml.cond queuing

* add queue unit tests

* format

* invert docstring

* rename arguments

* module docstring

* module docstring: note where integration tests are

* changelog

* Update pennylane/transforms/condition.py

Co-authored-by: Josh Izaac <josh146@gmail.com>

Co-authored-by: Josh Izaac <josh146@gmail.com>
  • Loading branch information
antalszava and josh146 authored Mar 4, 2022
1 parent 6f2283d commit 0308d78
Show file tree
Hide file tree
Showing 6 changed files with 482 additions and 18 deletions.
1 change: 1 addition & 0 deletions doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@
has been added to support use cases like quantum teleportation.
[(#2211)](https://github.com/PennyLaneAI/pennylane/pull/2211)
[(#2236)](https://github.com/PennyLaneAI/pennylane/pull/2236)
[(#2275)](https://github.com/PennyLaneAI/pennylane/pull/2275)

The addition includes the `defer_measurements` device-independent transform
that can be applied on devices that have no native mid-circuit measurements
Expand Down
9 changes: 9 additions & 0 deletions pennylane/measurements.py
Original file line number Diff line number Diff line change
Expand Up @@ -635,6 +635,15 @@ def branches(self):
branch_dict[(1,)] = self._one_case
return branch_dict

def __invert__(self):
"""Inverts the control value of the measurement."""
zero = self._zero_case
one = self._one_case

self._control_value = one if self._control_value == zero else zero

return self

def __eq__(self, control_value):
"""Allow asserting measurement values."""
measurement_outcomes = {self._zero_case, self._one_case}
Expand Down
64 changes: 53 additions & 11 deletions pennylane/transforms/condition.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,18 @@
"""
Contains the condition transform.
"""
import inspect
from copy import copy
from functools import wraps
from typing import Type

from pennylane.operation import Operation, AnyWires
from pennylane.measurements import MeasurementValue
from pennylane.operation import AnyWires, Operation, Operator
from pennylane.transforms import make_tape


class ConditionalTransformError(ValueError):
"""Error for using qml.cond incorrectly"""


class Conditional(Operation):
Expand Down Expand Up @@ -57,22 +64,24 @@ def __init__(
super().__init__(wires=then_op.wires, do_queue=do_queue, id=id)


def cond(measurement, then_op):
def cond(condition, true_fn, false_fn=None):
"""Condition a quantum operation on the results of mid-circuit qubit measurements.
Support for using :func:`~.cond` is device-dependent. If a device doesn't
support mid-circuit measurements natively, then the QNode will apply the
:func:`defer_measurements` transform.
Args:
measurement (MeasurementValue): a measurement value to consider, for
example the output of calling :func:`~.measure`
then_op (Operation): The PennyLane operation to apply if the condition
applies.
condition (.MeasurementValue[bool]): a conditional expression involving a mid-circuit
measurement value (see :func:`.pennylane.measure`)
true_fn (callable): The quantum function of PennyLane operation to
apply if ``condition`` is ``True``
false_fn (callable): The quantum function of PennyLane operation to
apply if ``condition`` is ``False``
Returns:
function: A new function that applies the conditional equivalent of ``then_op``. The returned
function takes the same input arguments as ``then_op``.
function: A new function that applies the conditional equivalent of ``true_fn``. The returned
function takes the same input arguments as ``true_fn``.
**Example**
Expand All @@ -92,9 +101,42 @@ def qnode():
qml.cond(m_0, qml.RZ)(sec_par, wires=1)
return qml.expval(qml.PauliZ(1))
"""
if callable(true_fn):
# We assume that the callable is an operation or a quantum function

with_meas_err = (
"Only quantum functions that contain no measurements can be applied conditionally."
)

@wraps(true_fn)
def wrapper(*args, **kwargs):
# We assume that the callable is a quantum function

# 1. Apply true_fn conditionally
tape = make_tape(true_fn)(*args, **kwargs)

if tape.measurements:
raise ConditionalTransformError(with_meas_err)

for op in tape.operations:
Conditional(condition, op)

if false_fn is not None:
# 2. Apply false_fn conditionally
else_tape = make_tape(false_fn)(*args, **kwargs)

if else_tape.measurements:
raise ConditionalTransformError(with_meas_err)

inverted_m = copy(condition)
inverted_m = ~inverted_m

for op in else_tape.operations:
Conditional(inverted_m, op)

@wraps(then_op)
def wrapper(*args, **kwargs):
return Conditional(measurement, then_op(*args, do_queue=False, **kwargs))
else:
raise ConditionalTransformError(
"Only operations and quantum functions with no measurements can be applied conditionally."
)

return wrapper
17 changes: 17 additions & 0 deletions tests/test_measurements.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,6 +334,23 @@ def test_measurement_value_assertion(self, val_pair, control_val_idx):
mv == val_pair[control_val_idx]
assert mv._control_value == val_pair[control_val_idx]

@pytest.mark.parametrize("val_pair", [(0, 1), (1, 0), (-1, 1)])
@pytest.mark.parametrize("num_inv, expected_idx", [(1, 0), (2, 1), (3, 0)])
def test_measurement_value_inversion(self, val_pair, num_inv, expected_idx):
"""Test that inverting the value of a measurement works well even with
multiple inversions.
Double-inversion should leave the control value of the measurement
value in place.
"""
zero_case = val_pair[0]
one_case = val_pair[1]
mv = MeasurementValue(measurement_id="1234", zero_case=zero_case, one_case=one_case)
for _ in range(num_inv):
mv = mv.__invert__()

assert mv._control_value == val_pair[expected_idx]

def test_measurement_value_assertion_error(self):
"""Test that the return_type related info is updated for a
measurement."""
Expand Down
Loading

0 comments on commit 0308d78

Please sign in to comment.