Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Wrapping of legacy device automatically in various device creation/qnode/execute functions #6046

Merged
merged 117 commits into from
Aug 21, 2024
Merged
Show file tree
Hide file tree
Changes from 107 commits
Commits
Show all changes
117 commits
Select commit Hold shift + click to select a range
ab601c8
adding legacydevicefacade class
albi3ro Apr 1, 2024
b6daa90
more tests
albi3ro Apr 1, 2024
7b20c99
adding to interfaces tests
albi3ro Apr 2, 2024
71db075
finish up torch tests
albi3ro Apr 2, 2024
2e17b76
merge master
albi3ro Jun 4, 2024
583f526
fixing up interface tests
albi3ro Jun 4, 2024
34e0856
starting on testing
albi3ro Jun 10, 2024
6f12302
starting on testing
albi3ro Jun 11, 2024
d32f7f9
Merge branch 'master' into legacy-device-facade-class
albi3ro Jul 2, 2024
900fa35
fixing up tests
albi3ro Jul 2, 2024
0771623
adding some more test coverage [skip-ci]
albi3ro Jul 2, 2024
4dbd068
adding some more tests and coverage
albi3ro Jul 4, 2024
bb6612e
more tests and some docs
albi3ro Jul 8, 2024
f998016
Merge branch 'master' into legacy-device-facade-class
albi3ro Jul 9, 2024
cf46f57
Merge branch 'master' into legacy-device-facade-class
Shiro-Raven Jul 24, 2024
f654249
pass along postselect mode
albi3ro Jul 24, 2024
4e9cc8d
Merge branch 'master' into legacy-device-facade-class
Shiro-Raven Jul 24, 2024
9b7605a
Update pennylane/devices/legacy_facade.py
albi3ro Jul 25, 2024
ebb8f44
resolving merges
albi3ro Jul 25, 2024
30029db
Merge branch 'legacy-device-facade-class' of https://github.com/Penny…
albi3ro Jul 25, 2024
0d11cd2
revert merge problems
albi3ro Jul 26, 2024
0021a54
Update pennylane/devices/legacy_facade.py
albi3ro Jul 26, 2024
7f2e199
Update tests/devices/test_legacy_facade.py
albi3ro Jul 26, 2024
0c7f247
Merge branch 'master' into legacy-device-facade-class
albi3ro Jul 26, 2024
64e59f7
deprecation of backprop device switching
albi3ro Jul 26, 2024
a3d0032
Update pennylane/devices/legacy_facade.py
albi3ro Jul 26, 2024
2c2a764
added facade wrappers, some tests still failing
Shiro-Raven Jul 26, 2024
d71419f
Merge branch 'legacy-device-facade-class' into ad/facade-wrapper
Shiro-Raven Jul 26, 2024
f6b7a52
Merge branch 'master' into ad/facade-wrapper
Shiro-Raven Jul 26, 2024
c49a816
more test fixes and renaming of dq2 test file
Shiro-Raven Jul 26, 2024
575c5c3
fix remaining tests
Shiro-Raven Jul 26, 2024
dbcb7a9
Merge branch 'master' into ad/facade-wrapper
Shiro-Raven Jul 26, 2024
067efa2
Update pennylane/workflow/qnode.py
Shiro-Raven Jul 26, 2024
d006c87
Update pennylane/workflow/qnode.py
Shiro-Raven Jul 26, 2024
78b47d0
Update pennylane/workflow/qnode.py
Shiro-Raven Jul 26, 2024
ec72210
changelog update
Shiro-Raven Jul 26, 2024
1b85542
more test fixes and merge mashups fixes
Shiro-Raven Jul 26, 2024
68bf9e4
Merge branch 'master' into ad/facade-wrapper
Shiro-Raven Jul 26, 2024
10bf1e1
fix snapshot tests
Shiro-Raven Jul 29, 2024
18b855f
fix failing test_gates test
Shiro-Raven Jul 29, 2024
b5b15d8
Merge branch 'master' into ad/facade-wrapper
Shiro-Raven Jul 29, 2024
7e95b50
fixed more tests and added metaclass for legacy device API for facade…
Shiro-Raven Jul 29, 2024
5e5dc4c
Merge branch 'master' into ad/facade-wrapper
Shiro-Raven Jul 29, 2024
d3fc25d
fix to `TransformedDevice`
Shiro-Raven Jul 29, 2024
e0e7079
Merge branch 'master' into ad/facade-wrapper
Shiro-Raven Jul 29, 2024
cb841a1
fix jacobian bug
Shiro-Raven Jul 29, 2024
b94bbb7
more test fixes
Shiro-Raven Jul 30, 2024
6762273
more fixes
Shiro-Raven Jul 30, 2024
1fbd7fe
Merge branch 'master' into ad/facade-wrapper
Shiro-Raven Jul 30, 2024
3d198e8
fix for tensorflow test
Shiro-Raven Jul 30, 2024
9d3cc5d
Merge branch 'master' into ad/facade-wrapper
Shiro-Raven Jul 30, 2024
28cacd4
more test fixes
Shiro-Raven Jul 30, 2024
c01f3d8
jax tests fixes
Shiro-Raven Jul 30, 2024
bd43ed2
more test fixes
Shiro-Raven Jul 31, 2024
4869235
Merge branch 'master' into ad/facade-wrapper
Shiro-Raven Jul 31, 2024
2375c4a
more test fixes
Shiro-Raven Jul 31, 2024
f32c10a
Merge branch 'master' into ad/facade-wrapper
Shiro-Raven Jul 31, 2024
af525b5
revert error message
Shiro-Raven Jul 31, 2024
2a0fced
Merge branch 'master' into ad/facade-wrapper
Shiro-Raven Aug 1, 2024
3d0cbba
more test fixes
Shiro-Raven Aug 1, 2024
f4b4142
more fixes
Shiro-Raven Aug 1, 2024
b7c7033
Merge branch 'master' into ad/facade-wrapper
Shiro-Raven Aug 1, 2024
b8d2125
more tests
Shiro-Raven Aug 1, 2024
04ddf49
weeeeeee
Shiro-Raven Aug 1, 2024
4833934
Merge branch 'master' into ad/facade-wrapper
Shiro-Raven Aug 1, 2024
d9712b7
Update pennylane/workflow/jacobian_products.py
Shiro-Raven Aug 2, 2024
fb43e42
Merge branch 'master' into ad/facade-wrapper
Shiro-Raven Aug 2, 2024
4a9fba1
fix minor logging test fail
Shiro-Raven Aug 1, 2024
ebdd00a
fix parameterized evolution test
Shiro-Raven Aug 2, 2024
edca13d
fixes
Shiro-Raven Aug 2, 2024
441f0fc
fix
Shiro-Raven Aug 2, 2024
57e8800
hopefully last fix
Shiro-Raven Aug 2, 2024
f812961
Merge branch 'master' into ad/facade-wrapper
Shiro-Raven Aug 5, 2024
33e4508
revert unnecessary changes
Shiro-Raven Aug 5, 2024
d18f695
delete unneeded attribute from QNode class
Shiro-Raven Aug 5, 2024
7016cd5
codecov fixes
Shiro-Raven Aug 5, 2024
f1a39cb
Merge branch 'master' into ad/facade-wrapper
Shiro-Raven Aug 5, 2024
4b136f3
type hints fix
Shiro-Raven Aug 5, 2024
4fc5671
minor redundancy removal
Shiro-Raven Aug 5, 2024
0aa4a20
Merge branch 'master' into ad/facade-wrapper
Shiro-Raven Aug 5, 2024
565ced6
add seed to TF test
Shiro-Raven Aug 5, 2024
b0f0772
revert default gradient method and remove `gradient_kwargs` from jaco…
Shiro-Raven Aug 5, 2024
ee6c610
remove erroneous xfails
Shiro-Raven Aug 5, 2024
90820a9
fix in facade tests
Shiro-Raven Aug 5, 2024
971b92a
no adjoint for non-expvals
albi3ro Aug 6, 2024
7e1c605
[no ci] bump nightly version
ringo-but-quantum Aug 6, 2024
e3cdff1
Fix `qml.center` with linear combinations (#6049)
dwierichs Aug 6, 2024
a281d80
revert repr of QNode to old logic
Shiro-Raven Aug 7, 2024
e2d0969
Merge branch 'master' into ad/facade-wrapper
Shiro-Raven Aug 7, 2024
3b07de5
Merge branch 'master' into ad/facade-wrapper
Shiro-Raven Aug 8, 2024
ca97469
fix tf legacy drawing test
Shiro-Raven Aug 8, 2024
954689f
generalize adjoint request in facade
Shiro-Raven Aug 8, 2024
5d0253b
Update tests/test_debugging.py
Shiro-Raven Aug 8, 2024
41f115f
fix recursion limit exceeded problem
Shiro-Raven Aug 8, 2024
e501e71
Revert "fix recursion limit exceeded problem"
Shiro-Raven Aug 8, 2024
9db314c
Merge branch 'master' into ad/facade-wrapper
Shiro-Raven Aug 8, 2024
fb35f1d
change adjoint ops definition
albi3ro Aug 9, 2024
3956ed4
adjoint is never best
albi3ro Aug 12, 2024
01a9ae9
adjoint allows non-trainable qubit unitary
albi3ro Aug 12, 2024
0b662d3
oops, is_trainable is in operation
albi3ro Aug 12, 2024
5879c18
ignore observable validation
Shiro-Raven Aug 12, 2024
8b91e04
address codecov missed lines
Shiro-Raven Aug 12, 2024
aa8e98b
address final feedback concerns and add codecov no cover
Shiro-Raven Aug 13, 2024
631a291
Merge branch 'master' into ad/facade-wrapper
Shiro-Raven Aug 13, 2024
a07dfff
Merge branch 'master' into ad/facade-wrapper
Shiro-Raven Aug 13, 2024
54d97d9
address codecov misses
Shiro-Raven Aug 14, 2024
e1cfde3
Merge branch 'master' into ad/facade-wrapper
Shiro-Raven Aug 14, 2024
8720103
more codecov fixes and minor code movement
Shiro-Raven Aug 14, 2024
3df7558
Merge branch 'master' into ad/facade-wrapper
Shiro-Raven Aug 14, 2024
7a5085c
remove __new__ method from facade class, added error for re-wrapping …
Shiro-Raven Aug 15, 2024
8f75e51
Merge branch 'master' into ad/facade-wrapper
Shiro-Raven Aug 15, 2024
7ed7a0b
add copy operations to facade class for catalyst use case
Shiro-Raven Aug 19, 2024
030715d
Merge branch 'master' into ad/facade-wrapper
Shiro-Raven Aug 19, 2024
b99804c
trigger CI
Shiro-Raven Aug 19, 2024
b59bb38
Merge branch 'master' into ad/facade-wrapper
Shiro-Raven Aug 20, 2024
d7e4438
Merge branch 'master' into ad/facade-wrapper
Shiro-Raven Aug 20, 2024
3a97f79
Merge branch 'master' into ad/facade-wrapper
Shiro-Raven Aug 21, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,11 @@
[(#5923)](https://github.com/PennyLaneAI/pennylane/pull/5923)

* `qml.devices.LegacyDeviceFacade` has been added to map the legacy devices to the new
device interface.
device interface, obviating the need to upgrade older devices to the new API.
Functions such as `qml.device`, `qml.execute` and QNode construction
now automatically wrap any passed legacy devices with the facade.
[(#5927)](https://github.com/PennyLaneAI/pennylane/pull/5927)
[(#6046)](https://github.com/PennyLaneAI/pennylane/pull/6046)

* Added the `compute_sparse_matrix` method for `qml.ops.qubit.BasisStateProjector`.
[(#5790)](https://github.com/PennyLaneAI/pennylane/pull/5790)
Expand Down
20 changes: 19 additions & 1 deletion pennylane/_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,25 @@ class DeviceError(Exception):
"""


class Device(abc.ABC):
class _LegacyMeta(abc.ABCMeta):
"""
A simple meta class added to circumvent the Legacy facade when
checking the instance of a device against a Legacy device type.

To illustrate, if "dev" is of type LegacyDeviceFacade, and a user is
checking "isinstance(dev, qml.devices.DefaultMixed)", the overridden
"__instancecheck__" will look behind the facade, and will evaluate instead
"isinstance(dev.target_device, qml.devices.DefaultMixed)"
"""

def __instancecheck__(cls, instance):
if isinstance(instance, qml.devices.LegacyDeviceFacade):
return isinstance(instance.target_device, cls)

return super().__instancecheck__(instance)


class Device(abc.ABC, metaclass=_LegacyMeta):
"""Abstract base class for PennyLane devices.

Args:
Expand Down
11 changes: 1 addition & 10 deletions pennylane/capture/capture_qnode.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,15 +90,6 @@ def _(*args, qnode, shots, device, qnode_kwargs, qfunc_jaxpr, n_consts):
return qnode_prim


# pylint: disable=protected-access
def _get_device_shots(device) -> "qml.measurements.Shots":
if isinstance(device, qml.devices.LegacyDevice):
if device._shot_vector:
return qml.measurements.Shots(device._raw_shot_sequence)
return qml.measurements.Shots(device.shots)
return device.shots


def qnode_call(qnode: "qml.QNode", *args, **kwargs) -> "qml.typing.Result":
"""A capture compatible call to a QNode. This function is internally used by ``QNode.__call__``.

Expand Down Expand Up @@ -166,7 +157,7 @@ def f(x):
if "shots" in kwargs:
shots = qml.measurements.Shots(kwargs.pop("shots"))
else:
shots = _get_device_shots(qnode.device)
shots = qnode.device.shots
if shots.has_partitioned_shots:
# Questions over the pytrees and the nested result object shape
raise NotImplementedError("shot vectors are not yet supported with plxpr capture.")
Expand Down
5 changes: 0 additions & 5 deletions pennylane/debugging/snapshot.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,16 +228,11 @@ def get_snapshots(*args, **kwargs):

with _SnapshotDebugger(qnode.device) as dbg:
# pylint: disable=protected-access
if qnode._original_device:
qnode._original_device._debugger = qnode.device._debugger

results = qnode(*args, **kwargs)

# Reset interface
if old_interface == "auto":
qnode.interface = "auto"
if qnode._original_device:
qnode.device._debugger = None

dbg.snapshots["execution_results"] = results
return dbg.snapshots
Expand Down
4 changes: 3 additions & 1 deletion pennylane/devices/device_constructor.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,7 +282,6 @@ def _safe_specifier_set(version_str):

# Once the device is constructed, we set its custom expansion function if
# any custom decompositions were specified.

if custom_decomps is not None:
if isinstance(dev, qml.devices.LegacyDevice):
custom_decomp_expand_fn = qml.transforms.create_decomp_expand_fn(
Expand All @@ -295,6 +294,9 @@ def _safe_specifier_set(version_str):
)
dev.preprocess = custom_decomp_preprocess

if isinstance(dev, qml.devices.LegacyDevice):
dev = qml.devices.LegacyDeviceFacade(dev)

return dev

raise qml.DeviceError(
Expand Down
76 changes: 56 additions & 20 deletions pennylane/devices/legacy_facade.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,14 @@
"""
import warnings

# pylint: disable=not-callable
# pylint: disable=not-callable, unused-argument
from contextlib import contextmanager
from dataclasses import replace

import pennylane as qml
from pennylane.measurements import Shots
from pennylane.measurements import MidMeasureMP, Shots
from pennylane.transforms.core.transform_program import TransformProgram

from .default_qubit import adjoint_observables, adjoint_ops
from .device_api import Device
from .execution_config import DefaultExecutionConfig
from .modifiers import single_tape_support
Expand All @@ -34,10 +33,16 @@
no_sampling,
validate_adjoint_trainable_params,
validate_measurements,
validate_observables,
)


def _requests_adjoint(execution_config):
return execution_config.gradient_method == "adjoint" or (
execution_config.gradient_method == "device"
and execution_config.gradient_keyword_arguments.get("method", None) == "adjoint_jacobian"
)


@contextmanager
def _set_shots(device, shots):
"""Context manager to temporarily change the shots
Expand Down Expand Up @@ -98,6 +103,15 @@ def legacy_device_batch_transform(tape, device):
return _set_shots(device, tape.shots)(device.batch_transform)(tape)


def adjoint_ops(op: qml.operation.Operator) -> bool:
"""Specify whether or not an Operator is supported by adjoint differentiation."""
if isinstance(op, qml.QubitUnitary) and not qml.operation.is_trainable(op):
return True
astralcai marked this conversation as resolved.
Show resolved Hide resolved
return not isinstance(op, MidMeasureMP) and (
Shiro-Raven marked this conversation as resolved.
Show resolved Hide resolved
op.num_params == 0 or (op.num_params == 1 and op.has_generator)
)


def _add_adjoint_transforms(program: TransformProgram, name="adjoint"):
"""Add the adjoint specific transforms to the transform program."""
program.add_transform(no_sampling, name=name)
Expand All @@ -106,9 +120,13 @@ def _add_adjoint_transforms(program: TransformProgram, name="adjoint"):
stopping_condition=adjoint_ops,
name=name,
)
program.add_transform(validate_observables, adjoint_observables, name=name)

def accepted_adjoint_measurements(mp):
return isinstance(mp, qml.measurements.ExpectationMP)

program.add_transform(
validate_measurements,
analytic_measurements=accepted_adjoint_measurements,
name=name,
)
program.add_transform(qml.transforms.broadcast_expand)
Expand Down Expand Up @@ -139,12 +157,16 @@ class LegacyDeviceFacade(Device):

"""

def __new__(cls, device: "qml.devices.LegacyDevice", *args, **kwargs):
return device if isinstance(device, cls) else super().__new__(cls)
astralcai marked this conversation as resolved.
Show resolved Hide resolved

# pylint: disable=super-init-not-called
def __init__(self, device: "qml.devices.LegacyDevice"):
if not isinstance(device, qml.devices.LegacyDevice):
raise ValueError(
"The LegacyDeviceFacade only accepts a device of type qml.devices.LegacyDevice."
)

self._device = device

@property
Expand Down Expand Up @@ -195,13 +217,20 @@ def _debugger(self, new_debugger):
def preprocess(self, execution_config=DefaultExecutionConfig):
execution_config = self._setup_execution_config(execution_config)
program = qml.transforms.core.TransformProgram()
# note: need to wrap these methods with a set_shots modifier

program.add_transform(legacy_device_batch_transform, device=self._device)
program.add_transform(legacy_device_expand_fn, device=self._device)
if execution_config.gradient_method == "adjoint":

if _requests_adjoint(execution_config):
_add_adjoint_transforms(program, name=f"{self.name} + adjoint")

if not self._device.capabilities().get("supports_mid_measure", False):
if self._device.capabilities().get("supports_mid_measure", False):
program.add_transform(
qml.devices.preprocess.mid_circuit_measurements,
device=self,
mcm_config=execution_config.mcm_config,
)
else:
program.add_transform(qml.defer_measurements, device=self)

return program, execution_config
Expand Down Expand Up @@ -230,8 +259,10 @@ def _setup_adjoint_config(self, execution_config):

def _setup_device_config(self, execution_config):
tape = qml.tape.QuantumScript([], [])

if not self._validate_device_method(tape):
raise qml.DeviceError("device does not support device derivatives")

updated_values = {}
if execution_config.use_device_gradient is None:
updated_values["use_device_gradient"] = True
Expand All @@ -243,19 +274,17 @@ def _setup_device_config(self, execution_config):
def _setup_execution_config(self, execution_config):
if execution_config.gradient_method == "best":
tape = qml.tape.QuantumScript([], [])
if self._validate_backprop_method(tape):
config = replace(execution_config, gradient_method="backprop")
return self._setup_backprop_config(config)
if self._validate_adjoint_method(tape):
config = replace(execution_config, gradient_method="adjoint")
return self._setup_adjoint_config(config)
if self._validate_device_method(tape):
config = replace(execution_config, gradient_method="device")
return self._setup_execution_config(config)

if self._validate_backprop_method(tape):
config = replace(execution_config, gradient_method="backprop")
return self._setup_backprop_config(config)

if execution_config.gradient_method == "backprop":
return self._setup_backprop_config(execution_config)
if execution_config.gradient_method == "adjoint":
if _requests_adjoint(execution_config):
return self._setup_adjoint_config(execution_config)
if execution_config.gradient_method == "device":
return self._setup_device_config(execution_config)
Expand All @@ -268,17 +297,17 @@ def supports_derivatives(self, execution_config=None, circuit=None) -> bool:
if execution_config is None or execution_config.gradient_method == "best":
validation_methods = (
self._validate_backprop_method,
self._validate_adjoint_method,
self._validate_device_method,
)
return any(validate(circuit) for validate in validation_methods)

if execution_config.gradient_method == "backprop":
return self._validate_backprop_method(circuit)
if execution_config.gradient_method == "adjoint":
if _requests_adjoint(execution_config):
return self._validate_adjoint_method(circuit)
if execution_config.gradient_method == "device":
return self._validate_device_method(circuit)

return False

# pylint: disable=protected-access
Expand Down Expand Up @@ -333,7 +362,7 @@ def _create_temp_device(self, batch):
backprop_devices[mapped_interface],
wires=self._device.wires,
shots=self._device.shots,
)
).target_device

new_device.expand_fn = expand_fn
new_device.batch_transform = batch_transform
Expand Down Expand Up @@ -368,6 +397,7 @@ def _validate_backprop_method(self, tape):

# determine if the device supports backpropagation
backprop_interface = self._device.capabilities().get("passthru_interface", None)

if backprop_interface is not None:
# device supports backpropagation natively
return mapped_interface in [backprop_interface, "Numpy"]
Expand All @@ -388,9 +418,15 @@ def _validate_adjoint_method(self, tape):
supported_device = all(hasattr(self._device, attr) for attr in required_attrs)
supported_device = supported_device and self._device.capabilities().get("returns_state")

if not supported_device:
if not supported_device or bool(tape.shots):
return False
program = TransformProgram()
_add_adjoint_transforms(program, name=f"{self.name} + adjoint")
try:
program((tape,))
except (qml.operation.DecompositionUndefinedError, qml.DeviceError, AttributeError):
return False
return not bool(tape.shots)
return True

def _validate_device_method(self, _):
# determine if the device provides its own jacobian method
Expand Down
22 changes: 22 additions & 0 deletions pennylane/devices/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
"""Contains shared fixtures for the device tests."""
import argparse
import os
import warnings

import numpy as np
import pytest
Expand Down Expand Up @@ -41,6 +42,27 @@
}


@pytest.fixture(scope="function", autouse=True)
def capture_legacy_device_deprecation_warnings():
"""Catches all warnings raised by a test and verifies that any Deprecation
warnings released are related to the legacy devices. Otherwise, it re-raises
any unrelated warnings"""

with warnings.catch_warnings(record=True) as recwarn:
warnings.simplefilter("always")
yield

for w in recwarn:
if isinstance(w, qml.PennyLaneDeprecationWarning):
assert "Use of 'default.qubit." in str(w.message)
astralcai marked this conversation as resolved.
Show resolved Hide resolved
assert "is deprecated" in str(w.message)
assert "use 'default.qubit'" in str(w.message)

for w in recwarn:
if "Use of 'default.qubit." not in str(w.message):
warnings.warn(message=w.message, category=w.category)


@pytest.fixture(scope="function")
def tol():
"""Numerical tolerance for equality tests. Returns a different tolerance for tests
Expand Down
1 change: 1 addition & 0 deletions pennylane/devices/tests/test_gates.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
# pylint: disable=too-many-arguments
# pylint: disable=pointless-statement
# pylint: disable=unnecessary-lambda-assignment
# pylint: disable=no-name-in-module
from cmath import exp
from math import cos, sin, sqrt

Expand Down
16 changes: 8 additions & 8 deletions pennylane/devices/tests/test_measurements.py
Original file line number Diff line number Diff line change
Expand Up @@ -1783,15 +1783,15 @@ def circuit():
qml.X(0)
return MyMeasurement()

if isinstance(dev, qml.Device):
with pytest.warns(
with (
pytest.warns(
UserWarning,
match="Requested measurement MyMeasurement with finite shots",
):
circuit()
else:
with pytest.raises(qml.DeviceError):
circuit()
match="MyMeasurement with finite shots; the returned state information is analytic",
)
if isinstance(dev, qml.devices.LegacyDevice)
astralcai marked this conversation as resolved.
Show resolved Hide resolved
else pytest.raises(qml.DeviceError, match="not accepted with finite shots")
):
circuit()

def test_method_overriden_by_device(self, device):
"""Test that the device can override a measurement process."""
Expand Down
Loading
Loading