Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add batch_partial implementation #2585

Merged
merged 29 commits into from
May 31, 2022
Merged
Show file tree
Hide file tree
Changes from 24 commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
c65d301
Add batch_partial implementation
May 13, 2022
8c80a51
fix codefactor errors
May 17, 2022
454c8e8
Add tests for all interfaces
May 17, 2022
d953e7c
Fix tf test
May 18, 2022
7117fcc
Add docs and a couple more tests
May 18, 2022
90303a6
Improve docs and fix broken tests
May 18, 2022
abb5dab
Fix indent error
May 18, 2022
515b13d
Merge branch 'master' into batch_partial_qnode
eddddddy May 19, 2022
b669d45
Add changelog entry
May 19, 2022
24fab51
Update pennylane/transforms/batch_partial.py
eddddddy May 20, 2022
26f2079
Update pennylane/transforms/batch_partial.py
eddddddy May 20, 2022
fd76986
Update pennylane/transforms/batch_partial.py
eddddddy May 20, 2022
b826dc0
Update doc/releases/changelog-dev.md
eddddddy May 20, 2022
7aaf29e
Change example in docstring
May 20, 2022
b5cfe4b
Change test to match new error message
May 20, 2022
797160e
Update tests/transforms/test_batch_partial.py
eddddddy May 20, 2022
365a63c
Update tests/transforms/test_batch_partial.py
eddddddy May 20, 2022
5434172
Update tests/transforms/test_batch_partial.py
eddddddy May 20, 2022
037b673
Update tests/transforms/test_batch_partial.py
eddddddy May 20, 2022
ed1bd58
Update tests/transforms/test_batch_partial.py
eddddddy May 20, 2022
111b22e
Update pennylane/transforms/batch_partial.py
eddddddy May 20, 2022
5d2b11d
Update pennylane/transforms/batch_partial.py
eddddddy May 22, 2022
04a99ed
Changes for review
May 24, 2022
cce077b
Add test for coverage
May 24, 2022
f127987
Apply suggestions from code review
eddddddy May 26, 2022
ecd0ee4
Merge branch 'master' into batch_partial_qnode
eddddddy May 26, 2022
92df7be
rerun ci
May 26, 2022
13e0fee
Merge branch 'master' into batch_partial_qnode
antalszava May 31, 2022
37c2b5b
Remove unused import
May 31, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 23 additions & 2 deletions doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
* Boolean mask indexing of the parameter-shift Hessian
[(#2538)](https://github.com/PennyLaneAI/pennylane/pull/2538)

The `argnum` keyword argument for `param_shift_hessian`
The `argnum` keyword argument for `param_shift_hessian`
is now allowed to be a twodimensional Boolean `array_like`.
Only the indicated entries of the Hessian will then be computed.
A particularly useful example is the computation of the diagonal
Expand Down Expand Up @@ -37,6 +37,27 @@
The code that checks for qubit wise commuting (QWC) got a performance boost that is noticable
when many commuting paulis of the same type are measured.

* Added new transform `qml.batch_partial` which behaves similarly to `functools.partial` but supports batching in the unevaluated parameters.
[(#2585)](https://github.com/PennyLaneAI/pennylane/pull/2585)

This is useful for executing a circuit with a batch dimension in some of its parameters:

```python
dev = qml.device("default.qubit", wires=1)

@qml.qnode(dev)
def circuit(x, y):
qml.RX(x, wires=0)
qml.RY(y, wires=0)
return qml.expval(qml.PauliZ(wires=0))
```
```pycon
>>> batched_partial_circuit = qml.batch_partial(circuit, x=np.array(np.pi / 2))
>>> y = np.array([0.2, 0.3, 0.4])
>>> batched_partial_circuit(y=y)
tensor([0.69301172, 0.67552491, 0.65128847], requires_grad=True)
eddddddy marked this conversation as resolved.
Show resolved Hide resolved
```

<h3>Improvements</h3>

* The developer-facing `pow` method has been added to `Operator` with concrete implementations
Expand Down Expand Up @@ -111,7 +132,7 @@

<h3>Bug fixes</h3>

* `QNode`'s now can interpret variations on the interface name, like `"tensorflow"` or `"jax-jit"`, when requesting backpropagation.
* `QNode`'s now can interpret variations on the interface name, like `"tensorflow"` or `"jax-jit"`, when requesting backpropagation.
[(#2591)](https://github.com/PennyLaneAI/pennylane/pull/2591)

* Fixed a bug for `diff_method="adjoint"` where incorrect gradients were
Expand Down
1 change: 1 addition & 0 deletions pennylane/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@
batch_params,
batch_input,
batch_transform,
batch_partial,
cut_circuit,
cut_circuit_mc,
ControlledOperation,
Expand Down
2 changes: 2 additions & 0 deletions pennylane/transforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
~transforms.classical_jacobian
~batch_params
~batch_input
~batch_partial
~metric_tensor
~adjoint_metric_tensor
~specs
Expand Down Expand Up @@ -179,6 +180,7 @@
from .adjoint import adjoint
from .batch_params import batch_params
from .batch_input import batch_input
from .batch_partial import batch_partial
from .classical_jacobian import classical_jacobian
from .condition import cond, Conditional
from .compile import compile
Expand Down
193 changes: 193 additions & 0 deletions pennylane/transforms/batch_partial.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,193 @@
# Copyright 2022 Xanadu Quantum Technologies Inc.

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

# http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Contains the batch dimension transform for partial use of QNodes.
"""
import functools
import inspect

import pennylane as qml
from pennylane import numpy as np


def _convert_to_args(sig, args, kwargs):
"""
Given the signature of a function, convert the positional and
keyword arguments to purely positional arguments.
"""
new_args = []
for i, param in enumerate(sig):
if param in kwargs:
# first check if the name is provided in the keyword arguments
new_args.append(kwargs[param])
else:
# if not, then the argument must be positional
new_args.append(args[i])

return tuple(new_args)


def batch_partial(qnode, all_operations=False, preprocess=None, **partial_kwargs):
"""
Create a batched partial callable object from the QNode specified.

This transform provides functionality akin to `functools.partial` and
allows batching the arguments used for calling the batched partial object.

Args:
qnode (pennylane.QNode): QNode to pre-supply arguments to
all_operations (bool): If ``True``, a batch dimension will be added to *all* operations
in the QNode, rather than just trainable QNode parameters.
partial_kwargs (dict): pre-supplied arguments to pass to the QNode.

Returns:
func: Function which wraps the QNode and accepts the same arguments minus the
pre-supplied arguments provided, and behaves the same as the QNode called with
both the pre-supplied arguments and the other arguments passed to this wrapper
function. However, the first dimension of each argument of the wrapper function
will be treated as a batch dimension. The function output will also contain
an initial batch dimension.

**Example**

Consider the following circuit:

.. code-block:: python

dev = qml.device("default.qubit", wires=2)

@qml.qnode(dev)
def circuit(x, y):
qml.RX(x, wires=0)
qml.RY(y, wires=1)
return qml.expval(qml.PauliZ(wires=0) @ qml.PauliZ(wires=1))

The ``qml.batch_partial`` decorator allows us to create a partially evaluated
function that wraps the QNode. For example,
eddddddy marked this conversation as resolved.
Show resolved Hide resolved

>>> y = np.array(0.2)
>>> batched_partial_circuit = qml.batch_partial(circuit, y=y)

The unevaluated arguments of the resulting function must now have a batch
dimension, and the output of the function also has a batch dimension:

>>> batch_size = 4
>>> x = np.linspace(0.1, 0.5, batch_size)
>>> batched_partial_circuit(x)
tensor([0.97517033, 0.95350781, 0.91491915, 0.86008934], requires_grad=True)

Jacobians can be computed for the arguments of the wrapper function, but
not for any partially evaluated arguments passed to ``qml.batch_partial``:
eddddddy marked this conversation as resolved.
Show resolved Hide resolved

>>> qml.jacobian(batched_partial_circuit)(x)
array([[-0.0978434 , 0. , 0. , 0. ],
[ 0. , -0.22661276, 0. , 0. ],
[ 0. , 0. , -0.35135943, 0. ],
[ 0. , 0. , 0. , -0.46986895]])

The same ``qml.batch_partial`` function can also be used to replace arguments
of a QNode with functions, and calling the wrapper would evaluate
those functions and pass the results into the QNode. For example,
eddddddy marked this conversation as resolved.
Show resolved Hide resolved

>>> x = np.array(0.1)
>>> y_fn = lambda y0: y0 * 0.2 + 0.3
>>> batched_lambda_circuit = qml.batch_partial(circuit, x=x, y=y_fn)

The wrapped function ``batched_lambda_circuit`` also expects arguments to
have an initial batch dimension:

>>> batch_size = 4
>>> y0 = np.linspace(0.5, 2, batch_size)
>>> batched_lambda_circuit(y0)
tensor([0.91645953, 0.8731983 , 0.82121237, 0.76102116], requires_grad=True)

Jacobians can be computed in this scenario as well:

>>> qml.jacobian(batched_lambda_circuit)(y0)
array([[-0.07749457, 0. , 0. , 0. ],
[ 0. , -0.09540608, 0. , 0. ],
[ 0. , 0. , -0.11236432, 0. ],
[ 0. , 0. , 0. , -0.12819986]])
"""
qnode = qml.batch_params(qnode, all_operations=all_operations)

preprocess = {} if preprocess is None else preprocess

# store whether this decorator is being used as a pure
# analog of functools.partial, or whether it is used to
# wrap a QNode in a more complex lambda statement
is_partial = preprocess == {}

# determine which arguments need to be stacked along the batch dimension
to_stack = []
antalszava marked this conversation as resolved.
Show resolved Hide resolved
for key, val in partial_kwargs.items():
try:
# check if the value is a tensor
if qml.math.asarray(val).dtype != object:
to_stack.append(key)
except ImportError:
# autoray can't find a backend for val, so it cannot be stacked
pass

sig = inspect.signature(qnode).parameters
if is_partial:
# the partially evaluated function must have at least one more
eddddddy marked this conversation as resolved.
Show resolved Hide resolved
# parameter, otherwise batching doesn't make sense
if len(sig) <= len(partial_kwargs):
raise ValueError("Partial evaluation must leave at least one unevaluated parameter")
antalszava marked this conversation as resolved.
Show resolved Hide resolved
else:
# if used to wrap a QNode in a lambda statement, then check that
# all arguments are provided
if len(sig) > len(partial_kwargs) + len(preprocess):
raise ValueError("Callable argument requires all other arguments to QNode be provided")

@functools.wraps(qnode)
def wrapper(*args, **kwargs):

# raise an error if keyword arguments are passed, since the
# arguments are passed to the lambda statement instead of the QNode
if not is_partial and kwargs:
raise ValueError(
"Arguments must not be passed as keyword arguments to "
"callable within partial function"
)

# get the batch dimension (we don't have to check if all arguments
# have the same batch dim since that's done in qml.batch_params)
try:
if args:
batch_dim = qml.math.shape(args[0])[0]
else:
batch_dim = qml.math.shape(list(kwargs.values())[0])[0]
except IndexError:
raise ValueError("Parameter with batch dimension must be provided") from None

for key, val in preprocess.items():
unstacked_args = (qml.math.unstack(arg) for arg in args)
val = qml.math.stack([val(*a) for a in zip(*unstacked_args)])
kwargs[key] = val

for key, val in partial_kwargs.items():
if key in to_stack:
kwargs[key] = qml.math.stack([val] * batch_dim)
else:
kwargs[key] = val

if is_partial:
return qnode(*_convert_to_args(sig, args, kwargs))

# don't pass the arguments to the lambda itself into the QNode
return qnode(*_convert_to_args(sig, (), kwargs))

return wrapper
Loading