Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Replacing map_batch_transform in source code #5212

Merged
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -373,6 +373,10 @@

<h3>Deprecations 👋</h3>

* Replacing `map_batch_transform` in the source code with the method `_batch_transform`
implemented in `TransformDispatcher`.
[(#5212)](https://github.com/PennyLaneAI/pennylane/pull/5212)

* `TransformDispatcher` can now dispatch onto a batch of tapes, so that it is easier to compose transforms
when working in the tape paradigm.
[(#5163)](https://github.com/PennyLaneAI/pennylane/pull/5163)
Expand Down
4 changes: 1 addition & 3 deletions pennylane/_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -803,9 +803,7 @@ def hamiltonian_fn(res):
return circuits, hamiltonian_fn

# Expand each of the broadcasted Hamiltonian-expanded circuits
expanded_tapes, expanded_fn = qml.transforms.map_batch_transform(
qml.transforms.broadcast_expand, circuits
)
expanded_tapes, expanded_fn = qml.transforms.broadcast_expand(circuits)

# Chain the postprocessing functions of the broadcasted-tape expansions and the Hamiltonian
# expansion. Note that the application order is reversed compared to the expansion order,
Expand Down
5 changes: 5 additions & 0 deletions pennylane/transforms/mitigate.py
Original file line number Diff line number Diff line change
Expand Up @@ -525,6 +525,11 @@ def circuit(w1, w2):

def processing_fn(results):
"""Maps from input tape executions to an error-mitigated estimate"""

# content of `results` must be modified in this post-processing function
if isinstance(results, tuple):
results = list(results)

PietropaoloFrisoni marked this conversation as resolved.
Show resolved Hide resolved
for i, tape in enumerate(out_tapes):
# stack the results if there are multiple measurements
# this will not create ragged arrays since only expval measurements are allowed
Expand Down
6 changes: 4 additions & 2 deletions pennylane/workflow/execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,8 +196,10 @@ def _batch_transform(
"""
# TODO: Remove once old device are removed
if device_batch_transform:
dev_batch_transform = set_shots(device, override_shots)(device.batch_transform)
return *qml.transforms.map_batch_transform(dev_batch_transform, tapes), config
dev_batch_transform = qml.transform(
set_shots(device, override_shots)(device.batch_transform)
)
PietropaoloFrisoni marked this conversation as resolved.
Show resolved Hide resolved
return *dev_batch_transform(tapes), config

def null_post_processing_fn(results):
"""A null post processing function used because the user requested not to use the device batch transform."""
Expand Down
11 changes: 2 additions & 9 deletions pennylane/workflow/jacobian_products.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
Defines classes that take the vjps, jvps, and jacobians of circuits.
"""
import abc
from functools import partial
import inspect
import logging
from typing import Tuple, Callable, Optional, Union
Expand Down Expand Up @@ -310,10 +309,7 @@ def execute_and_compute_jacobian(self, tapes: Batch):

num_result_tapes = len(tapes)

partial_gradient_fn = partial(self._gradient_transform, **self._gradient_kwargs)
jac_tapes, jac_postprocessing = qml.transforms.map_batch_transform(
partial_gradient_fn, tapes
)
jac_tapes, jac_postprocessing = self._gradient_transform(tapes, **self._gradient_kwargs)

full_batch = tapes + tuple(jac_tapes)
full_results = self._inner_execute(full_batch)
Expand All @@ -327,10 +323,7 @@ def compute_jacobian(self, tapes: Batch):
logger.debug("compute_jacobian called with %s", tapes)
if tapes in self._cache:
return self._cache[tapes]
partial_gradient_fn = partial(self._gradient_transform, **self._gradient_kwargs)
jac_tapes, batch_post_processing = qml.transforms.map_batch_transform(
partial_gradient_fn, tapes
)
jac_tapes, batch_post_processing = self._gradient_transform(tapes, **self._gradient_kwargs)
results = self._inner_execute(jac_tapes)
jacs = tuple(batch_post_processing(results))
self._cache[tapes] = jacs
Expand Down
2 changes: 1 addition & 1 deletion tests/interfaces/test_autograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,7 @@ def test_batch_transform_dynamic_shots(self):
H = 2.0 * qml.PauliZ(0)
qscript = qml.tape.QuantumScript(measurements=[qml.expval(H)])
res = qml.execute([qscript], dev, interface=None, override_shots=10)
assert res == [2.0]
assert res == (2.0,)


class TestCaching:
Expand Down
5 changes: 3 additions & 2 deletions tests/interfaces/test_jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

import pennylane as qml
from pennylane.gradients import param_shift
from pennylane.typing import TensorLike
from pennylane import execute

pytestmark = pytest.mark.jax
Expand Down Expand Up @@ -576,7 +577,7 @@ def cost_fn(x):
return execute(tapes=[tape1, tape2], device=dev, **execute_kwargs)

res = cost_fn(params)
assert isinstance(res, list)
assert isinstance(res, TensorLike)
assert all(isinstance(r, jax.numpy.ndarray) for r in res)
assert all(r.shape == () for r in res)

Expand Down Expand Up @@ -856,7 +857,7 @@ def cost(x, y, device, interface, ek):
x, y, dev, interface="jax-python", ek=execute_kwargs
)

assert isinstance(res, list)
assert isinstance(res, TensorLike)
assert len(res) == 2

for r, exp_shape in zip(res, [(), (2,)]):
Expand Down
3 changes: 2 additions & 1 deletion tests/interfaces/test_jax_jit.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

import pennylane as qml
from pennylane.gradients import param_shift
from pennylane.typing import TensorLike
from pennylane import execute

pytestmark = pytest.mark.jax
Expand Down Expand Up @@ -580,7 +581,7 @@ def cost_fn(x):
return execute(tapes=[tape1, tape2], device=dev, **execute_kwargs)

res = jax.jit(cost_fn)(params)
assert isinstance(res, list)
assert isinstance(res, TensorLike)
assert all(isinstance(r, jax.numpy.ndarray) for r in res)
assert all(r.shape == () for r in res)

Expand Down
3 changes: 2 additions & 1 deletion tests/interfaces/test_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

import pennylane as qml
from pennylane.gradients import finite_diff, param_shift
from pennylane.typing import TensorLike
from pennylane import execute

pytestmark = pytest.mark.torch
Expand Down Expand Up @@ -421,7 +422,7 @@ def test_execution(self, torch_device, execute_kwargs):

res = execute([tape1, tape2], dev, **execute_kwargs)

assert isinstance(res, list)
assert isinstance(res, TensorLike)
assert len(res) == 2
assert res[0].shape == ()
assert res[1].shape == ()
Expand Down
Loading