diff --git a/.github/workflows/interface-unit-tests.yml b/.github/workflows/interface-unit-tests.yml
index e78e2e94dcd..0dc3dd25168 100644
--- a/.github/workflows/interface-unit-tests.yml
+++ b/.github/workflows/interface-unit-tests.yml
@@ -74,7 +74,7 @@ jobs:
"qcut-tests": ["3.9"],
"qchem-tests": ["3.9"],
"gradients-tests": ["3.9"],
- "data-tests": ["3.10"],
+ "data-tests": ["3.9", "3.10"],
"device-tests": ["3.9"]
}
EOF
diff --git a/doc/releases/changelog-0.33.1.md b/doc/releases/changelog-0.33.1.md
index a31b30d8240..aa7bb956f87 100644
--- a/doc/releases/changelog-0.33.1.md
+++ b/doc/releases/changelog-0.33.1.md
@@ -14,6 +14,9 @@
Bug fixes 🐛
+* Fix gradient performance regression due to expansion of VJP products.
+ [(#4806)](https://github.com/PennyLaneAI/pennylane/pull/4806)
+
* `qml.defer_measurements` now correctly transforms circuits when terminal measurements include wires
used in mid-circuit measurements.
[(#4787)](https://github.com/PennyLaneAI/pennylane/pull/4787)
@@ -31,4 +34,5 @@
This release contains contributions from (in alphabetical order):
Christina Lee,
+Lee James O'Riordan,
Mudit Pandey
diff --git a/pennylane/gradients/vjp.py b/pennylane/gradients/vjp.py
index ae1ff46501a..50bc96853c3 100644
--- a/pennylane/gradients/vjp.py
+++ b/pennylane/gradients/vjp.py
@@ -47,7 +47,7 @@ def _all_close_to_zero(dy):
return qml.math.allclose(dy, 0)
# call this method recursively
- return qml.math.all(qml.math.stack([_all_close_to_zero(dy_) for dy_ in dy]))
+ return all(_all_close_to_zero(dy_) for dy_ in dy)
def compute_vjp_single(dy, jac, num=None):
@@ -75,7 +75,7 @@ def compute_vjp_single(dy, jac, num=None):
>>> compute_vjp_single(dy, jac)
np.array([0.2])
- 2. For a single parameter and a single measurment with shape (e.g. probs):
+ 2. For a single parameter and a single measurement with shape (e.g. probs):
.. code-block:: pycon
@@ -115,16 +115,8 @@ def compute_vjp_single(dy, jac, num=None):
if not isinstance(dy_row, np.ndarray):
jac = _convert(jac, dy_row)
- try:
- if _all_close_to_zero(dy):
- # If the dy vector is zero, then the
- # corresponding element of the VJP will be zero.
- num_params = len(jac) if isinstance(jac, tuple) else 1
-
- res = qml.math.convert_like(np.zeros(num_params), dy)
- return qml.math.cast_like(res, dy)
- except (AttributeError, TypeError):
- pass
+ # Note: For generality, all exception type warnings are disabled.
+ # TODO: Excplictly catalogue and update raises for known types.
# Single measurement with a single param
if not isinstance(jac, (tuple, autograd.builtins.SequenceBox)):
@@ -136,7 +128,12 @@ def compute_vjp_single(dy, jac, num=None):
if num == 1:
jac = qml.math.squeeze(jac)
jac = qml.math.reshape(jac, (-1, 1))
- res = qml.math.tensordot(jac, dy_row, [[0], [0]])
+ try:
+ res = dy_row @ jac
+
+ except Exception: # pylint: disable=broad-except
+ res = qml.math.tensordot(jac, dy_row, [[0], [0]])
+
# Single measurement with multiple params
else:
# No trainable parameters (adjoint)
@@ -146,12 +143,19 @@ def compute_vjp_single(dy, jac, num=None):
# Single measurement with no dimension e.g. expval
if num == 1:
jac = qml.math.reshape(qml.math.stack(jac), (1, -1))
- res = qml.math.tensordot(jac, dy_row, [[0], [0]])
+ try:
+ res = dy_row @ jac
+ except Exception: # pylint: disable=broad-except
+ res = qml.math.tensordot(jac, dy_row, [[0], [0]])
# Single measurement with dimension e.g. probs
else:
jac = qml.math.stack(jac)
- res = qml.math.tensordot(jac, dy_row, [[1], [0]])
+ try:
+ res = jac @ dy_row
+ except Exception: # pylint: disable=broad-except
+ res = qml.math.tensordot(jac, dy_row, [[1], [0]])
+
return res
@@ -202,14 +206,34 @@ def compute_vjp_multi(dy, jac, num=None):
res = qml.math.sum(qml.math.stack(res), axis=0)
# Multiple parameters
else:
- res = []
- for d, j_ in zip(dy, jac):
- sub_res = []
- for j in j_:
- sub_res.append(qml.math.squeeze(compute_vjp_single(d, j, num=num)))
- res.append(sub_res)
- res = qml.math.stack([qml.math.stack(r) for r in res])
- res = qml.math.sum(res, axis=0)
+ try:
+ dy_interface = qml.math.get_interface(dy[0])
+ # dy -> (i,j) observables, entries per observable
+ # jac -> (i,k,j) observables, parameters, entries per observable
+ # Contractions over observables and entries per observable
+ dy_shape = qml.math.shape(dy)
+ if len(dy_shape) > 1: # multiple values exist per observable output
+ return qml.math.array(qml.math.einsum("ij,i...j", dy, jac), like=dy[0])
+
+ if dy_interface == "tensorflow":
+ # TF needs a different path for Hessian support
+ return qml.math.array(
+ qml.math.einsum("i,i...", dy, jac, like=dy[0]), like=dy[0]
+ ) # Scalar value per observable output
+ return qml.math.array(
+ qml.math.einsum("i,i...", dy, jac), like=dy[0]
+ ) # Scalar value per observable output
+ # NOTE: We want any potential failure to fall back here, so catch every exception type
+ # TODO: Catalogue and update for expected exception types
+ except Exception: # pylint: disable=broad-except
+ res = []
+ for d, j_ in zip(dy, jac):
+ sub_res = []
+ for j in j_:
+ sub_res.append(qml.math.squeeze(compute_vjp_single(d, j, num=num)))
+ res.append(sub_res)
+ res = qml.math.stack([qml.math.stack(r) for r in res])
+ res = qml.math.sum(res, axis=0)
return res
diff --git a/pennylane/math/multi_dispatch.py b/pennylane/math/multi_dispatch.py
index b8db858f527..d8739c7db7b 100644
--- a/pennylane/math/multi_dispatch.py
+++ b/pennylane/math/multi_dispatch.py
@@ -540,6 +540,13 @@ def einsum(indices, *operands, like=None, optimize=None):
if optimize is None or like == "torch":
# torch einsum doesn't support the optimize keyword argument
return np.einsum(indices, *operands, like=like)
+ if like == "tensorflow":
+ # Unpacking and casting necessary for higher order derivatives,
+ # and avoiding implicit fp32 down-conversions.
+ op1, op2 = operands
+ op1 = array(op1, like=op1[0], dtype=op1[0].dtype)
+ op2 = array(op2, like=op2[0], dtype=op2[0].dtype)
+ return np.einsum(indices, op1, op2, like=like)
return np.einsum(indices, *operands, like=like, optimize=optimize)
diff --git a/tests/gradients/core/test_vjp.py b/tests/gradients/core/test_vjp.py
index 0904d1a19c2..b533389cc13 100644
--- a/tests/gradients/core/test_vjp.py
+++ b/tests/gradients/core/test_vjp.py
@@ -117,7 +117,7 @@ def test_zero_dy_single_measurement_single_params(self):
def test_zero_dy_single_measurement_multi_params(self):
"""A zero dy vector will return a zero matrix"""
- dy = np.zeros([2])
+ dy = np.zeros(1)
jac = tuple([np.array(0.1), np.array(0.2)])
vjp = qml.gradients.compute_vjp_single(dy, jac)
diff --git a/tests/gradients/finite_diff/test_finite_difference.py b/tests/gradients/finite_diff/test_finite_difference.py
index 0179a3036f6..9f1752dba82 100644
--- a/tests/gradients/finite_diff/test_finite_difference.py
+++ b/tests/gradients/finite_diff/test_finite_difference.py
@@ -477,6 +477,15 @@ def __init__(self, val):
def __mul__(self, other):
return SpecialObject(self.val * other)
+ def __rmul__(self, other):
+ return SpecialObject(other * self.val)
+
+ def __matmul__(self, other):
+ return SpecialObject(self.val @ other)
+
+ def __rmatmul__(self, other):
+ return SpecialObject(other @ self.val)
+
def __add__(self, other):
new = self.val + (other.val if isinstance(other, self.__class__) else other)
return SpecialObject(new)