Skip to content

Commit

Permalink
Wrapping of legacy device automatically in various device creation/qn…
Browse files Browse the repository at this point in the history
…ode/execute functions (#6046)

**Context:**
With the `LegacyDeviceFacade` now in place, we can add automatic
wrapping of legacy devices.

**Description of the Change:**
Add automatic wrapping to `qml.device`, `qml.execute`, `QNode`
constructor, and the `get_best_method` and `best_method_str` functions
of the QNode class. The tests are also updated accordingly.

**Benefits:**
Users no longer need to worry about upgrading their devices to the new
Device API and can use the facade to access the basic functions of the
new API.

**Possible Drawbacks:**
The facade doesn't yet provide all potential advantages of fully
upgrading to the new API

[[sc-65998](https://app.shortcut.com/xanaduai/story/65998)]

---------

Co-authored-by: albi3ro <chrissie.c.l@gmail.com>
Co-authored-by: Christina Lee <christina@xanadu.ai>
Co-authored-by: Mudit Pandey <mudit.pandey@xanadu.ai>
Co-authored-by: ringo-but-quantum <github-ringo-but-quantum@xanadu.ai>
Co-authored-by: David Wierichs <david.wierichs@xanadu.ai>
Co-authored-by: Thomas R. Bromley <49409390+trbromley@users.noreply.github.com>
Co-authored-by: Korbinian Kottmann <43949391+Qottmann@users.noreply.github.com>
  • Loading branch information
8 people authored Aug 21, 2024
1 parent dea7a2d commit 1f55c88
Show file tree
Hide file tree
Showing 68 changed files with 1,285 additions and 2,018 deletions.
11 changes: 1 addition & 10 deletions pennylane/capture/capture_qnode.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,15 +90,6 @@ def _(*args, qnode, shots, device, qnode_kwargs, qfunc_jaxpr, n_consts):
return qnode_prim


# pylint: disable=protected-access
def _get_device_shots(device) -> "qml.measurements.Shots":
if isinstance(device, qml.devices.LegacyDevice):
if device._shot_vector:
return qml.measurements.Shots(device._raw_shot_sequence)
return qml.measurements.Shots(device.shots)
return device.shots


def qnode_call(qnode: "qml.QNode", *args, **kwargs) -> "qml.typing.Result":
"""A capture compatible call to a QNode. This function is internally used by ``QNode.__call__``.
Expand Down Expand Up @@ -166,7 +157,7 @@ def f(x):
if "shots" in kwargs:
shots = qml.measurements.Shots(kwargs.pop("shots"))
else:
shots = _get_device_shots(qnode.device)
shots = qnode.device.shots
if shots.has_partitioned_shots:
# Questions over the pytrees and the nested result object shape
raise NotImplementedError("shot vectors are not yet supported with plxpr capture.")
Expand Down
5 changes: 0 additions & 5 deletions pennylane/debugging/snapshot.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,16 +228,11 @@ def get_snapshots(*args, **kwargs):

with _SnapshotDebugger(qnode.device) as dbg:
# pylint: disable=protected-access
if qnode._original_device:
qnode._original_device._debugger = qnode.device._debugger

results = qnode(*args, **kwargs)

# Reset interface
if old_interface == "auto":
qnode.interface = "auto"
if qnode._original_device:
qnode.device._debugger = None

dbg.snapshots["execution_results"] = results
return dbg.snapshots
Expand Down
20 changes: 19 additions & 1 deletion pennylane/devices/_legacy_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,25 @@ def _local_tape_expand(tape, depth, stop_at):
return new_tape


class Device(abc.ABC):
class _LegacyMeta(abc.ABCMeta):
"""
A simple meta class added to circumvent the Legacy facade when
checking the instance of a device against a Legacy device type.
To illustrate, if "dev" is of type LegacyDeviceFacade, and a user is
checking "isinstance(dev, qml.devices.DefaultMixed)", the overridden
"__instancecheck__" will look behind the facade, and will evaluate instead
"isinstance(dev.target_device, qml.devices.DefaultMixed)"
"""

def __instancecheck__(cls, instance):
if isinstance(instance, qml.devices.LegacyDeviceFacade):
return isinstance(instance.target_device, cls)

return super().__instancecheck__(instance)


class Device(abc.ABC, metaclass=_LegacyMeta):
"""Abstract base class for PennyLane devices.
Args:
Expand Down
4 changes: 3 additions & 1 deletion pennylane/devices/device_constructor.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,7 +281,6 @@ def _safe_specifier_set(version_str):

# Once the device is constructed, we set its custom expansion function if
# any custom decompositions were specified.

if custom_decomps is not None:
if isinstance(dev, qml.devices.LegacyDevice):
custom_decomp_expand_fn = qml.transforms.create_decomp_expand_fn(
Expand All @@ -294,6 +293,9 @@ def _safe_specifier_set(version_str):
)
dev.preprocess = custom_decomp_preprocess

if isinstance(dev, qml.devices.LegacyDevice):
dev = qml.devices.LegacyDeviceFacade(dev)

return dev

raise qml.DeviceError(
Expand Down
84 changes: 64 additions & 20 deletions pennylane/devices/legacy_facade.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,15 @@
"""
import warnings

# pylint: disable=not-callable
# pylint: disable=not-callable, unused-argument
from contextlib import contextmanager
from copy import copy, deepcopy
from dataclasses import replace

import pennylane as qml
from pennylane.measurements import Shots
from pennylane.measurements import MidMeasureMP, Shots
from pennylane.transforms.core.transform_program import TransformProgram

from .default_qubit import adjoint_observables, adjoint_ops
from .device_api import Device
from .execution_config import DefaultExecutionConfig
from .modifiers import single_tape_support
Expand All @@ -34,10 +34,16 @@
no_sampling,
validate_adjoint_trainable_params,
validate_measurements,
validate_observables,
)


def _requests_adjoint(execution_config):
return execution_config.gradient_method == "adjoint" or (
execution_config.gradient_method == "device"
and execution_config.gradient_keyword_arguments.get("method", None) == "adjoint_jacobian"
)


@contextmanager
def _set_shots(device, shots):
"""Context manager to temporarily change the shots
Expand Down Expand Up @@ -98,6 +104,15 @@ def legacy_device_batch_transform(tape, device):
return _set_shots(device, tape.shots)(device.batch_transform)(tape)


def adjoint_ops(op: qml.operation.Operator) -> bool:
"""Specify whether or not an Operator is supported by adjoint differentiation."""
if isinstance(op, qml.QubitUnitary) and not qml.operation.is_trainable(op):
return True
return not isinstance(op, MidMeasureMP) and (
op.num_params == 0 or (op.num_params == 1 and op.has_generator)
)


def _add_adjoint_transforms(program: TransformProgram, name="adjoint"):
"""Add the adjoint specific transforms to the transform program."""
program.add_transform(no_sampling, name=name)
Expand All @@ -106,9 +121,13 @@ def _add_adjoint_transforms(program: TransformProgram, name="adjoint"):
stopping_condition=adjoint_ops,
name=name,
)
program.add_transform(validate_observables, adjoint_observables, name=name)

def accepted_adjoint_measurements(mp):
return isinstance(mp, qml.measurements.ExpectationMP)

program.add_transform(
validate_measurements,
analytic_measurements=accepted_adjoint_measurements,
name=name,
)
program.add_transform(qml.transforms.broadcast_expand)
Expand Down Expand Up @@ -141,10 +160,14 @@ class LegacyDeviceFacade(Device):

# pylint: disable=super-init-not-called
def __init__(self, device: "qml.devices.LegacyDevice"):
if isinstance(device, type(self)):
raise RuntimeError("An already-facaded device can not be wrapped in a facade again.")

if not isinstance(device, qml.devices.LegacyDevice):
raise ValueError(
"The LegacyDeviceFacade only accepts a device of type qml.devices.LegacyDevice."
)

self._device = device

@property
Expand All @@ -168,6 +191,13 @@ def __repr__(self):
def __getattr__(self, name):
return getattr(self._device, name)

# These custom copy methods are needed for Catalyst
def __copy__(self):
return type(self)(copy(self.target_device))

def __deepcopy__(self, memo):
return type(self)(deepcopy(self.target_device, memo))

@property
def target_device(self) -> "qml.devices.LegacyDevice":
"""The device wrapped by the facade."""
Expand Down Expand Up @@ -195,13 +225,20 @@ def _debugger(self, new_debugger):
def preprocess(self, execution_config=DefaultExecutionConfig):
execution_config = self._setup_execution_config(execution_config)
program = qml.transforms.core.TransformProgram()
# note: need to wrap these methods with a set_shots modifier

program.add_transform(legacy_device_batch_transform, device=self._device)
program.add_transform(legacy_device_expand_fn, device=self._device)
if execution_config.gradient_method == "adjoint":

if _requests_adjoint(execution_config):
_add_adjoint_transforms(program, name=f"{self.name} + adjoint")

if not self._device.capabilities().get("supports_mid_measure", False):
if self._device.capabilities().get("supports_mid_measure", False):
program.add_transform(
qml.devices.preprocess.mid_circuit_measurements,
device=self,
mcm_config=execution_config.mcm_config,
)
else:
program.add_transform(qml.defer_measurements, device=self)

return program, execution_config
Expand Down Expand Up @@ -230,8 +267,10 @@ def _setup_adjoint_config(self, execution_config):

def _setup_device_config(self, execution_config):
tape = qml.tape.QuantumScript([], [])

if not self._validate_device_method(tape):
raise qml.DeviceError("device does not support device derivatives")

updated_values = {}
if execution_config.use_device_gradient is None:
updated_values["use_device_gradient"] = True
Expand All @@ -243,19 +282,17 @@ def _setup_device_config(self, execution_config):
def _setup_execution_config(self, execution_config):
if execution_config.gradient_method == "best":
tape = qml.tape.QuantumScript([], [])
if self._validate_backprop_method(tape):
config = replace(execution_config, gradient_method="backprop")
return self._setup_backprop_config(config)
if self._validate_adjoint_method(tape):
config = replace(execution_config, gradient_method="adjoint")
return self._setup_adjoint_config(config)
if self._validate_device_method(tape):
config = replace(execution_config, gradient_method="device")
return self._setup_execution_config(config)

if self._validate_backprop_method(tape):
config = replace(execution_config, gradient_method="backprop")
return self._setup_backprop_config(config)

if execution_config.gradient_method == "backprop":
return self._setup_backprop_config(execution_config)
if execution_config.gradient_method == "adjoint":
if _requests_adjoint(execution_config):
return self._setup_adjoint_config(execution_config)
if execution_config.gradient_method == "device":
return self._setup_device_config(execution_config)
Expand All @@ -268,17 +305,17 @@ def supports_derivatives(self, execution_config=None, circuit=None) -> bool:
if execution_config is None or execution_config.gradient_method == "best":
validation_methods = (
self._validate_backprop_method,
self._validate_adjoint_method,
self._validate_device_method,
)
return any(validate(circuit) for validate in validation_methods)

if execution_config.gradient_method == "backprop":
return self._validate_backprop_method(circuit)
if execution_config.gradient_method == "adjoint":
if _requests_adjoint(execution_config):
return self._validate_adjoint_method(circuit)
if execution_config.gradient_method == "device":
return self._validate_device_method(circuit)

return False

# pylint: disable=protected-access
Expand Down Expand Up @@ -333,7 +370,7 @@ def _create_temp_device(self, batch):
backprop_devices[mapped_interface],
wires=self._device.wires,
shots=self._device.shots,
)
).target_device

new_device.expand_fn = expand_fn
new_device.batch_transform = batch_transform
Expand Down Expand Up @@ -368,6 +405,7 @@ def _validate_backprop_method(self, tape):

# determine if the device supports backpropagation
backprop_interface = self._device.capabilities().get("passthru_interface", None)

if backprop_interface is not None:
# device supports backpropagation natively
return mapped_interface in [backprop_interface, "Numpy"]
Expand All @@ -388,9 +426,15 @@ def _validate_adjoint_method(self, tape):
supported_device = all(hasattr(self._device, attr) for attr in required_attrs)
supported_device = supported_device and self._device.capabilities().get("returns_state")

if not supported_device:
if not supported_device or bool(tape.shots):
return False
program = TransformProgram()
_add_adjoint_transforms(program, name=f"{self.name} + adjoint")
try:
program((tape,))
except (qml.operation.DecompositionUndefinedError, qml.DeviceError, AttributeError):
return False
return not bool(tape.shots)
return True

def _validate_device_method(self, _):
# determine if the device provides its own jacobian method
Expand Down
22 changes: 22 additions & 0 deletions pennylane/devices/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
"""Contains shared fixtures for the device tests."""
import argparse
import os
import warnings

import numpy as np
import pytest
Expand Down Expand Up @@ -41,6 +42,27 @@
}


@pytest.fixture(scope="function", autouse=True)
def capture_legacy_device_deprecation_warnings():
"""Catches all warnings raised by a test and verifies that any Deprecation
warnings released are related to the legacy devices. Otherwise, it re-raises
any unrelated warnings"""

with warnings.catch_warnings(record=True) as recwarn:
warnings.simplefilter("always")
yield

for w in recwarn:
if isinstance(w, qml.PennyLaneDeprecationWarning):
assert "Use of 'default.qubit." in str(w.message)
assert "is deprecated" in str(w.message)
assert "use 'default.qubit'" in str(w.message)

for w in recwarn:
if "Use of 'default.qubit." not in str(w.message):
warnings.warn(message=w.message, category=w.category)


@pytest.fixture(scope="function")
def tol():
"""Numerical tolerance for equality tests. Returns a different tolerance for tests
Expand Down
1 change: 1 addition & 0 deletions pennylane/devices/tests/test_gates.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
# pylint: disable=too-many-arguments
# pylint: disable=pointless-statement
# pylint: disable=unnecessary-lambda-assignment
# pylint: disable=no-name-in-module
from cmath import exp
from math import cos, sin, sqrt

Expand Down
16 changes: 8 additions & 8 deletions pennylane/devices/tests/test_measurements.py
Original file line number Diff line number Diff line change
Expand Up @@ -1783,15 +1783,15 @@ def circuit():
qml.X(0)
return MyMeasurement()

if isinstance(dev, qml.Device):
with pytest.warns(
with (
pytest.warns(
UserWarning,
match="Requested measurement MyMeasurement with finite shots",
):
circuit()
else:
with pytest.raises(qml.DeviceError):
circuit()
match="MyMeasurement with finite shots; the returned state information is analytic",
)
if isinstance(dev, qml.devices.LegacyDevice)
else pytest.raises(qml.DeviceError, match="not accepted with finite shots")
):
circuit()

def test_method_overriden_by_device(self, device):
"""Test that the device can override a measurement process."""
Expand Down
12 changes: 4 additions & 8 deletions pennylane/optimize/qnspsa.py
Original file line number Diff line number Diff line change
Expand Up @@ -442,15 +442,11 @@ def _apply_blocking(self, cost, args, kwargs, params_next):
cost.construct(params_next, kwargs)
tape_loss_next = cost.tape.copy(copy_operations=True)

if isinstance(cost.device, qml.devices.Device):
program, _ = cost.device.preprocess()

loss_curr, loss_next = qml.execute(
[tape_loss_curr, tape_loss_next], cost.device, None, transform_program=program
)
program, _ = cost.device.preprocess()

else:
loss_curr, loss_next = qml.execute([tape_loss_curr, tape_loss_next], cost.device, None)
loss_curr, loss_next = qml.execute(
[tape_loss_curr, tape_loss_next], cost.device, None, transform_program=program
)

# self.k has been updated earlier
ind = (self.k - 2) % self.last_n_steps.size
Expand Down
Loading

0 comments on commit 1f55c88

Please sign in to comment.