Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

clarify that qnodes with qml.sample are not differentiable #5237

Merged
merged 3 commits into from
Feb 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -531,7 +531,12 @@
* A warning about two mathematically equivalent Hamiltonians undergoing different time evolutions was added to `qml.TrotterProduct` and `qml.ApproxTimeEvolution`.
[(#5137)](https://github.com/PennyLaneAI/pennylane/pull/5137)

* Added a reference to the paper that provides the image of the `qml.QAOAEmbedding` template. [(#5130)](https://github.com/PennyLaneAI/pennylane/pull/5130)
* Added a reference to the paper that provides the image of the `qml.QAOAEmbedding` template.
[(#5130)](https://github.com/PennyLaneAI/pennylane/pull/5130)

* The docstring of `qml.sample` has been updated to advise the use of single-shot expectations
instead when differentiating a circuit.
timmysilv marked this conversation as resolved.
Show resolved Hide resolved
[(#5237)](https://github.com/PennyLaneAI/pennylane/pull/5237)

<h3>Bug fixes 🐛</h3>

Expand Down
35 changes: 31 additions & 4 deletions pennylane/measurements/sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,10 +55,37 @@ def sample(op: Optional[Union[Operator, MeasurementValue]] = None, wires=None) -
.. note::

QNodes that return samples cannot, in general, be differentiated, since the derivative
with respect to a sample --- a stochastic process --- is ill-defined. The one exception
is if the QNode uses the parameter-shift method (``diff_method="parameter-shift"``), in
which case ``qml.sample(obs)`` is interpreted as a single-shot expectation value of the
observable ``obs``.
with respect to a sample --- a stochastic process --- is ill-defined. An alternative
approach would be to use single-shot expectation values. For example, instead of this:
timmysilv marked this conversation as resolved.
Show resolved Hide resolved

.. code-block:: python

dev = qml.device("default.qubit", shots=10)

@qml.qnode(dev, diff_method="parameter-shift")
def circuit(angle):
qml.RX(angle, wires=0)
return qml.sample(qml.PauliX(0))

angle = qml.numpy.array(0.1)
res = qml.jacobian(circuit)(angle)

Consider using :func:`~pennylane.expval` and a sequence of single shots, like this:

.. code-block:: python

dev = qml.device("default.qubit", shots=[(1, 10)])

@qml.qnode(dev, diff_method="parameter-shift")
def circuit(angle):
qml.RX(angle, wires=0)
return qml.expval(qml.PauliX(0))

def cost(angle):
return qml.math.hstack(circuit(angle))

angle = qml.numpy.array(0.1)
res = qml.jacobian(cost)(angle)

**Example**

Expand Down
30 changes: 30 additions & 0 deletions tests/measurements/test_sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -451,6 +451,36 @@ class DummyOp(Operator): # pylint: disable=too-few-public-methods
with pytest.raises(EigvalsUndefinedError, match="Cannot compute samples of"):
qml.sample(op=DummyOp(0)).process_samples(samples=np.array([[1, 0]]), wire_order=[0])

def test_sample_allowed_with_parameter_shift(self):
"""Test that qml.sample doesn't raise an error with parameter-shift and autograd."""
dev = qml.device("default.qubit", shots=10)

@qml.qnode(dev, diff_method="parameter-shift")
def circuit(angle):
qml.RX(angle, wires=0)
return qml.sample(qml.PauliX(0))

angle = qml.numpy.array(0.1)
res = qml.jacobian(circuit)(angle)
assert qml.math.shape(res) == (10,)
assert all(r in {-1, 0, 1} for r in np.round(res, 13))

@pytest.mark.jax
def test_sample_fails_with_jax_jacobian(self):
"""Test that qml.sample raises an error with parameter-shift and jax."""
import jax

dev = qml.device("default.qubit", shots=10)

@qml.qnode(dev, diff_method="parameter-shift")
def circuit(angle):
qml.RX(angle, wires=0)
return qml.sample(qml.PauliX(0))

angle = jax.numpy.array(0.1)
with pytest.raises(TypeError, match=r"got int64\[10\] and float64\[10\] respectively"):
_ = jax.jacobian(circuit)(angle)


@pytest.mark.jax
@pytest.mark.parametrize("samples", (1, 10))
Expand Down
Loading