Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Create VJP perf fix #4806

Merged
merged 6 commits into from
Nov 8, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/interface-unit-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ jobs:
"qcut-tests": ["3.9"],
"qchem-tests": ["3.9"],
"gradients-tests": ["3.9"],
"data-tests": ["3.10"],
"data-tests": ["3.9", "3.10"],
"device-tests": ["3.9"]
}
EOF
Expand Down
4 changes: 4 additions & 0 deletions doc/releases/changelog-0.33.1.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@

<h3>Bug fixes 🐛</h3>

* Fix gradient performance regression due to expansion of VJP products.
[(#4806)](https://github.com/PennyLaneAI/pennylane/pull/4806)

* `qml.defer_measurements` now correctly transforms circuits when terminal measurements include wires
used in mid-circuit measurements.
[(#4787)](https://github.com/PennyLaneAI/pennylane/pull/4787)
Expand All @@ -31,4 +34,5 @@
This release contains contributions from (in alphabetical order):

Christina Lee,
Lee James O'Riordan,
Mudit Pandey
70 changes: 47 additions & 23 deletions pennylane/gradients/vjp.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def _all_close_to_zero(dy):
return qml.math.allclose(dy, 0)

# call this method recursively
return qml.math.all(qml.math.stack([_all_close_to_zero(dy_) for dy_ in dy]))
return all(_all_close_to_zero(dy_) for dy_ in dy)
mlxd marked this conversation as resolved.
Show resolved Hide resolved


def compute_vjp_single(dy, jac, num=None):
Expand Down Expand Up @@ -75,7 +75,7 @@ def compute_vjp_single(dy, jac, num=None):
>>> compute_vjp_single(dy, jac)
np.array([0.2])

2. For a single parameter and a single measurment with shape (e.g. probs):
2. For a single parameter and a single measurement with shape (e.g. probs):

.. code-block:: pycon

Expand Down Expand Up @@ -115,16 +115,8 @@ def compute_vjp_single(dy, jac, num=None):
if not isinstance(dy_row, np.ndarray):
jac = _convert(jac, dy_row)

try:
if _all_close_to_zero(dy):
# If the dy vector is zero, then the
# corresponding element of the VJP will be zero.
num_params = len(jac) if isinstance(jac, tuple) else 1

res = qml.math.convert_like(np.zeros(num_params), dy)
return qml.math.cast_like(res, dy)
except (AttributeError, TypeError):
pass
# Note: For generality, all exception type warnings are disabled.
# TODO: Excplictly catalogue and update raises for known types.
mlxd marked this conversation as resolved.
Show resolved Hide resolved

# Single measurement with a single param
if not isinstance(jac, (tuple, autograd.builtins.SequenceBox)):
Expand All @@ -136,7 +128,12 @@ def compute_vjp_single(dy, jac, num=None):
if num == 1:
jac = qml.math.squeeze(jac)
jac = qml.math.reshape(jac, (-1, 1))
res = qml.math.tensordot(jac, dy_row, [[0], [0]])
try:
rmoyard marked this conversation as resolved.
Show resolved Hide resolved
res = dy_row @ jac

except Exception: # pylint: disable=broad-except
res = qml.math.tensordot(jac, dy_row, [[0], [0]])

# Single measurement with multiple params
else:
# No trainable parameters (adjoint)
Expand All @@ -146,12 +143,19 @@ def compute_vjp_single(dy, jac, num=None):
# Single measurement with no dimension e.g. expval
if num == 1:
jac = qml.math.reshape(qml.math.stack(jac), (1, -1))
res = qml.math.tensordot(jac, dy_row, [[0], [0]])
try:
res = dy_row @ jac
except Exception: # pylint: disable=broad-except
res = qml.math.tensordot(jac, dy_row, [[0], [0]])

# Single measurement with dimension e.g. probs
else:
jac = qml.math.stack(jac)
res = qml.math.tensordot(jac, dy_row, [[1], [0]])
try:
res = jac @ dy_row
except Exception: # pylint: disable=broad-except
res = qml.math.tensordot(jac, dy_row, [[1], [0]])

return res


Expand Down Expand Up @@ -202,14 +206,34 @@ def compute_vjp_multi(dy, jac, num=None):
res = qml.math.sum(qml.math.stack(res), axis=0)
# Multiple parameters
else:
res = []
for d, j_ in zip(dy, jac):
sub_res = []
for j in j_:
sub_res.append(qml.math.squeeze(compute_vjp_single(d, j, num=num)))
res.append(sub_res)
res = qml.math.stack([qml.math.stack(r) for r in res])
res = qml.math.sum(res, axis=0)
try:
rmoyard marked this conversation as resolved.
Show resolved Hide resolved
dy_interface = qml.math.get_interface(dy[0])
# dy -> (i,j) observables, entries per observable
# jac -> (i,k,j) observables, parameters, entries per observable
# Contractions over observables and entries per observable
dy_shape = qml.math.shape(dy)
if len(dy_shape) > 1: # multiple values exist per observable output
return qml.math.array(qml.math.einsum("ij,i...j", dy, jac), like=dy[0])

if dy_interface == "tensorflow":
rmoyard marked this conversation as resolved.
Show resolved Hide resolved
# TF needs a different path for Hessian support
return qml.math.array(
qml.math.einsum("i,i...", dy, jac, like=dy[0]), like=dy[0]
) # Scalar value per observable output
return qml.math.array(
qml.math.einsum("i,i...", dy, jac), like=dy[0]
) # Scalar value per observable output
# NOTE: We want any potential failure to fall back here, so catch every exception type
# TODO: Catalogue and update for expected exception types
except Exception: # pylint: disable=broad-except
res = []
for d, j_ in zip(dy, jac):
sub_res = []
for j in j_:
sub_res.append(qml.math.squeeze(compute_vjp_single(d, j, num=num)))
res.append(sub_res)
res = qml.math.stack([qml.math.stack(r) for r in res])
res = qml.math.sum(res, axis=0)
return res


Expand Down
7 changes: 7 additions & 0 deletions pennylane/math/multi_dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -540,6 +540,13 @@ def einsum(indices, *operands, like=None, optimize=None):
if optimize is None or like == "torch":
# torch einsum doesn't support the optimize keyword argument
return np.einsum(indices, *operands, like=like)
if like == "tensorflow":
# Unpacking and casting necessary for higher order derivatives,
# and avoiding implicit fp32 down-conversions.
op1, op2 = operands
op1 = array(op1, like=op1[0], dtype=op1[0].dtype)
op2 = array(op2, like=op2[0], dtype=op2[0].dtype)
return np.einsum(indices, op1, op2, like=like)
return np.einsum(indices, *operands, like=like, optimize=optimize)


Expand Down
2 changes: 1 addition & 1 deletion tests/gradients/core/test_vjp.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ def test_zero_dy_single_measurement_single_params(self):

def test_zero_dy_single_measurement_multi_params(self):
"""A zero dy vector will return a zero matrix"""
dy = np.zeros([2])
dy = np.zeros(1)
jac = tuple([np.array(0.1), np.array(0.2)])

vjp = qml.gradients.compute_vjp_single(dy, jac)
Expand Down
9 changes: 9 additions & 0 deletions tests/gradients/finite_diff/test_finite_difference.py
Original file line number Diff line number Diff line change
Expand Up @@ -477,6 +477,15 @@ def __init__(self, val):
def __mul__(self, other):
return SpecialObject(self.val * other)

def __rmul__(self, other):
return SpecialObject(other * self.val)

def __matmul__(self, other):
return SpecialObject(self.val @ other)

def __rmatmul__(self, other):
return SpecialObject(other @ self.val)

def __add__(self, other):
new = self.val + (other.val if isinstance(other, self.__class__) else other)
return SpecialObject(new)
Expand Down
Loading