Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make the contraction of quantum and classical Jacobians consistent in gradient_transform #4945

Merged
merged 21 commits into from
Apr 25, 2024
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,10 @@

<h3>Breaking changes 💔</h3>

* Applying a `gradient_transform` to a QNode directly now gives the same shape and type independent
of whether there is classical processing in the node.
[(#4945)](https://github.com/PennyLaneAI/pennylane/pull/4945)

* The private functions `_pauli_mult`, `_binary_matrix` and `_get_pauli_map` from the `pauli` module have been removed. The same functionality can be achieved using newer features in the ``pauli`` module.
[(#5323)](https://github.com/PennyLaneAI/pennylane/pull/5323)

Expand Down Expand Up @@ -291,6 +295,10 @@

<h3>Bug fixes 🐛</h3>

* Fixed a bug where the shape and type of derivatives obtained by applying a gradient transform to
a QNode differed, based on whether the QNode uses classical coprocessing.
[(#4945)](https://github.com/PennyLaneAI/pennylane/pull/4945)

* Fix Torch tensor locality with autoray-registered coerce method.
[(#5438)](https://github.com/PennyLaneAI/pennylane/pull/5438)

Expand Down Expand Up @@ -336,4 +344,5 @@ Christina Lee,
Vincent Michaud-Rioux,
Mudit Pandey,
Jay Soni,
David Wierichs,
Matthew Silverman.
52 changes: 40 additions & 12 deletions pennylane/gradients/gradient_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -391,52 +391,80 @@ def _contract_qjac_with_cjac(qjac, cjac, tape):
cjac = cjac[0]

cjac_is_tuple = isinstance(cjac, tuple)
if not cjac_is_tuple:
is_square = cjac.ndim == 2 and cjac.shape[0] == cjac.shape[1]
# skip_cjac = False
rmoyard marked this conversation as resolved.
Show resolved Hide resolved
# if not cjac_is_tuple:
# is_square = cjac.ndim == 2 and cjac.shape[0] == cjac.shape[1]

if not qml.math.is_abstract(cjac) and (
is_square and qml.math.allclose(cjac, qml.numpy.eye(cjac.shape[0]))
):
# Classical Jacobian is the identity. No classical processing is present in the QNode
return qjac
# if not qml.math.is_abstract(cjac) and (
# is_square and qml.math.allclose(cjac, qml.numpy.eye(cjac.shape[0]))
# ):
# skip_cjac = True

multi_meas = num_measurements > 1

if cjac_is_tuple:
multi_params = True
single_tape_param = False
else:
_qjac = qjac
if multi_meas:
_qjac = _qjac[0]
if has_partitioned_shots:
_qjac = _qjac[0]
multi_params = isinstance(_qjac, tuple)
single_tape_param = not isinstance(_qjac, tuple)

tdot = partial(qml.math.tensordot, axes=[[0], [0]])

if not multi_params:
if single_tape_param:
dwierichs marked this conversation as resolved.
Show resolved Hide resolved
# Without dimension (e.g. expval) or with dimension (e.g. probs)
def _reshape(x):
return qml.math.reshape(x, (1,) if x.shape == () else (1, -1))

if not (multi_meas or has_partitioned_shots):
# Single parameter, single measurements
# Single parameter, single measurements, no shot vector
# if skip_cjac:
# return qml.math.moveaxis(_reshape(qjac), 0, -1)
dwierichs marked this conversation as resolved.
Show resolved Hide resolved
return tdot(_reshape(qjac), cjac)

if not (multi_meas and has_partitioned_shots):
# Single parameter, multiple measurements or shot vector, but not both
# if skip_cjac:
# return tuple(qml.math.moveaxis(_reshape(q), 0, -1) for q in qjac)
return tuple(tdot(_reshape(q), cjac) for q in qjac)

# Single parameter, multiple measurements
# Single parameter, multiple measurements, and shot vector
# if skip_cjac:
# return tuple(tuple(qml.math.moveaxis(_reshape(_q), 0, -1) for _q in q) for q in qjac)
dwierichs marked this conversation as resolved.
Show resolved Hide resolved
return tuple(tuple(tdot(_reshape(_q), cjac) for _q in q) for q in qjac)

if not multi_meas:
# Multiple parameters, single measurement
qjac = qml.math.stack(qjac)
if not cjac_is_tuple:
if has_partitioned_shots:
# if skip_cjac:
# return tuple(qml.math.moveaxis(qml.math.stack(q), 0, -1) for q in qjac)
return tuple(tdot(qml.math.stack(q), qml.math.stack(cjac)) for q in qjac)
# if skip_cjac:
# return qml.math.moveaxis(qjac, 0, -1)
return tdot(qjac, qml.math.stack(cjac))
if has_partitioned_shots:
return tuple(tuple(tdot(q, c) for c in cjac if c is not None) for q in qjac)
return tuple(tdot(qjac, c) for c in cjac if c is not None)

# Multiple parameters, multiple measurements
if not cjac_is_tuple:
if has_partitioned_shots:
# if skip_cjac:
# return tuple(tuple(qml.math.moveaxis(qml.math.stack(_q), 0, -1) for _q in q) for q in qjac)
return tuple(
tuple(tdot(qml.math.stack(_q), qml.math.stack(cjac)) for _q in q) for q in qjac
)
# if skip_cjac:
# return tuple(qml.math.moveaxis(qml.math.stack(q), 0, -1) for q in qjac)
return tuple(tdot(qml.math.stack(q), qml.math.stack(cjac)) for q in qjac)
if has_partitioned_shots:
return tuple(
tuple(tuple(tdot(qml.math.stack(_q), c) for c in cjac if c is not None) for _q in q)
for q in qjac
)
dwierichs marked this conversation as resolved.
Show resolved Hide resolved
return tuple(tuple(tdot(qml.math.stack(q), c) for c in cjac if c is not None) for q in qjac)
30 changes: 22 additions & 8 deletions tests/gradients/core/test_gradient_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,50 +186,64 @@ class TestGradientTransformIntegration:

@pytest.mark.parametrize("shots, atol", [(None, 1e-6), (1000, 1e-1), ([1000, 500], 3e-1)])
@pytest.mark.parametrize("slicing", [False, True])
def test_acting_on_qnodes_single_param(self, shots, slicing, atol):
@pytest.mark.parametrize("prefactor", [1.0, 2.0])
def test_acting_on_qnodes_single_param(self, shots, slicing, prefactor, atol):
"""Test that a gradient transform acts on QNodes with a single parameter correctly"""
np.random.seed(412)
dev = qml.device("default.qubit", wires=2, shots=shots)

@qml.qnode(dev)
def circuit(weights):
if slicing:
qml.RX(weights[0], wires=[0])
qml.RX(prefactor * weights[0], wires=[0])
else:
qml.RX(weights, wires=[0])
qml.RX(prefactor * weights, wires=[0])
dwierichs marked this conversation as resolved.
Show resolved Hide resolved
return qml.expval(qml.PauliZ(0)), qml.var(qml.PauliX(1))

grad_fn = qml.gradients.param_shift(circuit)

w = np.array([0.543] if slicing else 0.543, requires_grad=True)
res = grad_fn(w)
assert circuit.interface == "auto"
expected = np.array([-np.sin(w[0] if slicing else w), 0])

# Need to multiply 0 with w to get the right output shape for non-scalar w
expected = (-prefactor * np.sin(prefactor * w), w * 0)
dwierichs marked this conversation as resolved.
Show resolved Hide resolved

if isinstance(shots, list):
assert all(np.allclose(r, expected, atol=atol, rtol=0) for r in res)
else:
assert np.allclose(res, expected, atol=atol, rtol=0)

@pytest.mark.parametrize("shots, atol", [(None, 1e-6), (1000, 1e-1), ([1000, 100], 2e-1)])
def test_acting_on_qnodes_multi_param(self, shots, atol):
@pytest.mark.parametrize("prefactor", [1.0, 2.0])
def test_acting_on_qnodes_multi_param(self, shots, prefactor, atol):
"""Test that a gradient transform acts on QNodes with multiple parameters correctly"""
np.random.seed(412)
dev = qml.device("default.qubit", wires=2, shots=shots)

@qml.qnode(dev)
def circuit(weights):
qml.RX(weights[0], wires=[0])
qml.RY(weights[1], wires=[1])
qml.RY(prefactor * weights[1], wires=[1])
qml.CNOT(wires=[0, 1])
return qml.expval(qml.PauliZ(0)), qml.var(qml.PauliX(1))
return qml.expval(qml.PauliZ(0)), qml.var(qml.PauliZ(1))
dwierichs marked this conversation as resolved.
Show resolved Hide resolved

grad_fn = qml.gradients.param_shift(circuit)

w = np.array([0.543, -0.654], requires_grad=True)
res = grad_fn(w)
assert circuit.interface == "auto"
x, y = w
expected = np.array([[-np.sin(x), 0], [0, -2 * np.cos(y) * np.sin(y)]])
y *= prefactor
expected = np.array(
[
[-np.sin(x), 0],
[
2 * np.cos(x) * np.sin(x) * np.cos(y) ** 2,
2 * prefactor * np.cos(y) * np.sin(y) * np.cos(x) ** 2,
dwierichs marked this conversation as resolved.
Show resolved Hide resolved
],
]
)
if isinstance(shots, list):
assert all(np.allclose(r, expected, atol=atol, rtol=0) for r in res)
else:
Expand Down
Loading
Loading