Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Extend the conditional operations documentation #2294

Merged
merged 20 commits into from
Mar 12, 2022
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 11 additions & 3 deletions pennylane/measurements.py
Original file line number Diff line number Diff line change
Expand Up @@ -636,17 +636,25 @@ def branches(self):
return branch_dict

def __invert__(self):
"""Inverts the control value of the measurement."""
"""Return a copy of the measurement value with an inverted control
value."""
inverted_self = copy.copy(self)
zero = self._zero_case
one = self._one_case

self._control_value = one if self._control_value == zero else zero
inverted_self._control_value = one if self._control_value == zero else zero

return self
return inverted_self

def __eq__(self, control_value):
"""Allow asserting measurement values."""
measurement_outcomes = {self._zero_case, self._one_case}

if not isinstance(control_value, tuple(type(val) for val in measurement_outcomes)):
raise MeasurementValueError(
f"The equality operator is used to assert measurement outcomes, but got a value with type {type(control_value)}."
)

if control_value not in measurement_outcomes:
raise MeasurementValueError(
f"Unknown measurement value asserted; the set of possible measurement outcomes is: {measurement_outcomes}."
Expand Down
84 changes: 80 additions & 4 deletions pennylane/transforms/condition.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
"""
Contains the condition transform.
"""
from copy import copy
from functools import wraps
from typing import Type

Expand Down Expand Up @@ -99,6 +98,84 @@ def qnode():
m_1 = qml.measure(2)
qml.cond(m_0, qml.RZ)(sec_par, wires=1)
return qml.expval(qml.PauliZ(1))

.. UsageDetails::

**Conditional quantum functions**

The ``cond`` transform allows conditioning quantum functions too:

.. code-block:: python3

dev = qml.device("default.qubit", wires=2)

def qfunc(par, wires):
qml.Hadamard(wires[0])
qml.RY(par, wires[0])

@qml.qnode(dev)
def qnode():
qml.Hadamard(0)
m_0 = qml.measure(0)
qml.cond(m_0, qfunc)(first_par, wires=[1])
return qml.expval(qml.PauliZ(1))

.. code-block :: pycon

>>> par = np.array(0.3, requires_grad=True)
>>> qnode()
tensor(0.45008329, requires_grad=True)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe worth having qnode() take the parameter as an argument, e.g., qnode(par)? To avoid users potentially copying and pasting the code, and having odd closure/scope issues :)


**Passing a quantum function for the False case too**

In the qubit model, single-qubit measurements may result in one of two
outcomes. The expression involving a mid-circuit measurement value
passed to ``cond`` may also have two outcomes. ``cond`` allows passing
a quantum functions for both cases at the same time:

.. code-block:: python3

dev = qml.device("default.qubit", wires=2)

def qfunc1(par, wires):
qml.Hadamard(wires[0])
qml.RY(par, wires[0])

def qfunc2(par, wires):
qml.Hadamard(wires[0])
qml.RZ(par, wires[0])

@qml.qnode(dev)
def qnode1():
qml.Hadamard(0)
m_0 = qml.measure(0)
qml.cond(m_0, qfunc1, qfunc2)(par, wires=[1])
return qml.expval(qml.PauliZ(1))

.. code-block :: pycon

>>> par = np.array(0.3, requires_grad=True)
>>> qnode1()
tensor(-0.04991671, requires_grad=True)

The previous QNode is equivalent to using ``cond`` twice, inverting the
measurement value using the ``~`` unary operator in the second case:

.. code-block:: python3

@qml.qnode(dev)
def qnode2():
qml.Hadamard(0)
m_0 = qml.measure(0)
qml.cond(m_0, qfunc1)(par, wires=[1])
qml.cond(~m_0, qfunc2)(par, wires=[1])
return qml.expval(qml.PauliZ(1))

.. code-block :: pycon

>>> par = np.array(0.3, requires_grad=True)
>>> qnode2()
tensor(-0.04991671, requires_grad=True)
"""
if callable(true_fn):
# We assume that the callable is an operation or a quantum function
Expand Down Expand Up @@ -127,11 +204,10 @@ def wrapper(*args, **kwargs):
if else_tape.measurements:
raise ConditionalTransformError(with_meas_err)

inverted_m = copy(condition)
inverted_m = ~inverted_m
inverted_condition = ~condition

for op in else_tape.operations:
Conditional(inverted_m, op)
Conditional(inverted_condition, op)

else:
raise ConditionalTransformError(
Expand Down
16 changes: 15 additions & 1 deletion tests/test_measurements.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,10 +347,24 @@ def test_measurement_value_inversion(self, val_pair, num_inv, expected_idx):
one_case = val_pair[1]
mv = MeasurementValue(measurement_id="1234", zero_case=zero_case, one_case=one_case)
for _ in range(num_inv):
mv = mv.__invert__()
mv_new = mv.__invert__()

# Check that inversion involves creating a copy
assert not mv_new is mv

mv = mv_new

assert mv._control_value == val_pair[expected_idx]

def test_measurement_value_assertion_error_wrong_type(self):
"""Test that the return_type related info is updated for a
measurement."""
mv1 = MeasurementValue(measurement_id="1111")
mv2 = MeasurementValue(measurement_id="2222")

with pytest.raises(MeasurementValueError, match="The equality operator is used to assert measurement outcomes, but got a value with type"):
mv1 == mv2

def test_measurement_value_assertion_error(self):
"""Test that the return_type related info is updated for a
measurement."""
Expand Down
43 changes: 35 additions & 8 deletions tests/transforms/test_condition.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,9 +69,33 @@ def f(x):

assert ops[4].return_type == qml.operation.Probability

def test_cond_queues_with_else(self):
"""Test that qml.cond queues Conditional operations as expected when an
else qfunc is also provided."""

def tape_with_else(f, g, r):
"""Tape that uses cond by passing both a true and false func."""
with qml.tape.QuantumTape() as tape:
m_0 = qml.measure(0)
qml.cond(m_0, f, g)(r)
qml.probs(wires=1)

return tape

def tape_uses_cond_twice(f, g, r):
"""Tape that uses cond twice such that it's equivalent to using cond
with two functions being passed (tape_with_else)."""
with qml.tape.QuantumTape() as tape:
m_0 = qml.measure(0)
qml.cond(m_0, f)(r)
qml.cond(~m_0, g)(r)
qml.probs(wires=1)

return tape

@pytest.mark.parametrize("tape", [tape_with_else, tape_uses_cond_twice])
def test_cond_queues_with_else(self, tape):
"""Test that qml.cond queues Conditional operations as expected in two cases:
1. When an else qfunc is provided;
2. When qml.cond is used twice equivalent to using an else qfunc.
"""
r = 1.234

def f(x):
Expand All @@ -82,11 +106,7 @@ def f(x):
def g(x):
qml.PauliY(1)

with qml.tape.QuantumTape() as tape:
m_0 = qml.measure(0)
qml.cond(m_0, f, g)(r)
qml.probs(wires=1)

tape = tape(f, g, r)
ops = tape.queue
target_wire = qml.wires.Wires(1)

Expand All @@ -111,6 +131,13 @@ def g(x):
assert isinstance(ops[4].then_op, qml.PauliY)
assert ops[4].then_op.wires == target_wire

# Check that: the measurement value is the same for true_fn conditional
# ops
assert ops[1].meas_val is ops[2].meas_val is ops[3].meas_val

# However, it is not the same for the false_fn
assert ops[3].meas_val is not ops[4].meas_val

assert ops[5].return_type == qml.operation.Probability

def test_cond_error(self):
Expand Down