Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add batch_partial implementation #2585

Merged
merged 29 commits into from
May 31, 2022
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
c65d301
Add batch_partial implementation
May 13, 2022
8c80a51
fix codefactor errors
May 17, 2022
454c8e8
Add tests for all interfaces
May 17, 2022
d953e7c
Fix tf test
May 18, 2022
7117fcc
Add docs and a couple more tests
May 18, 2022
90303a6
Improve docs and fix broken tests
May 18, 2022
abb5dab
Fix indent error
May 18, 2022
515b13d
Merge branch 'master' into batch_partial_qnode
eddddddy May 19, 2022
b669d45
Add changelog entry
May 19, 2022
24fab51
Update pennylane/transforms/batch_partial.py
eddddddy May 20, 2022
26f2079
Update pennylane/transforms/batch_partial.py
eddddddy May 20, 2022
fd76986
Update pennylane/transforms/batch_partial.py
eddddddy May 20, 2022
b826dc0
Update doc/releases/changelog-dev.md
eddddddy May 20, 2022
7aaf29e
Change example in docstring
May 20, 2022
b5cfe4b
Change test to match new error message
May 20, 2022
797160e
Update tests/transforms/test_batch_partial.py
eddddddy May 20, 2022
365a63c
Update tests/transforms/test_batch_partial.py
eddddddy May 20, 2022
5434172
Update tests/transforms/test_batch_partial.py
eddddddy May 20, 2022
037b673
Update tests/transforms/test_batch_partial.py
eddddddy May 20, 2022
ed1bd58
Update tests/transforms/test_batch_partial.py
eddddddy May 20, 2022
111b22e
Update pennylane/transforms/batch_partial.py
eddddddy May 20, 2022
5d2b11d
Update pennylane/transforms/batch_partial.py
eddddddy May 22, 2022
04a99ed
Changes for review
May 24, 2022
cce077b
Add test for coverage
May 24, 2022
f127987
Apply suggestions from code review
eddddddy May 26, 2022
ecd0ee4
Merge branch 'master' into batch_partial_qnode
eddddddy May 26, 2022
92df7be
rerun ci
May 26, 2022
13e0fee
Merge branch 'master' into batch_partial_qnode
antalszava May 31, 2022
37c2b5b
Remove unused import
May 31, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 23 additions & 2 deletions doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
* Boolean mask indexing of the parameter-shift Hessian
[(#2538)](https://github.com/PennyLaneAI/pennylane/pull/2538)

The `argnum` keyword argument for `param_shift_hessian`
The `argnum` keyword argument for `param_shift_hessian`
is now allowed to be a twodimensional Boolean `array_like`.
Only the indicated entries of the Hessian will then be computed.
A particularly useful example is the computation of the diagonal
Expand Down Expand Up @@ -37,6 +37,27 @@
The code that checks for qubit wise commuting (QWC) got a performance boost that is noticable
when many commuting paulis of the same type are measured.

* Added new transform `qml.batch_partial` which behaves similarly to `functools.partial` but supports batching in the unevaluated parameters.
[(#2585)](https://github.com/PennyLaneAI/pennylane/pull/2585)

This is useful for batching circuit executions with some identical parameters but not others:
eddddddy marked this conversation as resolved.
Show resolved Hide resolved

```python
dev = qml.device("default.qubit", wires=1)

@qml.qnode(dev)
def circuit(x, y):
qml.RX(x, wires=0)
qml.RY(y, wires=0)
return qml.expval(qml.PauliZ(wires=0))
```
```pycon
>>> batched_partial_circuit = qml.batch_partial(circuit, x=np.array(np.pi / 2))
>>> y = np.array([0.2, 0.3, 0.4])
>>> batched_partial_circuit(y=y)
tensor([0.69301172, 0.67552491, 0.65128847], requires_grad=True)
eddddddy marked this conversation as resolved.
Show resolved Hide resolved
```

<h3>Improvements</h3>

* The developer-facing `pow` method has been added to `Operator` with concrete implementations
Expand Down Expand Up @@ -111,7 +132,7 @@

<h3>Bug fixes</h3>

* `QNode`'s now can interpret variations on the interface name, like `"tensorflow"` or `"jax-jit"`, when requesting backpropagation.
* `QNode`'s now can interpret variations on the interface name, like `"tensorflow"` or `"jax-jit"`, when requesting backpropagation.
[(#2591)](https://github.com/PennyLaneAI/pennylane/pull/2591)

* Fixed a bug for `diff_method="adjoint"` where incorrect gradients were
Expand Down
1 change: 1 addition & 0 deletions pennylane/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@
batch_params,
batch_input,
batch_transform,
batch_partial,
cut_circuit,
cut_circuit_mc,
ControlledOperation,
Expand Down
2 changes: 2 additions & 0 deletions pennylane/transforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
~transforms.classical_jacobian
~batch_params
~batch_input
~batch_partial
~metric_tensor
~adjoint_metric_tensor
~specs
Expand Down Expand Up @@ -179,6 +180,7 @@
from .adjoint import adjoint
from .batch_params import batch_params
from .batch_input import batch_input
from .batch_partial import batch_partial
from .classical_jacobian import classical_jacobian
from .condition import cond, Conditional
from .compile import compile
Expand Down
180 changes: 180 additions & 0 deletions pennylane/transforms/batch_partial.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,180 @@
# Copyright 2022 Xanadu Quantum Technologies Inc.

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

# http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Contains the batch dimension transform.
eddddddy marked this conversation as resolved.
Show resolved Hide resolved
"""
import functools
import inspect

import pennylane as qml


def _convert_to_args(func, args, kwargs):
"""
Given a function, convert the positional and
keyword arguments to purely positional arguments.
"""
sig = inspect.signature(func).parameters
eddddddy marked this conversation as resolved.
Show resolved Hide resolved

new_args = []
for i, param in enumerate(sig):
if param in kwargs:
# first check if the name is provided in kwargs
new_args.append(kwargs[param])
elif i < len(sig):
antalszava marked this conversation as resolved.
Show resolved Hide resolved
# next check if the argnum is provided
new_args.append(args[i])

return tuple(new_args)


def batch_partial(qnode, all_operations=False, **partial_kwargs):
eddddddy marked this conversation as resolved.
Show resolved Hide resolved
"""
Create a wrapper function around the QNode with partially
evaluated parameters, which supports an initial batch dimension.
eddddddy marked this conversation as resolved.
Show resolved Hide resolved

Args:
qnode (pennylane.QNode): QNode to partially evaluate
eddddddy marked this conversation as resolved.
Show resolved Hide resolved
all_operations (bool): If ``True``, a batch dimension will be added to *all* operations
in the QNode, rather than just trainable QNode parameters.
partial_kwargs (dict): partially-evaluated parameters to pass to the QNode
eddddddy marked this conversation as resolved.
Show resolved Hide resolved

Returns:
func: Function which accepts the same arguments as the QNode minus the
partially evaluated arguments provided, and behaves the same as the QNode
called with both the partially evaluated arguments and the extra arguments.
eddddddy marked this conversation as resolved.
Show resolved Hide resolved
However, the first dimension of each argument of the returned function
will be treated as a batch dimension. The function output will also contain
an initial batch dimension.

**Example**

Consider the following circuit:

.. code-block:: python

dev = qml.device("default.qubit", wires=2)

@qml.qnode(dev)
def circuit(x, y):
qml.RX(x, wires=0)
qml.RY(y[..., 0], wires=0)
qml.RY(y[..., 1], wires=1)
eddddddy marked this conversation as resolved.
Show resolved Hide resolved
return qml.expval(qml.PauliZ(wires=0) @ qml.PauliZ(wires=1))

The ``qml.batch_partial`` decorator allows us to create a partially evaluated
function that wraps the QNode. For example,
eddddddy marked this conversation as resolved.
Show resolved Hide resolved

>>> y = np.array([0.2, 0.3])
>>> batched_partial_circuit = qml.batch_partial(circuit, y=y)
dwierichs marked this conversation as resolved.
Show resolved Hide resolved

The unevaluated arguments of the resulting function must now have a batch
dimension, and the output of the function also has a batch dimension:

>>> batch_size = 4
>>> x = np.linspace(0.1, 0.5, batch_size)
>>> batched_partial_circuit(x)
tensor([0.9316158 , 0.91092081, 0.87405565, 0.82167473], requires_grad=True)

Gradients can be computed for the arguments of the wrapper function, but
eddddddy marked this conversation as resolved.
Show resolved Hide resolved
not for any partially evaluated arguments passed to ``qml.batch_partial``:
eddddddy marked this conversation as resolved.
Show resolved Hide resolved

>>> qml.jacobian(batched_partial_circuit)(x)
array([[-0.09347337, 0. , 0. , 0. ],
[ 0. , -0.21649144, 0. , 0. ],
[ 0. , 0. , -0.33566648, 0. ],
[ 0. , 0. , 0. , -0.44888295]])

The same ``qml.batch_partial`` decorator can also be used to replace arguments
eddddddy marked this conversation as resolved.
Show resolved Hide resolved
of a QNode with functions, and calling the wrapper would evaluate
those functions and pass the results into the QNode. For example,
eddddddy marked this conversation as resolved.
Show resolved Hide resolved

>>> x = np.array(0.1)
>>> y_fn = lambda y0: y0 * np.array([0.2, 0.3])
>>> batched_lambda_circuit = qml.batch_partial(circuit, x=x, y=y_fn)

The wrapped function ``batched_lambda_circuit`` also expects arguments to
have an initial batch dimension:

>>> batch_size = 4
>>> y0 = np.linspace(0.5, 2, batch_size)
>>> batched_lambda_circuit(y0)
tensor([0.97891628, 0.9316158 , 0.85593241, 0.75638669], requires_grad=True)

Gradients can be computed in this scenario as well:
eddddddy marked this conversation as resolved.
Show resolved Hide resolved

>>> qml.jacobian(batched_lambda_circuit)(y0)
array([[-0.06402847, 0. , 0. , 0. ],
[ 0. , -0.12422434, 0. , 0. ],
[ 0. , 0. , -0.17699293, 0. ],
[ 0. , 0. , 0. , -0.21920062]])
"""
qnode = qml.batch_params(qnode, all_operations=all_operations)

# store whether this decorator is being used as a pure
# analog of functools.partial, or whether it is used to
# wrap a QNode in a more complex lambda statement
is_partial = False
if not any(callable(val) for val in partial_kwargs.values()):
# none of the kwargs passed in are callable
is_partial = True
eddddddy marked this conversation as resolved.
Show resolved Hide resolved

sig = inspect.signature(qnode).parameters
if is_partial:
# the partially evaluated function must have at least one more
eddddddy marked this conversation as resolved.
Show resolved Hide resolved
# parameter, otherwise batching doesn't make sense
if len(sig) <= len(partial_kwargs):
raise ValueError("Partial evaluation must leave at least one unevaluated parameter")
antalszava marked this conversation as resolved.
Show resolved Hide resolved
else:
# if used to wrap a QNode in a lambda statement, then check that
# all arguments are provided
if len(sig) > len(partial_kwargs):
raise ValueError("Callable argument requires all other arguments to QNode be provided")

@functools.wraps(qnode)
def wrapper(*args, **kwargs):

# raise an error if keyword arguments are passed, since the
# arguments are passed to the lambda statement instead of the QNode
if not is_partial and kwargs:
raise ValueError(
"Arguments must not be passed as keyword arguments to "
"callable within partial function"
)

# get the batch dimension (we don't have to check if all arguments
# have the same batch dim since that's done in qml.batch_params)
try:
if args:
batch_dim = qml.math.shape(args[0])[0]
else:
batch_dim = qml.math.shape(list(kwargs.values())[0])[0]
except IndexError:
raise ValueError("Batch dimension must be provided") from None
eddddddy marked this conversation as resolved.
Show resolved Hide resolved

for key, val in partial_kwargs.items():
if callable(val):
unstacked_args = (qml.math.unstack(arg) for arg in args)
val = qml.math.stack([val(*a) for a in zip(*unstacked_args)])
kwargs[key] = val
else:
kwargs[key] = qml.math.stack([val] * batch_dim)

if is_partial:
return qnode(*_convert_to_args(qnode, args, kwargs))

# don't pass the arguments to the lambda itself into the QNode
return qnode(*_convert_to_args(qnode, (), kwargs))

return wrapper
Loading