From 36dce77cafc0af70acbd0073c2431de1728cd436 Mon Sep 17 00:00:00 2001 From: "Lee J. O'Riordan" Date: Wed, 8 Nov 2023 10:56:01 -0500 Subject: [PATCH 1/5] Create VJP perf fix --- pennylane/gradients/vjp.py | 70 +++++++++++++------ pennylane/math/multi_dispatch.py | 7 ++ tests/gradients/core/test_vjp.py | 2 +- .../finite_diff/test_finite_difference.py | 9 +++ 4 files changed, 64 insertions(+), 24 deletions(-) diff --git a/pennylane/gradients/vjp.py b/pennylane/gradients/vjp.py index ae1ff46501a..bc9052d8436 100644 --- a/pennylane/gradients/vjp.py +++ b/pennylane/gradients/vjp.py @@ -47,7 +47,7 @@ def _all_close_to_zero(dy): return qml.math.allclose(dy, 0) # call this method recursively - return qml.math.all(qml.math.stack([_all_close_to_zero(dy_) for dy_ in dy])) + return all(_all_close_to_zero(dy_) for dy_ in dy) def compute_vjp_single(dy, jac, num=None): @@ -75,7 +75,7 @@ def compute_vjp_single(dy, jac, num=None): >>> compute_vjp_single(dy, jac) np.array([0.2]) - 2. For a single parameter and a single measurment with shape (e.g. probs): + 2. For a single parameter and a single measurement with shape (e.g. probs): .. code-block:: pycon @@ -115,16 +115,8 @@ def compute_vjp_single(dy, jac, num=None): if not isinstance(dy_row, np.ndarray): jac = _convert(jac, dy_row) - try: - if _all_close_to_zero(dy): - # If the dy vector is zero, then the - # corresponding element of the VJP will be zero. - num_params = len(jac) if isinstance(jac, tuple) else 1 - - res = qml.math.convert_like(np.zeros(num_params), dy) - return qml.math.cast_like(res, dy) - except (AttributeError, TypeError): - pass + # Note: For generality, all exception type warnings are disabled. + # TODO: Excplictly catalogue and update raises for known types. # Single measurement with a single param if not isinstance(jac, (tuple, autograd.builtins.SequenceBox)): @@ -136,7 +128,12 @@ def compute_vjp_single(dy, jac, num=None): if num == 1: jac = qml.math.squeeze(jac) jac = qml.math.reshape(jac, (-1, 1)) - res = qml.math.tensordot(jac, dy_row, [[0], [0]]) + try: + res = dy_row @ jac + + except Exception: # pylint: disable=broad-exception-caught + res = qml.math.tensordot(jac, dy_row, [[0], [0]]) + # Single measurement with multiple params else: # No trainable parameters (adjoint) @@ -146,12 +143,19 @@ def compute_vjp_single(dy, jac, num=None): # Single measurement with no dimension e.g. expval if num == 1: jac = qml.math.reshape(qml.math.stack(jac), (1, -1)) - res = qml.math.tensordot(jac, dy_row, [[0], [0]]) + try: + res = dy_row @ jac + except Exception: # pylint: disable=broad-exception-caught + res = qml.math.tensordot(jac, dy_row, [[0], [0]]) # Single measurement with dimension e.g. probs else: jac = qml.math.stack(jac) - res = qml.math.tensordot(jac, dy_row, [[1], [0]]) + try: + res = jac @ dy_row + except Exception: # pylint: disable=broad-exception-caught + res = qml.math.tensordot(jac, dy_row, [[1], [0]]) + return res @@ -202,14 +206,34 @@ def compute_vjp_multi(dy, jac, num=None): res = qml.math.sum(qml.math.stack(res), axis=0) # Multiple parameters else: - res = [] - for d, j_ in zip(dy, jac): - sub_res = [] - for j in j_: - sub_res.append(qml.math.squeeze(compute_vjp_single(d, j, num=num))) - res.append(sub_res) - res = qml.math.stack([qml.math.stack(r) for r in res]) - res = qml.math.sum(res, axis=0) + try: + dy_interface = qml.math.get_interface(dy[0]) + # dy -> (i,j) observables, entries per observable + # jac -> (i,k,j) observables, parameters, entries per observable + # Contractions over observables and entries per observable + dy_shape = qml.math.shape(dy) + if len(dy_shape) > 1: # multiple values exist per observable output + return qml.math.array(qml.math.einsum("ij,i...j", dy, jac), like=dy[0]) + + if dy_interface == "tensorflow": + # TF needs a different path for Hessian support + return qml.math.array( + qml.math.einsum("i,i...", dy, jac, like=dy[0]), like=dy[0] + ) # Scalar value per observable output + return qml.math.array( + qml.math.einsum("i,i...", dy, jac), like=dy[0] + ) # Scalar value per observable output + # NOTE: We want any potential failure to fall back here, so catch every exception type + # TODO: Catalogue and update for expected exception types + except Exception: # pylint: disable=broad-exception-caught + res = [] + for d, j_ in zip(dy, jac): + sub_res = [] + for j in j_: + sub_res.append(qml.math.squeeze(compute_vjp_single(d, j, num=num))) + res.append(sub_res) + res = qml.math.stack([qml.math.stack(r) for r in res]) + res = qml.math.sum(res, axis=0) return res diff --git a/pennylane/math/multi_dispatch.py b/pennylane/math/multi_dispatch.py index b8db858f527..d8739c7db7b 100644 --- a/pennylane/math/multi_dispatch.py +++ b/pennylane/math/multi_dispatch.py @@ -540,6 +540,13 @@ def einsum(indices, *operands, like=None, optimize=None): if optimize is None or like == "torch": # torch einsum doesn't support the optimize keyword argument return np.einsum(indices, *operands, like=like) + if like == "tensorflow": + # Unpacking and casting necessary for higher order derivatives, + # and avoiding implicit fp32 down-conversions. + op1, op2 = operands + op1 = array(op1, like=op1[0], dtype=op1[0].dtype) + op2 = array(op2, like=op2[0], dtype=op2[0].dtype) + return np.einsum(indices, op1, op2, like=like) return np.einsum(indices, *operands, like=like, optimize=optimize) diff --git a/tests/gradients/core/test_vjp.py b/tests/gradients/core/test_vjp.py index 0904d1a19c2..b533389cc13 100644 --- a/tests/gradients/core/test_vjp.py +++ b/tests/gradients/core/test_vjp.py @@ -117,7 +117,7 @@ def test_zero_dy_single_measurement_single_params(self): def test_zero_dy_single_measurement_multi_params(self): """A zero dy vector will return a zero matrix""" - dy = np.zeros([2]) + dy = np.zeros(1) jac = tuple([np.array(0.1), np.array(0.2)]) vjp = qml.gradients.compute_vjp_single(dy, jac) diff --git a/tests/gradients/finite_diff/test_finite_difference.py b/tests/gradients/finite_diff/test_finite_difference.py index 0179a3036f6..9f1752dba82 100644 --- a/tests/gradients/finite_diff/test_finite_difference.py +++ b/tests/gradients/finite_diff/test_finite_difference.py @@ -477,6 +477,15 @@ def __init__(self, val): def __mul__(self, other): return SpecialObject(self.val * other) + def __rmul__(self, other): + return SpecialObject(other * self.val) + + def __matmul__(self, other): + return SpecialObject(self.val @ other) + + def __rmatmul__(self, other): + return SpecialObject(other @ self.val) + def __add__(self, other): new = self.val + (other.val if isinstance(other, self.__class__) else other) return SpecialObject(new) From dfabb061918deda7683f298d780e6a36a2e6ff1d Mon Sep 17 00:00:00 2001 From: Matthew Silverman Date: Wed, 8 Nov 2023 12:01:12 -0500 Subject: [PATCH 2/5] fix pylint disable with python 3.9 --- pennylane/gradients/vjp.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/pennylane/gradients/vjp.py b/pennylane/gradients/vjp.py index bc9052d8436..50bc96853c3 100644 --- a/pennylane/gradients/vjp.py +++ b/pennylane/gradients/vjp.py @@ -131,7 +131,7 @@ def compute_vjp_single(dy, jac, num=None): try: res = dy_row @ jac - except Exception: # pylint: disable=broad-exception-caught + except Exception: # pylint: disable=broad-except res = qml.math.tensordot(jac, dy_row, [[0], [0]]) # Single measurement with multiple params @@ -145,7 +145,7 @@ def compute_vjp_single(dy, jac, num=None): jac = qml.math.reshape(qml.math.stack(jac), (1, -1)) try: res = dy_row @ jac - except Exception: # pylint: disable=broad-exception-caught + except Exception: # pylint: disable=broad-except res = qml.math.tensordot(jac, dy_row, [[0], [0]]) # Single measurement with dimension e.g. probs @@ -153,7 +153,7 @@ def compute_vjp_single(dy, jac, num=None): jac = qml.math.stack(jac) try: res = jac @ dy_row - except Exception: # pylint: disable=broad-exception-caught + except Exception: # pylint: disable=broad-except res = qml.math.tensordot(jac, dy_row, [[1], [0]]) return res @@ -225,7 +225,7 @@ def compute_vjp_multi(dy, jac, num=None): ) # Scalar value per observable output # NOTE: We want any potential failure to fall back here, so catch every exception type # TODO: Catalogue and update for expected exception types - except Exception: # pylint: disable=broad-exception-caught + except Exception: # pylint: disable=broad-except res = [] for d, j_ in zip(dy, jac): sub_res = [] From 413630022ae06c69ef1f953af92f9a03fca05b0d Mon Sep 17 00:00:00 2001 From: Matthew Silverman Date: Wed, 8 Nov 2023 13:16:02 -0500 Subject: [PATCH 3/5] also run data tests for 3.9 in full test suite --- .github/workflows/interface-unit-tests.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/interface-unit-tests.yml b/.github/workflows/interface-unit-tests.yml index e78e2e94dcd..0dc3dd25168 100644 --- a/.github/workflows/interface-unit-tests.yml +++ b/.github/workflows/interface-unit-tests.yml @@ -74,7 +74,7 @@ jobs: "qcut-tests": ["3.9"], "qchem-tests": ["3.9"], "gradients-tests": ["3.9"], - "data-tests": ["3.10"], + "data-tests": ["3.9", "3.10"], "device-tests": ["3.9"] } EOF From fa97d5326ee37e92280f19a3b898b8535eb03c5a Mon Sep 17 00:00:00 2001 From: "Lee J. O'Riordan" Date: Wed, 8 Nov 2023 15:08:49 -0500 Subject: [PATCH 4/5] Update changelog --- doc/releases/changelog-0.33.1.md | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/doc/releases/changelog-0.33.1.md b/doc/releases/changelog-0.33.1.md index a31b30d8240..aa7bb956f87 100644 --- a/doc/releases/changelog-0.33.1.md +++ b/doc/releases/changelog-0.33.1.md @@ -14,6 +14,9 @@

Bug fixes 🐛

+* Fix gradient performance regression due to expansion of VJP products. + [(#4806)](https://github.com/PennyLaneAI/pennylane/pull/4806) + * `qml.defer_measurements` now correctly transforms circuits when terminal measurements include wires used in mid-circuit measurements. [(#4787)](https://github.com/PennyLaneAI/pennylane/pull/4787) @@ -31,4 +34,5 @@ This release contains contributions from (in alphabetical order): Christina Lee, +Lee James O'Riordan, Mudit Pandey From 96070f27757a5e4950d322217f036e2cf2c88f0b Mon Sep 17 00:00:00 2001 From: "Lee J. O'Riordan" Date: Wed, 8 Nov 2023 15:09:23 -0500 Subject: [PATCH 5/5] [skip-ci]