Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor cache_execute() as a transform #5318

Merged
merged 28 commits into from
Mar 25, 2024
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
41bab70
Add `_cache_transform()` function to replace `cache_execute()`
Mandrenkov Mar 5, 2024
dc13c4c
Update docstring to mention `_cache_transform()`
Mandrenkov Mar 5, 2024
1b0c129
Update tests to spy on `_cache_transform()`
Mandrenkov Mar 5, 2024
e8463fa
Add explicit tests for `_cache_transform()`
Mandrenkov Mar 5, 2024
eb3f1dc
Extract `_apply_cache_transform()` for better reuse
Mandrenkov Mar 5, 2024
dbbd873
Add note to changelog under 'Improvements' section
Mandrenkov Mar 5, 2024
f116252
Disable pylint fixture warnings
Mandrenkov Mar 5, 2024
d494609
Merge branch 'master' into sc-39484-cache-transform
Mandrenkov Mar 5, 2024
37f9c44
Merge branch 'master' into sc-39484-cache-transform
Mandrenkov Mar 8, 2024
94a8e74
Restore tape with finite shots and persistent cache warning
Mandrenkov Mar 8, 2024
34b9ea9
Implement `_make_inner_execute()` using TransformProgram
Mandrenkov Mar 8, 2024
6191f04
Merge branch 'master' into sc-39484-cache-transform
Mandrenkov Mar 8, 2024
664f3ee
Add check for `None` cache
Mandrenkov Mar 11, 2024
d7b5dba
Update remaining broken `mocker.spy` references
Mandrenkov Mar 12, 2024
531beff
Avoid passing empty tape sequence to device execution
Mandrenkov Mar 12, 2024
07e57ee
Avoid comparing length of sequence to 0
Mandrenkov Mar 12, 2024
deefa70
Merge branch 'master' into sc-39484-cache-transform
Mandrenkov Mar 12, 2024
521cb00
Update number of expected logs in tests
Mandrenkov Mar 12, 2024
234d564
Delete `cache_execute()` and any mentions of it
Mandrenkov Mar 12, 2024
89da01d
Delete unused `wraps` import
Mandrenkov Mar 12, 2024
b43eb18
Add unit tests for `_apply_cache_transform()`
Mandrenkov Mar 13, 2024
de4fbe1
Fix reST formatting of test module docstring
Mandrenkov Mar 13, 2024
14eaaf0
Merge branch 'master' into sc-39484-cache-transform
Mandrenkov Mar 13, 2024
3844a61
Avoid referencing `_cache_transform()` in documentation
Mandrenkov Mar 14, 2024
20b7645
Merge branch 'master' into sc-39484-cache-transform
Mandrenkov Mar 14, 2024
7044857
Merge branch 'master' into sc-39484-cache-transform
Mandrenkov Mar 22, 2024
17fe338
Merge branch 'master' into sc-39484-cache-transform
Mandrenkov Mar 25, 2024
8bd60d5
Merge branch 'master' into sc-39484-cache-transform
Mandrenkov Mar 25, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -91,14 +91,17 @@

* The `molecular_hamiltonian` function calls `PySCF` directly when `method='pyscf'` is selected.
[(#5118)](https://github.com/PennyLaneAI/pennylane/pull/5118)
* All generators in the source code (except those in the `qchem` module) no longer return

* All generators in the source code (except those in the `qchem` module) no longer return
`Hamiltonian` or `Tensor` instances. Wherever possible, these return `Sum`, `SProd`, and `Prod` instances.
[(#5253)](https://github.com/PennyLaneAI/pennylane/pull/5253)

* Upgraded `null.qubit` to the new device API. Also, added support for all measurements and various modes of differentiation.
[(#5211)](https://github.com/PennyLaneAI/pennylane/pull/5211)

* Replaced `cache_execute` with an alternate implementation based on `@transform`.
[(#5318)](https://github.com/PennyLaneAI/pennylane/pull/5318)

* The `QNode` now defers `diff_method` validation to the device under the new device api `qml.devices.Device`.
[(#5176)](https://github.com/PennyLaneAI/pennylane/pull/5176)

Expand Down
120 changes: 92 additions & 28 deletions pennylane/workflow/execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,14 @@
import inspect
import warnings
from functools import wraps, partial
from typing import Callable, Sequence, Optional, Union, Tuple
from typing import Callable, MutableMapping, Sequence, Optional, Union, Tuple
import logging

from cachetools import LRUCache, Cache

import pennylane as qml
from pennylane.tape import QuantumTape
from pennylane.transforms import transform
from pennylane.typing import ResultBatch

from .set_shots import set_shots
Expand Down Expand Up @@ -85,6 +86,17 @@
"""list[str]: allowed interface strings"""


_CACHED_EXECUTION_WITH_FINITE_SHOTS_WARNINGS = (
"Cached execution with finite shots detected!\n"
"Note that samples as well as all noisy quantities computed via sampling "
"will be identical across executions. This situation arises where tapes "
"are executed with identical operations, measurements, and parameters.\n"
"To avoid this behaviour, provide 'cache=False' to the QNode or execution "
"function."
)
"""str: warning message to display when cached execution is used with finite shots"""
trbromley marked this conversation as resolved.
Show resolved Hide resolved


def _adjoint_jacobian_expansion(
tapes: Sequence[QuantumTape], grad_on_execution: bool, interface: str, max_expansion: int
):
Expand Down Expand Up @@ -259,28 +271,31 @@ def _make_inner_execute(

if isinstance(device, qml.Device):
device_execution = set_shots(device, override_shots)(device.batch_execute)

else:
device_execution = partial(device.execute, execution_config=execution_config)

cached_device_execution = qml.workflow.cache_execute(
device_execution, cache, return_tuple=False
)

def inner_execute(tapes: Sequence[QuantumTape], **_) -> ResultBatch:
"""Execution that occurs within a machine learning framework boundary.

Closure Variables:
expand_fn (Callable[[QuantumTape], QuantumTape]): A device preprocessing step
numpy_only (bool): whether or not to convert the data to numpy or leave as is
cached_device_execution (Callable[[Sequence[QuantumTape]], ResultBatch])

device_execution (Callable[[Sequence[QuantumTape]], ResultBatch])
cache (None | MutableMapping): The cache to use. If ``None``, caching will not occur.
"""
transform_program = qml.transforms.core.TransformProgram()

if cache is not None:
transform_program.add_transform(_cache_transform, cache=cache)

# TODO: Apply expand_fn() and convert_to_numpy_parameters() as transforms.
if expand_fn:
tapes = tuple(expand_fn(t) for t in tapes)
if numpy_only:
tapes = tuple(qml.transforms.convert_to_numpy_parameters(t) for t in tapes)
return cached_device_execution(tapes)

transformed_tapes, transform_post_processing = transform_program(tapes)
return transform_post_processing(device_execution(transformed_tapes))
Mandrenkov marked this conversation as resolved.
Show resolved Hide resolved

return inner_execute

Expand Down Expand Up @@ -322,6 +337,8 @@ def cache_execute(fn: Callable, cache, pass_kwargs=False, return_tuple=True, exp
function: a wrapped version of the execution function ``fn`` with caching
support
"""
# TODO: Add deprecation warning.
# This function has been replaced by ``_cache_transform()``.
if logger.isEnabledFor(logging.DEBUG):
logger.debug(
"Entry with args=(fn=%s, cache=%s, pass_kwargs=%s, return_tuple=%s, expand_fn=%s) called by=%s",
Expand Down Expand Up @@ -383,15 +400,7 @@ def wrapper(tapes: Sequence[QuantumTape], **kwargs):
# Tape exists within the cache, store the cached result
cached_results[i] = cache[hashes[i]]
if tape.shots and getattr(cache, "_persistent_cache", True):
warnings.warn(
"Cached execution with finite shots detected!\n"
"Note that samples as well as all noisy quantities computed via sampling "
"will be identical across executions. This situation arises where tapes "
"are executed with identical operations, measurements, and parameters.\n"
"To avoid this behavior, provide 'cache=False' to the QNode or execution "
"function.",
UserWarning,
)
warnings.warn(_CACHED_EXECUTION_WITH_FINITE_SHOTS_WARNINGS, UserWarning)
else:
# Tape does not exist within the cache, store the tape
# for execution via the execution function.
Expand Down Expand Up @@ -431,6 +440,62 @@ def wrapper(tapes: Sequence[QuantumTape], **kwargs):
return wrapper


@transform
def _cache_transform(tape: QuantumTape, cache: MutableMapping):
"""Caches the result of ``tape`` using the provided ``cache``.

.. note::

This function makes use of :attr:`.QuantumTape.hash` to identify unique tapes.
"""

def cache_hit_postprocessing(_results: Tuple[Tuple]) -> Tuple:
result = cache[tape.hash]
if result is not None:
if tape.shots and getattr(cache, "_persistent_cache", True):
warnings.warn(_CACHED_EXECUTION_WITH_FINITE_SHOTS_WARNINGS, UserWarning)
return result

raise RuntimeError(
"Result for tape is missing from the execution cache. "
"This is likely the result of a race condition."
)

if tape.hash in cache:
return [], cache_hit_postprocessing

def cache_miss_postprocessing(results: Tuple[Tuple]) -> Tuple:
result = results[0]
cache[tape.hash] = result
return result

# Adding a ``None`` entry to the cache indicates that a result will eventually be available for
# the tape. This assumes that post-processing functions are called in the same order in which
# the transforms are invoked. Otherwise, ``cache_hit_postprocessing()`` may be called before the
# result of the corresponding tape is placed in the cache by ``cache_miss_postprocessing()``.
cache[tape.hash] = None
return [tape], cache_miss_postprocessing


def _apply_cache_transform(fn: Callable, cache: Optional[MutableMapping]) -> Callable:
"""Wraps the given execution function with ``_cache_transform()`` using the provided cache.

Args:
fn (Callable): The execution function to be augmented with caching. This function should
have the signature ``fn(tapes, **kwargs)`` and return ``list[tensor_like]`` with the
same length as the input ``tapes``.
cache (None | MutableMapping): The cache to use. If ``None``, caching will not occur.
"""
if cache is None:
return fn

def execution_function_with_caching(tapes):
tapes, post_processing_fn = _cache_transform(tapes, cache=cache)
return post_processing_fn(fn(tapes))

return execution_function_with_caching


def execute(
tapes: Sequence[QuantumTape],
device: device_type,
Expand All @@ -440,7 +505,7 @@ def execute(
config=None,
grad_on_execution="best",
gradient_kwargs=None,
cache: Union[bool, dict, Cache] = True,
cache: Union[None, bool, dict, Cache] = True,
cachesize=10000,
max_diff=1,
override_shots: int = False,
Expand Down Expand Up @@ -471,7 +536,7 @@ def execute(
pass. The 'best' option chooses automatically between the two options and is default.
gradient_kwargs (dict): dictionary of keyword arguments to pass when
determining the gradients of tapes
cache (bool, dict, Cache): Whether to cache evaluations. This can result in
cache (None, bool, dict, Cache): Whether to cache evaluations. This can result in
a significant reduction in quantum evaluations during gradient computations.
cachesize (int): the size of the cache
max_diff (int): If ``gradient_fn`` is a gradient transform, this option specifies
Expand Down Expand Up @@ -624,11 +689,15 @@ def cost_fn(params, x):
else:
transform_program = qml.transforms.core.TransformProgram()

if isinstance(cache, bool) and cache:
# cache=True: create a LRUCache object
# If caching is desired but an explicit cache is not provided, use an ``LRUCache``.
if cache is True:
cache = LRUCache(maxsize=cachesize)
setattr(cache, "_persistent_cache", False)

# Ensure that ``cache`` is not a Boolean to simplify downstream code.
elif cache is False:
cache = None

expand_fn = _preprocess_expand_fn(expand_fn, device, max_expansion)

# changing this set of conditions causes a bunch of tests to break.
Expand Down Expand Up @@ -797,12 +866,7 @@ def inner_execute_with_empty_jac(tapes, **_):

# replace the backward gradient computation
gradient_fn_with_shots = set_shots(device, override_shots)(device.gradients)
cached_gradient_fn = qml.workflow.cache_execute(
gradient_fn_with_shots,
cache,
pass_kwargs=True,
return_tuple=False,
)
cached_gradient_fn = _apply_cache_transform(fn=gradient_fn_with_shots, cache=cache)

def device_gradient_fn(inner_tapes, **gradient_kwargs):
numpy_tapes = tuple(
Expand Down
2 changes: 1 addition & 1 deletion pennylane/workflow/jacobian_products.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,7 @@ class TransformJacobianProducts(JacobianProductCalculator):
instead of treating each call as independent. This keyword argument is used to patch problematic
autograd behavior when caching is turned off. In this case, caching will be based on the identity
of the batch, rather than the potentially expensive :attr:`~.QuantumScript.hash` that is used
by :func:`~.cache_execute`.
by :func:`~._cache_transform`.
Mandrenkov marked this conversation as resolved.
Show resolved Hide resolved

>>> inner_execute = qml.device('default.qubit').execute
>>> gradient_transform = qml.gradients.param_shift
Expand Down
10 changes: 5 additions & 5 deletions tests/interfaces/test_autograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,7 @@ class TestCaching:
def test_cache_maxsize(self, mocker):
"""Test the cachesize property of the cache"""
dev = qml.device("default.qubit.legacy", wires=1)
spy = mocker.spy(qml.workflow, "cache_execute")
spy = mocker.spy(qml.workflow.execution._cache_transform, "_transform")

def cost(a, cachesize):
with qml.queuing.AnnotatedQueue() as q:
Expand All @@ -241,7 +241,7 @@ def cost(a, cachesize):

params = np.array([0.1, 0.2])
qml.jacobian(cost)(params, cachesize=2)
cache = spy.call_args[0][1]
cache = spy.call_args.kwargs["cache"]

assert cache.maxsize == 2
assert cache.currsize == 2
Expand All @@ -250,7 +250,7 @@ def cost(a, cachesize):
def test_custom_cache(self, mocker):
"""Test the use of a custom cache object"""
dev = qml.device("default.qubit.legacy", wires=1)
spy = mocker.spy(qml.workflow, "cache_execute")
spy = mocker.spy(qml.workflow.execution._cache_transform, "_transform")

def cost(a, cache):
with qml.queuing.AnnotatedQueue() as q:
Expand All @@ -265,7 +265,7 @@ def cost(a, cache):
params = np.array([0.1, 0.2])
qml.jacobian(cost)(params, cache=custom_cache)

cache = spy.call_args[0][1]
cache = spy.call_args.kwargs["cache"]
assert cache is custom_cache

def test_caching_param_shift(self, tol):
Expand Down Expand Up @@ -398,7 +398,7 @@ def cost(a, cache):
)[0]
)

# no cache_execute caching, but jac for each batch still stored.
# no caching, but jac for each batch still stored.
qml.jacobian(cost)(params, cache=None)
assert dev.num_executions == 2

Expand Down
12 changes: 6 additions & 6 deletions tests/interfaces/test_jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,7 @@ class TestCaching:
def test_cache_maxsize(self, mocker):
"""Test the cachesize property of the cache"""
dev = qml.device("default.qubit.legacy", wires=1)
spy = mocker.spy(qml.workflow, "cache_execute")
spy = mocker.spy(qml.workflow.execution, "_cache_transform")
Mandrenkov marked this conversation as resolved.
Show resolved Hide resolved

def cost(a, cachesize):
with qml.queuing.AnnotatedQueue() as q:
Expand All @@ -213,7 +213,7 @@ def cost(a, cachesize):

params = jax.numpy.array([0.1, 0.2])
jax.grad(cost)(params, cachesize=2)
cache = spy.call_args[0][1]
cache = spy.call_args.kwargs["cache"]

assert cache.maxsize == 2
assert cache.currsize == 2
Expand All @@ -222,7 +222,7 @@ def cost(a, cachesize):
def test_custom_cache(self, mocker):
"""Test the use of a custom cache object"""
dev = qml.device("default.qubit.legacy", wires=1)
spy = mocker.spy(qml.workflow, "cache_execute")
spy = mocker.spy(qml.workflow.execution, "_cache_transform")
Mandrenkov marked this conversation as resolved.
Show resolved Hide resolved

def cost(a, cache):
with qml.queuing.AnnotatedQueue() as q:
Expand All @@ -242,13 +242,13 @@ def cost(a, cache):
params = jax.numpy.array([0.1, 0.2])
jax.grad(cost)(params, cache=custom_cache)

cache = spy.call_args[0][1]
cache = spy.call_args.kwargs["cache"]
assert cache is custom_cache

def test_custom_cache_multiple(self, mocker):
"""Test the use of a custom cache object with multiple tapes"""
dev = qml.device("default.qubit.legacy", wires=1)
spy = mocker.spy(qml.workflow, "cache_execute")
spy = mocker.spy(qml.workflow.execution, "_cache_transform")
Mandrenkov marked this conversation as resolved.
Show resolved Hide resolved

a = jax.numpy.array(0.1)
b = jax.numpy.array(0.2)
Expand Down Expand Up @@ -277,7 +277,7 @@ def cost(a, b, cache):
custom_cache = {}
jax.grad(cost)(a, b, cache=custom_cache)

cache = spy.call_args[0][1]
cache = spy.call_args.kwargs["cache"]
assert cache is custom_cache

def test_caching_param_shift(self, tol):
Expand Down
12 changes: 6 additions & 6 deletions tests/interfaces/test_jax_jit.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ class TestCaching:
def test_cache_maxsize(self, mocker):
"""Test the cachesize property of the cache"""
dev = qml.device("default.qubit.legacy", wires=1)
spy = mocker.spy(qml.workflow, "cache_execute")
spy = mocker.spy(qml.workflow.execution, "_cache_transform")
Mandrenkov marked this conversation as resolved.
Show resolved Hide resolved

def cost(a, cachesize):
with qml.queuing.AnnotatedQueue() as q:
Expand All @@ -203,7 +203,7 @@ def cost(a, cachesize):

params = jax.numpy.array([0.1, 0.2])
jax.jit(jax.grad(cost), static_argnums=1)(params, cachesize=2)
cache = spy.call_args[0][1]
cache = spy.call_args.kwargs["cache"]

assert cache.maxsize == 2
assert cache.currsize == 2
Expand All @@ -212,7 +212,7 @@ def cost(a, cachesize):
def test_custom_cache(self, mocker):
"""Test the use of a custom cache object"""
dev = qml.device("default.qubit.legacy", wires=1)
spy = mocker.spy(qml.workflow, "cache_execute")
spy = mocker.spy(qml.workflow.execution, "_cache_transform")
Mandrenkov marked this conversation as resolved.
Show resolved Hide resolved

def cost(a, cache):
with qml.queuing.AnnotatedQueue() as q:
Expand All @@ -233,13 +233,13 @@ def cost(a, cache):
params = jax.numpy.array([0.1, 0.2])
jax.grad(cost)(params, cache=custom_cache)

cache = spy.call_args[0][1]
cache = spy.call_args.kwargs["cache"]
assert cache is custom_cache

def test_custom_cache_multiple(self, mocker):
"""Test the use of a custom cache object with multiple tapes"""
dev = qml.device("default.qubit.legacy", wires=1)
spy = mocker.spy(qml.workflow, "cache_execute")
spy = mocker.spy(qml.workflow.execution, "_cache_transform")
Mandrenkov marked this conversation as resolved.
Show resolved Hide resolved

a = jax.numpy.array(0.1)
b = jax.numpy.array(0.2)
Expand Down Expand Up @@ -270,7 +270,7 @@ def cost(a, b, cache):
custom_cache = {}
jax.grad(cost)(a, b, cache=custom_cache)

cache = spy.call_args[0][1]
cache = spy.call_args.kwargs["cache"]
assert cache is custom_cache

def test_caching_param_shift(self, tol):
Expand Down
Loading
Loading