Skip to content

Commit

Permalink
Create VJP perf fix (#4806)
Browse files Browse the repository at this point in the history
### Before submitting

Please complete the following checklist when submitting a PR:

- [x] All new features must include a unit test.
If you've fixed a bug or added code that should be tested, add a test to
the
      test directory!

- [x] All new functions and code must be clearly commented and
documented.
If you do make documentation changes, make sure that the docs build and
      render correctly by running `make docs`.

- [x] Ensure that the test suite passes, by running `make test`.

- [ ] Add a new entry to the `doc/releases/changelog-dev.md` file,
summarizing the
      change, and including a link back to the PR.

- [x] The PennyLane source code conforms to
      [PEP8 standards](https://www.python.org/dev/peps/pep-0008/).
We check all of our code against [Pylint](https://www.pylint.org/).
      To lint modified files, simply `pip install pylint`, and then
      run `pylint pennylane/path/to/file.py`.

When all the above are checked, delete everything above the dashed
line and fill in the pull request template.


------------------------------------------------------------------------------------------------------------

**Context:** v0.33.1 backport of PR
#4792

**Description of the Change:** Updates VJP pipeline to favour direct
matrix-vector products where possible.

**Benefits:** Improves performance for many parameter/many observable
workloads

**Possible Drawbacks:**

**Related GitHub Issues:**
#4792

---------

Co-authored-by: Matthew Silverman <matthews@xanadu.ai>
  • Loading branch information
mlxd and timmysilv committed Nov 8, 2023
1 parent 98ea7eb commit f3b41b6
Show file tree
Hide file tree
Showing 6 changed files with 69 additions and 25 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/interface-unit-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ jobs:
"qcut-tests": ["3.9"],
"qchem-tests": ["3.9"],
"gradients-tests": ["3.9"],
"data-tests": ["3.10"],
"data-tests": ["3.9", "3.10"],
"device-tests": ["3.9"]
}
EOF
Expand Down
4 changes: 4 additions & 0 deletions doc/releases/changelog-0.33.1.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@

<h3>Bug fixes 🐛</h3>

* Fix gradient performance regression due to expansion of VJP products.
[(#4806)](https://github.com/PennyLaneAI/pennylane/pull/4806)

* `qml.defer_measurements` now correctly transforms circuits when terminal measurements include wires
used in mid-circuit measurements.
[(#4787)](https://github.com/PennyLaneAI/pennylane/pull/4787)
Expand All @@ -31,4 +34,5 @@
This release contains contributions from (in alphabetical order):

Christina Lee,
Lee James O'Riordan,
Mudit Pandey
70 changes: 47 additions & 23 deletions pennylane/gradients/vjp.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def _all_close_to_zero(dy):
return qml.math.allclose(dy, 0)

# call this method recursively
return qml.math.all(qml.math.stack([_all_close_to_zero(dy_) for dy_ in dy]))
return all(_all_close_to_zero(dy_) for dy_ in dy)


def compute_vjp_single(dy, jac, num=None):
Expand Down Expand Up @@ -75,7 +75,7 @@ def compute_vjp_single(dy, jac, num=None):
>>> compute_vjp_single(dy, jac)
np.array([0.2])
2. For a single parameter and a single measurment with shape (e.g. probs):
2. For a single parameter and a single measurement with shape (e.g. probs):
.. code-block:: pycon
Expand Down Expand Up @@ -115,16 +115,8 @@ def compute_vjp_single(dy, jac, num=None):
if not isinstance(dy_row, np.ndarray):
jac = _convert(jac, dy_row)

try:
if _all_close_to_zero(dy):
# If the dy vector is zero, then the
# corresponding element of the VJP will be zero.
num_params = len(jac) if isinstance(jac, tuple) else 1

res = qml.math.convert_like(np.zeros(num_params), dy)
return qml.math.cast_like(res, dy)
except (AttributeError, TypeError):
pass
# Note: For generality, all exception type warnings are disabled.
# TODO: Excplictly catalogue and update raises for known types.

# Single measurement with a single param
if not isinstance(jac, (tuple, autograd.builtins.SequenceBox)):
Expand All @@ -136,7 +128,12 @@ def compute_vjp_single(dy, jac, num=None):
if num == 1:
jac = qml.math.squeeze(jac)
jac = qml.math.reshape(jac, (-1, 1))
res = qml.math.tensordot(jac, dy_row, [[0], [0]])
try:
res = dy_row @ jac

except Exception: # pylint: disable=broad-except
res = qml.math.tensordot(jac, dy_row, [[0], [0]])

# Single measurement with multiple params
else:
# No trainable parameters (adjoint)
Expand All @@ -146,12 +143,19 @@ def compute_vjp_single(dy, jac, num=None):
# Single measurement with no dimension e.g. expval
if num == 1:
jac = qml.math.reshape(qml.math.stack(jac), (1, -1))
res = qml.math.tensordot(jac, dy_row, [[0], [0]])
try:
res = dy_row @ jac
except Exception: # pylint: disable=broad-except
res = qml.math.tensordot(jac, dy_row, [[0], [0]])

# Single measurement with dimension e.g. probs
else:
jac = qml.math.stack(jac)
res = qml.math.tensordot(jac, dy_row, [[1], [0]])
try:
res = jac @ dy_row
except Exception: # pylint: disable=broad-except
res = qml.math.tensordot(jac, dy_row, [[1], [0]])

return res


Expand Down Expand Up @@ -202,14 +206,34 @@ def compute_vjp_multi(dy, jac, num=None):
res = qml.math.sum(qml.math.stack(res), axis=0)
# Multiple parameters
else:
res = []
for d, j_ in zip(dy, jac):
sub_res = []
for j in j_:
sub_res.append(qml.math.squeeze(compute_vjp_single(d, j, num=num)))
res.append(sub_res)
res = qml.math.stack([qml.math.stack(r) for r in res])
res = qml.math.sum(res, axis=0)
try:
dy_interface = qml.math.get_interface(dy[0])
# dy -> (i,j) observables, entries per observable
# jac -> (i,k,j) observables, parameters, entries per observable
# Contractions over observables and entries per observable
dy_shape = qml.math.shape(dy)
if len(dy_shape) > 1: # multiple values exist per observable output
return qml.math.array(qml.math.einsum("ij,i...j", dy, jac), like=dy[0])

if dy_interface == "tensorflow":
# TF needs a different path for Hessian support
return qml.math.array(
qml.math.einsum("i,i...", dy, jac, like=dy[0]), like=dy[0]
) # Scalar value per observable output
return qml.math.array(
qml.math.einsum("i,i...", dy, jac), like=dy[0]
) # Scalar value per observable output
# NOTE: We want any potential failure to fall back here, so catch every exception type
# TODO: Catalogue and update for expected exception types
except Exception: # pylint: disable=broad-except
res = []
for d, j_ in zip(dy, jac):
sub_res = []
for j in j_:
sub_res.append(qml.math.squeeze(compute_vjp_single(d, j, num=num)))
res.append(sub_res)
res = qml.math.stack([qml.math.stack(r) for r in res])
res = qml.math.sum(res, axis=0)
return res


Expand Down
7 changes: 7 additions & 0 deletions pennylane/math/multi_dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -540,6 +540,13 @@ def einsum(indices, *operands, like=None, optimize=None):
if optimize is None or like == "torch":
# torch einsum doesn't support the optimize keyword argument
return np.einsum(indices, *operands, like=like)
if like == "tensorflow":
# Unpacking and casting necessary for higher order derivatives,
# and avoiding implicit fp32 down-conversions.
op1, op2 = operands
op1 = array(op1, like=op1[0], dtype=op1[0].dtype)
op2 = array(op2, like=op2[0], dtype=op2[0].dtype)
return np.einsum(indices, op1, op2, like=like)
return np.einsum(indices, *operands, like=like, optimize=optimize)


Expand Down
2 changes: 1 addition & 1 deletion tests/gradients/core/test_vjp.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ def test_zero_dy_single_measurement_single_params(self):

def test_zero_dy_single_measurement_multi_params(self):
"""A zero dy vector will return a zero matrix"""
dy = np.zeros([2])
dy = np.zeros(1)
jac = tuple([np.array(0.1), np.array(0.2)])

vjp = qml.gradients.compute_vjp_single(dy, jac)
Expand Down
9 changes: 9 additions & 0 deletions tests/gradients/finite_diff/test_finite_difference.py
Original file line number Diff line number Diff line change
Expand Up @@ -477,6 +477,15 @@ def __init__(self, val):
def __mul__(self, other):
return SpecialObject(self.val * other)

def __rmul__(self, other):
return SpecialObject(other * self.val)

def __matmul__(self, other):
return SpecialObject(self.val @ other)

def __rmatmul__(self, other):
return SpecialObject(other @ self.val)

def __add__(self, other):
new = self.val + (other.val if isinstance(other, self.__class__) else other)
return SpecialObject(new)
Expand Down

0 comments on commit f3b41b6

Please sign in to comment.