Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Extend qml.cond #2275

Merged
merged 28 commits into from
Mar 4, 2022
Merged
Show file tree
Hide file tree
Changes from 27 commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
6ba0bac
implement and test __invert__
antalszava Mar 3, 2022
506ad31
add else_op to qml.cond
antalszava Mar 3, 2022
6325507
using inversion in an integration
antalszava Mar 3, 2022
55f665a
test inversion
antalszava Mar 3, 2022
624362a
format
antalszava Mar 3, 2022
96bd3fc
copy measurement under the hood when else; else test
antalszava Mar 3, 2022
2e7b969
format
antalszava Mar 3, 2022
590eb89
err tests
antalszava Mar 3, 2022
e5bfd12
format
antalszava Mar 3, 2022
611389b
docstrings
antalszava Mar 3, 2022
aa0372c
isort
antalszava Mar 3, 2022
abafb9c
test_condition
antalszava Mar 4, 2022
90b0ee5
docstrings
antalszava Mar 4, 2022
233764d
format
antalszava Mar 4, 2022
9ae53c8
resolve conflict docstring
antalszava Mar 4, 2022
9eec1c2
no need for separate logic for ops
antalszava Mar 4, 2022
58f531a
refactor wrapper as per Josh's suggestion from code review
antalszava Mar 4, 2022
9067e9f
Update tests/transforms/test_defer_measurements.py
antalszava Mar 4, 2022
3614892
two unit tests for checking qml.cond queuing
antalszava Mar 4, 2022
d55fb16
Merge branch 'extend_qml_cond' of github.com:PennyLaneAI/pennylane in…
antalszava Mar 4, 2022
e34de5c
add queue unit tests
antalszava Mar 4, 2022
61a7407
format
antalszava Mar 4, 2022
969e1eb
invert docstring
antalszava Mar 4, 2022
9a841a7
rename arguments
antalszava Mar 4, 2022
7b968fe
module docstring
antalszava Mar 4, 2022
6109a35
module docstring: note where integration tests are
antalszava Mar 4, 2022
b03f177
changelog
antalszava Mar 4, 2022
c27f50d
Update pennylane/transforms/condition.py
antalszava Mar 4, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@
has been added to support use cases like quantum teleportation.
[(#2211)](https://github.com/PennyLaneAI/pennylane/pull/2211)
[(#2236)](https://github.com/PennyLaneAI/pennylane/pull/2236)
[(#2275)](https://github.com/PennyLaneAI/pennylane/pull/2275)

The addition includes the `defer_measurements` device-independent transform
that can be applied on devices that have no native mid-circuit measurements
Expand Down
9 changes: 9 additions & 0 deletions pennylane/measurements.py
Original file line number Diff line number Diff line change
Expand Up @@ -635,6 +635,15 @@ def branches(self):
branch_dict[(1,)] = self._one_case
return branch_dict

def __invert__(self):
"""Inverts the control value of the measurement."""
zero = self._zero_case
one = self._one_case

self._control_value = one if self._control_value == zero else zero

return self

def __eq__(self, control_value):
"""Allow asserting measurement values."""
measurement_outcomes = {self._zero_case, self._one_case}
Expand Down
64 changes: 53 additions & 11 deletions pennylane/transforms/condition.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,18 @@
"""
Contains the condition transform.
"""
import inspect
from copy import copy
from functools import wraps
from typing import Type

from pennylane.operation import Operation, AnyWires
from pennylane.measurements import MeasurementValue
from pennylane.operation import AnyWires, Operation, Operator
from pennylane.transforms import make_tape


class ConditionalTransformError(ValueError):
"""Error for using qml.cond incorrectly"""


class Conditional(Operation):
Expand Down Expand Up @@ -57,22 +64,24 @@ def __init__(
super().__init__(wires=then_op.wires, do_queue=do_queue, id=id)


def cond(measurement, then_op):
def cond(condition, true_fn, false_fn=None):
"""Condition a quantum operation on the results of mid-circuit qubit measurements.

Support for using :func:`~.cond` is device-dependent. If a device doesn't
support mid-circuit measurements natively, then the QNode will apply the
:func:`defer_measurements` transform.

Args:
measurement (MeasurementValue): a measurement value to consider, for
example the output of calling :func:`~.measure`
then_op (Operation): The PennyLane operation to apply if the condition
applies.
condition (MeasurementValue[bool]): a conditional expression involving a mid-circuit
antalszava marked this conversation as resolved.
Show resolved Hide resolved
measurement value (see :func:`.pennylane.measure`)
true_fn (callable): The quantum function of PennyLane operation to
apply if ``condition`` is ``True``
false_fn (callable): The quantum function of PennyLane operation to
apply if ``condition`` is ``False``

Returns:
function: A new function that applies the conditional equivalent of ``then_op``. The returned
function takes the same input arguments as ``then_op``.
function: A new function that applies the conditional equivalent of ``true_fn``. The returned
function takes the same input arguments as ``true_fn``.

**Example**
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@antalszava the example here is not up to our usual standards for docstring examples (e.g., it doesn't contain any text explaining the context/code, and showing outputs).

It also seems to be missing:

  • examples with qfuncs,
  • examples with then/else, and
  • examples showing how to deal with conditional functions with different signatures


Expand All @@ -92,9 +101,42 @@ def qnode():
qml.cond(m_0, qml.RZ)(sec_par, wires=1)
return qml.expval(qml.PauliZ(1))
"""
if callable(true_fn):
# We assume that the callable is an operation or a quantum function

with_meas_err = (
"Only quantum functions that contain no measurements can be applied conditionally."
josh146 marked this conversation as resolved.
Show resolved Hide resolved
)

@wraps(true_fn)
def wrapper(*args, **kwargs):
# We assume that the callable is a quantum function

# 1. Apply true_fn conditionally
tape = make_tape(true_fn)(*args, **kwargs)

if tape.measurements:
raise ConditionalTransformError(with_meas_err)

for op in tape.operations:
Conditional(condition, op)

if false_fn is not None:
# 2. Apply false_fn conditionally
else_tape = make_tape(false_fn)(*args, **kwargs)

if else_tape.measurements:
raise ConditionalTransformError(with_meas_err)

inverted_m = copy(condition)
inverted_m = ~inverted_m

for op in else_tape.operations:
Conditional(inverted_m, op)

@wraps(then_op)
def wrapper(*args, **kwargs):
return Conditional(measurement, then_op(*args, do_queue=False, **kwargs))
else:
raise ConditionalTransformError(
"Only operations and quantum functions with no measurements can be applied conditionally."
)

return wrapper
17 changes: 17 additions & 0 deletions tests/test_measurements.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,6 +334,23 @@ def test_measurement_value_assertion(self, val_pair, control_val_idx):
mv == val_pair[control_val_idx]
assert mv._control_value == val_pair[control_val_idx]

@pytest.mark.parametrize("val_pair", [(0, 1), (1, 0), (-1, 1)])
@pytest.mark.parametrize("num_inv, expected_idx", [(1, 0), (2, 1), (3, 0)])
def test_measurement_value_inversion(self, val_pair, num_inv, expected_idx):
"""Test that inverting the value of a measurement works well even with
multiple inversions.
antalszava marked this conversation as resolved.
Show resolved Hide resolved

Double-inversion should leave the control value of the measurement
value in place.
"""
zero_case = val_pair[0]
one_case = val_pair[1]
mv = MeasurementValue(measurement_id="1234", zero_case=zero_case, one_case=one_case)
for _ in range(num_inv):
mv = mv.__invert__()

assert mv._control_value == val_pair[expected_idx]

def test_measurement_value_assertion_error(self):
"""Test that the return_type related info is updated for a
measurement."""
Expand Down
Loading