From f3b41b69cefbbc5cc5246c2b2492df6915f8be34 Mon Sep 17 00:00:00 2001 From: Lee James O'Riordan Date: Wed, 8 Nov 2023 15:10:04 -0500 Subject: [PATCH] Create VJP perf fix (#4806) ### Before submitting Please complete the following checklist when submitting a PR: - [x] All new features must include a unit test. If you've fixed a bug or added code that should be tested, add a test to the test directory! - [x] All new functions and code must be clearly commented and documented. If you do make documentation changes, make sure that the docs build and render correctly by running `make docs`. - [x] Ensure that the test suite passes, by running `make test`. - [ ] Add a new entry to the `doc/releases/changelog-dev.md` file, summarizing the change, and including a link back to the PR. - [x] The PennyLane source code conforms to [PEP8 standards](https://www.python.org/dev/peps/pep-0008/). We check all of our code against [Pylint](https://www.pylint.org/). To lint modified files, simply `pip install pylint`, and then run `pylint pennylane/path/to/file.py`. When all the above are checked, delete everything above the dashed line and fill in the pull request template. ------------------------------------------------------------------------------------------------------------ **Context:** v0.33.1 backport of PR https://github.com/PennyLaneAI/pennylane/pull/4792 **Description of the Change:** Updates VJP pipeline to favour direct matrix-vector products where possible. **Benefits:** Improves performance for many parameter/many observable workloads **Possible Drawbacks:** **Related GitHub Issues:** https://github.com/PennyLaneAI/pennylane/pull/4792 --------- Co-authored-by: Matthew Silverman --- .github/workflows/interface-unit-tests.yml | 2 +- doc/releases/changelog-0.33.1.md | 4 ++ pennylane/gradients/vjp.py | 70 +++++++++++++------ pennylane/math/multi_dispatch.py | 7 ++ tests/gradients/core/test_vjp.py | 2 +- .../finite_diff/test_finite_difference.py | 9 +++ 6 files changed, 69 insertions(+), 25 deletions(-) diff --git a/.github/workflows/interface-unit-tests.yml b/.github/workflows/interface-unit-tests.yml index e78e2e94dcd..0dc3dd25168 100644 --- a/.github/workflows/interface-unit-tests.yml +++ b/.github/workflows/interface-unit-tests.yml @@ -74,7 +74,7 @@ jobs: "qcut-tests": ["3.9"], "qchem-tests": ["3.9"], "gradients-tests": ["3.9"], - "data-tests": ["3.10"], + "data-tests": ["3.9", "3.10"], "device-tests": ["3.9"] } EOF diff --git a/doc/releases/changelog-0.33.1.md b/doc/releases/changelog-0.33.1.md index a31b30d8240..aa7bb956f87 100644 --- a/doc/releases/changelog-0.33.1.md +++ b/doc/releases/changelog-0.33.1.md @@ -14,6 +14,9 @@

Bug fixes 🐛

+* Fix gradient performance regression due to expansion of VJP products. + [(#4806)](https://github.com/PennyLaneAI/pennylane/pull/4806) + * `qml.defer_measurements` now correctly transforms circuits when terminal measurements include wires used in mid-circuit measurements. [(#4787)](https://github.com/PennyLaneAI/pennylane/pull/4787) @@ -31,4 +34,5 @@ This release contains contributions from (in alphabetical order): Christina Lee, +Lee James O'Riordan, Mudit Pandey diff --git a/pennylane/gradients/vjp.py b/pennylane/gradients/vjp.py index ae1ff46501a..50bc96853c3 100644 --- a/pennylane/gradients/vjp.py +++ b/pennylane/gradients/vjp.py @@ -47,7 +47,7 @@ def _all_close_to_zero(dy): return qml.math.allclose(dy, 0) # call this method recursively - return qml.math.all(qml.math.stack([_all_close_to_zero(dy_) for dy_ in dy])) + return all(_all_close_to_zero(dy_) for dy_ in dy) def compute_vjp_single(dy, jac, num=None): @@ -75,7 +75,7 @@ def compute_vjp_single(dy, jac, num=None): >>> compute_vjp_single(dy, jac) np.array([0.2]) - 2. For a single parameter and a single measurment with shape (e.g. probs): + 2. For a single parameter and a single measurement with shape (e.g. probs): .. code-block:: pycon @@ -115,16 +115,8 @@ def compute_vjp_single(dy, jac, num=None): if not isinstance(dy_row, np.ndarray): jac = _convert(jac, dy_row) - try: - if _all_close_to_zero(dy): - # If the dy vector is zero, then the - # corresponding element of the VJP will be zero. - num_params = len(jac) if isinstance(jac, tuple) else 1 - - res = qml.math.convert_like(np.zeros(num_params), dy) - return qml.math.cast_like(res, dy) - except (AttributeError, TypeError): - pass + # Note: For generality, all exception type warnings are disabled. + # TODO: Excplictly catalogue and update raises for known types. # Single measurement with a single param if not isinstance(jac, (tuple, autograd.builtins.SequenceBox)): @@ -136,7 +128,12 @@ def compute_vjp_single(dy, jac, num=None): if num == 1: jac = qml.math.squeeze(jac) jac = qml.math.reshape(jac, (-1, 1)) - res = qml.math.tensordot(jac, dy_row, [[0], [0]]) + try: + res = dy_row @ jac + + except Exception: # pylint: disable=broad-except + res = qml.math.tensordot(jac, dy_row, [[0], [0]]) + # Single measurement with multiple params else: # No trainable parameters (adjoint) @@ -146,12 +143,19 @@ def compute_vjp_single(dy, jac, num=None): # Single measurement with no dimension e.g. expval if num == 1: jac = qml.math.reshape(qml.math.stack(jac), (1, -1)) - res = qml.math.tensordot(jac, dy_row, [[0], [0]]) + try: + res = dy_row @ jac + except Exception: # pylint: disable=broad-except + res = qml.math.tensordot(jac, dy_row, [[0], [0]]) # Single measurement with dimension e.g. probs else: jac = qml.math.stack(jac) - res = qml.math.tensordot(jac, dy_row, [[1], [0]]) + try: + res = jac @ dy_row + except Exception: # pylint: disable=broad-except + res = qml.math.tensordot(jac, dy_row, [[1], [0]]) + return res @@ -202,14 +206,34 @@ def compute_vjp_multi(dy, jac, num=None): res = qml.math.sum(qml.math.stack(res), axis=0) # Multiple parameters else: - res = [] - for d, j_ in zip(dy, jac): - sub_res = [] - for j in j_: - sub_res.append(qml.math.squeeze(compute_vjp_single(d, j, num=num))) - res.append(sub_res) - res = qml.math.stack([qml.math.stack(r) for r in res]) - res = qml.math.sum(res, axis=0) + try: + dy_interface = qml.math.get_interface(dy[0]) + # dy -> (i,j) observables, entries per observable + # jac -> (i,k,j) observables, parameters, entries per observable + # Contractions over observables and entries per observable + dy_shape = qml.math.shape(dy) + if len(dy_shape) > 1: # multiple values exist per observable output + return qml.math.array(qml.math.einsum("ij,i...j", dy, jac), like=dy[0]) + + if dy_interface == "tensorflow": + # TF needs a different path for Hessian support + return qml.math.array( + qml.math.einsum("i,i...", dy, jac, like=dy[0]), like=dy[0] + ) # Scalar value per observable output + return qml.math.array( + qml.math.einsum("i,i...", dy, jac), like=dy[0] + ) # Scalar value per observable output + # NOTE: We want any potential failure to fall back here, so catch every exception type + # TODO: Catalogue and update for expected exception types + except Exception: # pylint: disable=broad-except + res = [] + for d, j_ in zip(dy, jac): + sub_res = [] + for j in j_: + sub_res.append(qml.math.squeeze(compute_vjp_single(d, j, num=num))) + res.append(sub_res) + res = qml.math.stack([qml.math.stack(r) for r in res]) + res = qml.math.sum(res, axis=0) return res diff --git a/pennylane/math/multi_dispatch.py b/pennylane/math/multi_dispatch.py index b8db858f527..d8739c7db7b 100644 --- a/pennylane/math/multi_dispatch.py +++ b/pennylane/math/multi_dispatch.py @@ -540,6 +540,13 @@ def einsum(indices, *operands, like=None, optimize=None): if optimize is None or like == "torch": # torch einsum doesn't support the optimize keyword argument return np.einsum(indices, *operands, like=like) + if like == "tensorflow": + # Unpacking and casting necessary for higher order derivatives, + # and avoiding implicit fp32 down-conversions. + op1, op2 = operands + op1 = array(op1, like=op1[0], dtype=op1[0].dtype) + op2 = array(op2, like=op2[0], dtype=op2[0].dtype) + return np.einsum(indices, op1, op2, like=like) return np.einsum(indices, *operands, like=like, optimize=optimize) diff --git a/tests/gradients/core/test_vjp.py b/tests/gradients/core/test_vjp.py index 0904d1a19c2..b533389cc13 100644 --- a/tests/gradients/core/test_vjp.py +++ b/tests/gradients/core/test_vjp.py @@ -117,7 +117,7 @@ def test_zero_dy_single_measurement_single_params(self): def test_zero_dy_single_measurement_multi_params(self): """A zero dy vector will return a zero matrix""" - dy = np.zeros([2]) + dy = np.zeros(1) jac = tuple([np.array(0.1), np.array(0.2)]) vjp = qml.gradients.compute_vjp_single(dy, jac) diff --git a/tests/gradients/finite_diff/test_finite_difference.py b/tests/gradients/finite_diff/test_finite_difference.py index 0179a3036f6..9f1752dba82 100644 --- a/tests/gradients/finite_diff/test_finite_difference.py +++ b/tests/gradients/finite_diff/test_finite_difference.py @@ -477,6 +477,15 @@ def __init__(self, val): def __mul__(self, other): return SpecialObject(self.val * other) + def __rmul__(self, other): + return SpecialObject(other * self.val) + + def __matmul__(self, other): + return SpecialObject(self.val @ other) + + def __rmatmul__(self, other): + return SpecialObject(other @ self.val) + def __add__(self, other): new = self.val + (other.val if isinstance(other, self.__class__) else other) return SpecialObject(new)