Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove op.is_hermitian check in expval, counts, sample to allow jit tracing #5506

Merged
merged 19 commits into from
Apr 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -258,6 +258,9 @@
(which is not currently compatible with `KerasLayer`), linking to instructions to enable Keras 2.
[(#5488)](https://github.com/PennyLaneAI/pennylane/pull/5488)

* Removed the warning that an observable might not be hermitian in `qnode` executions. This enables jit-compilation.
[(#5506)](https://github.com/PennyLaneAI/pennylane/pull/5506)

<h3>Breaking changes 💔</h3>

* The private functions `_pauli_mult`, `_binary_matrix` and `_get_pauli_map` from the `pauli` module have been removed. The same functionality can be achieved using newer features in the ``pauli`` module.
Expand Down
4 changes: 0 additions & 4 deletions pennylane/measurements/counts.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
"""
This module contains the qml.counts measurement.
"""
import warnings
from typing import Sequence, Tuple, Optional
import numpy as np

Expand Down Expand Up @@ -146,9 +145,6 @@ def circuit():

return CountsMP(obs=op, all_outcomes=all_outcomes)

if op is not None and not op.is_hermitian: # None type is also allowed for op
warnings.warn(f"{op.name} might not be hermitian.")

if wires is not None:
Qottmann marked this conversation as resolved.
Show resolved Hide resolved
if op is not None:
raise ValueError(
Expand Down
4 changes: 0 additions & 4 deletions pennylane/measurements/expval.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
"""
This module contains the qml.expval measurement.
"""
import warnings
from typing import Sequence, Tuple, Union

import pennylane as qml
Expand Down Expand Up @@ -71,9 +70,6 @@ def circuit(x):
"Expectation values of qml.Identity() without wires are currently not allowed."
)

if not op.is_hermitian:
warnings.warn(f"{op.name} might not be hermitian.")

return ExpectationMP(obs=op)


Expand Down
4 changes: 0 additions & 4 deletions pennylane/measurements/sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
This module contains the qml.sample measurement.
"""
import functools
import warnings
from typing import Sequence, Tuple, Optional, Union

import numpy as np
Expand Down Expand Up @@ -173,9 +172,6 @@ def __init__(self, obs=None, wires=None, eigvals=None, id=None):
super().__init__(obs=obs)
return

if obs is not None and not obs.is_hermitian: # None type is also allowed for op
warnings.warn(f"{obs.name} might not be hermitian.")

if wires is not None:
if obs is not None:
raise ValueError(
Expand Down
4 changes: 0 additions & 4 deletions pennylane/measurements/var.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
"""
This module contains the qml.var measurement.
"""
import warnings
from typing import Sequence, Tuple, Union

import pennylane as qml
Expand Down Expand Up @@ -64,9 +63,6 @@ def circuit(x):
"qml.var does not support measuring sequences of measurements or observables"
)

if not op.is_hermitian:
warnings.warn(f"{op.name} might not be hermitian.")

return VarianceMP(obs=op)


Expand Down
28 changes: 28 additions & 0 deletions tests/devices/default_qubit/test_default_qubit.py
Original file line number Diff line number Diff line change
Expand Up @@ -2052,6 +2052,34 @@ def circuit():

assert circuit() == expected

@pytest.mark.jax
@pytest.mark.parametrize("measurement_func", [qml.expval, qml.var])
def test_differentiate_jitted_qnode(self, measurement_func):
"""Test that a jitted qnode can be correctly differentiated"""
import jax

dev = DefaultQubit()

def qfunc(x, y):
qml.RX(x, 0)
return measurement_func(qml.Hamiltonian(y, [qml.Z(0)]))

qnode = qml.QNode(qfunc, dev, interface="jax")
qnode_jit = jax.jit(qml.QNode(qfunc, dev, interface="jax"))

x = jax.numpy.array(0.5)
y = jax.numpy.array([0.5])

res = qnode(x, y)
res_jit = qnode_jit(x, y)

assert qml.math.allclose(res, res_jit)

grad = jax.grad(qnode)(x, y)
grad_jit = jax.grad(qnode_jit)(x, y)

assert qml.math.allclose(grad, grad_jit)


@pytest.mark.parametrize("max_workers", max_workers_list)
def test_broadcasted_parameter(max_workers):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1520,6 +1520,7 @@ def test_hamiltonian_expansion_analytic(
spy = mocker.spy(qml.transforms, "hamiltonian_expand")
obs = [qml.PauliX(0), qml.PauliX(0) @ qml.PauliZ(1), qml.PauliZ(0) @ qml.PauliZ(1)]

@jax.jit
@qnode(
dev,
interface=interface,
Expand Down
18 changes: 0 additions & 18 deletions tests/measurements/legacy/test_expval_legacy.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,24 +80,6 @@ def circuit(x):

custom_measurement_process(new_dev, spy)

def test_not_an_observable(self, mocker):
"""Test that a warning is raised if the provided
argument might not be hermitian."""
dev = qml.device("default.qubit.legacy", wires=2)

@qml.qnode(dev)
def circuit():
qml.RX(0.52, wires=0)
return qml.expval(qml.prod(qml.PauliX(0), qml.PauliZ(0)))

new_dev = circuit.device
spy = mocker.spy(qml.QubitDevice, "expval")

with pytest.warns(UserWarning, match="Prod might not be hermitian."):
_ = circuit()

custom_measurement_process(new_dev, spy)

def test_observable_return_type_is_expectation(self, mocker):
"""Test that the return type of the observable is :attr:`ObservableReturnTypes.Expectation`"""
dev = qml.device("default.qubit.legacy", wires=2)
Expand Down
18 changes: 0 additions & 18 deletions tests/measurements/legacy/test_measurements_legacy.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,6 @@
Shots,
StateMeasurement,
StateMP,
expval,
var,
)

from pennylane.wires import Wires
Expand Down Expand Up @@ -66,22 +64,6 @@ def test_shape_unrecognized_error():
mp.shape(dev, Shots(None))


@pytest.mark.parametrize("stat_func", [expval, var])
def test_not_an_observable(stat_func):
"""Test that a UserWarning is raised if the provided
argument might not be hermitian."""

dev = qml.device("default.qubit.legacy", wires=2)

@qml.qnode(dev)
def circuit():
qml.RX(0.52, wires=0)
return stat_func(qml.prod(qml.PauliX(0), qml.PauliZ(0)))

with pytest.warns(UserWarning, match="Prod might not be hermitian."):
_ = circuit()


class TestSampleMeasurement:
"""Tests for the SampleMeasurement class."""

Expand Down
16 changes: 0 additions & 16 deletions tests/measurements/legacy/test_sample_legacy.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,22 +251,6 @@ def circuit():

custom_measurement_process(dev, spy)

def test_not_an_observable(self, mocker):
"""Test that a UserWarning is raised if the provided
argument might not be hermitian."""
dev = qml.device("default.qubit.legacy", wires=2, shots=10)
spy = mocker.spy(qml.QubitDevice, "sample")

@qml.qnode(dev)
def circuit():
qml.RX(0.52, wires=0)
return qml.sample(qml.prod(qml.PauliX(0), qml.PauliZ(0)))

with pytest.warns(UserWarning, match="Prod might not be hermitian."):
_ = circuit()

custom_measurement_process(dev, spy)

def test_observable_return_type_is_sample(self, mocker):
"""Test that the return type of the observable is :attr:`ObservableReturnTypes.Sample`"""
n_shots = 10
Expand Down
16 changes: 0 additions & 16 deletions tests/measurements/legacy/test_var_legacy.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,22 +77,6 @@ def circuit(x):

custom_measurement_process(dev, spy)

def test_not_an_observable(self, mocker):
"""Test that a UserWarning is raised if the provided
argument might not be hermitian."""
dev = qml.device("default.qubit.legacy", wires=2)
spy = mocker.spy(qml.QubitDevice, "var")

@qml.qnode(dev)
def circuit():
qml.RX(0.52, wires=0)
return qml.var(qml.prod(qml.PauliX(0), qml.PauliZ(0)))

with pytest.warns(UserWarning, match="Prod might not be hermitian."):
_ = circuit()

custom_measurement_process(dev, spy)

def test_observable_return_type_is_variance(self, mocker):
"""Test that the return type of the observable is :attr:`ObservableReturnTypes.Variance`"""
dev = qml.device("default.qubit.legacy", wires=2)
Expand Down
7 changes: 0 additions & 7 deletions tests/measurements/test_counts.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,13 +61,6 @@ def test_providing_observable_and_wires(self):
):
qml.counts(qml.PauliZ(0), wires=[0, 1])

def test_observable_might_not_be_hermitian(self):
"""Test that a UserWarning is raised if the provided
argument might not be hermitian."""

with pytest.warns(UserWarning, match="Prod might not be hermitian."):
qml.counts(qml.prod(qml.PauliX(0), qml.PauliZ(0)))

def test_hash(self):
"""Test that the hash property includes the all_outcomes property."""
m1 = qml.counts(all_outcomes=True)
Expand Down
13 changes: 0 additions & 13 deletions tests/measurements/test_expval.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,19 +73,6 @@ def circuit(x):
else:
assert res.dtype == r_dtype

def test_not_an_observable(self):
"""Test that a warning is raised if the provided
argument might not be hermitian."""
dev = qml.device("default.qubit", wires=2)

@qml.qnode(dev)
def circuit():
qml.RX(0.52, wires=0)
return qml.expval(qml.prod(qml.PauliX(0), qml.PauliZ(0)))

with pytest.warns(UserWarning, match="Prod might not be hermitian."):
_ = circuit()

def test_observable_return_type_is_expectation(self):
"""Test that the return type of the observable is :attr:`ObservableReturnTypes.Expectation`"""
dev = qml.device("default.qubit", wires=2)
Expand Down
16 changes: 0 additions & 16 deletions tests/measurements/test_measurements.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,22 +293,6 @@ def test_queueing_tensor_observable(self, op1, op2, stat_func, return_type):
assert isinstance(meas_proc, MeasurementProcess)
assert meas_proc.return_type == return_type

def test_not_an_observable(self, stat_func, return_type): # pylint: disable=unused-argument
"""Test that a UserWarning is raised if the provided
argument might not be hermitian."""
if stat_func is sample:
pytest.skip("Sampling is not yet supported with symbolic operators.")

dev = qml.device("default.qubit", wires=2)

@qml.qnode(dev)
def circuit():
qml.RX(0.52, wires=0)
return stat_func(qml.prod(qml.PauliX(0), qml.PauliZ(0)))

with pytest.warns(UserWarning, match="Prod might not be hermitian."):
_ = circuit()


class TestProperties:
"""Test for the properties"""
Expand Down
13 changes: 0 additions & 13 deletions tests/measurements/test_sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,19 +126,6 @@ def circuit():
assert result[2].dtype == np.dtype("int")
assert np.array_equal(result[2].shape, (n_sample,))

def test_not_an_observable(self):
"""Test that a UserWarning is raised if the provided
argument might not be hermitian."""
dev = qml.device("default.qubit", wires=2, shots=10)

@qml.qnode(dev)
def circuit():
qml.RX(0.52, wires=0)
return qml.sample(qml.prod(qml.PauliX(0), qml.PauliZ(0)))

with pytest.warns(UserWarning, match="Prod might not be hermitian."):
_ = circuit()

def test_observable_return_type_is_sample(self):
"""Test that the return type of the observable is :attr:`ObservableReturnTypes.Sample`"""
n_shots = 10
Expand Down
13 changes: 0 additions & 13 deletions tests/measurements/test_var.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,19 +49,6 @@ def circuit(x):
else:
assert res.dtype == r_dtype

def test_not_an_observable(self):
"""Test that a UserWarning is raised if the provided
argument might not be hermitian."""
dev = qml.device("default.qubit", wires=2)

@qml.qnode(dev)
def circuit():
qml.RX(0.52, wires=0)
return qml.var(qml.prod(qml.PauliX(0), qml.PauliZ(0)))

with pytest.warns(UserWarning, match="Prod might not be hermitian."):
_ = circuit()

def test_observable_return_type_is_variance(self):
"""Test that the return type of the observable is :attr:`ObservableReturnTypes.Variance`"""
dev = qml.device("default.qubit", wires=2)
Expand Down
4 changes: 2 additions & 2 deletions tests/ops/op_math/test_prod.py
Original file line number Diff line number Diff line change
Expand Up @@ -1485,8 +1485,8 @@ def circuit(weights):
true_grad = -qnp.sqrt(2) * qnp.cos(weights[0] / 2) * qnp.sin(weights[0] / 2)
assert qnp.allclose(grad, true_grad)

def test_non_hermitian_obs_not_supported(self):
"""Test that non-hermitian ops in a measurement process will raise a warning."""
def test_non_supported_obs_not_supported(self):
"""Test that non-supported ops in a measurement process will raise an error."""
wires = [0, 1]
dev = qml.device("default.qubit", wires=wires)
prod_op = Prod(qml.RX(1.23, wires=0), qml.Identity(wires=1))
Expand Down
14 changes: 0 additions & 14 deletions tests/ops/op_math/test_sprod.py
Original file line number Diff line number Diff line change
Expand Up @@ -1084,20 +1084,6 @@ def circuit(weights):
true_grad = 100 * -qnp.sqrt(2) * qnp.cos(weights[0] / 2) * qnp.sin(weights[0] / 2)
assert qnp.allclose(grad, true_grad)

def test_non_hermitian_obs_not_supported(self):
"""Test that non-hermitian ops in a measurement process will raise a warning."""
wires = [0, 1]
dev = qml.device("default.qubit", wires=wires)
sprod_op = SProd(1.0 + 2.0j, qml.RX(1.23, wires=0))

@qml.qnode(dev)
def my_circ():
qml.PauliX(0)
return qml.expval(sprod_op)

with pytest.raises(NotImplementedError):
my_circ()

@pytest.mark.torch
@pytest.mark.parametrize("diff_method", ("parameter-shift", "backprop"))
def test_torch(self, diff_method):
Expand Down
14 changes: 0 additions & 14 deletions tests/ops/op_math/test_sum.py
Original file line number Diff line number Diff line change
Expand Up @@ -1267,20 +1267,6 @@ def circuit(weights):
true_grad = qnp.array([-0.09347337, -0.18884787, -0.28818254])
assert qnp.allclose(grad, true_grad)

def test_non_hermitian_op_in_measurement_process(self):
"""Test that non-hermitian ops in a measurement process will raise a warning."""
wires = [0, 1]
dev = qml.device("default.qubit", wires=wires)
sum_op = Sum(Prod(qml.RX(1.23, wires=0), qml.Identity(wires=1)), qml.Identity(wires=1))

@qml.qnode(dev, interface=None)
def my_circ():
qml.PauliX(0)
return qml.expval(sum_op)

with pytest.warns(UserWarning, match="Sum might not be hermitian."):
my_circ()

def test_params_can_be_considered_trainable(self):
"""Tests that the parameters of a Sum are considered trainable."""
dev = qml.device("default.qubit", wires=2)
Expand Down
Loading