Skip to content

Commit

Permalink
Fix qml.grad so that the returned gradient always matches the cost fu…
Browse files Browse the repository at this point in the history
…nction return type if only a single argument is differentiated (#1067)

* Bugfix for qml.grad

* Bugfix for qml.grad

* add tests

* fix tests

* fix

* Apply suggestions from code review

Co-authored-by: Chase Roberts <chase@xanadu.ai>

Co-authored-by: Chase Roberts <chase@xanadu.ai>
Co-authored-by: antalszava <antalszava@gmail.com>
  • Loading branch information
3 people authored Feb 9, 2021
1 parent 812e9d4 commit cfdb6f8
Show file tree
Hide file tree
Showing 10 changed files with 147 additions and 17 deletions.
5 changes: 5 additions & 0 deletions .github/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,11 @@

<h3>Bug fixes</h3>

* If only one argument to the function `qml.grad` has the `requires_grad` attribute
set to True, then the returned gradient will be a NumPy array, rather than a
tuple of length 1.
[(#)](https://github.com/PennyLaneAI/pennylane/pull/)

<h3>Documentation</h3>

<h3>Contributors</h3>
Expand Down
3 changes: 3 additions & 0 deletions pennylane/_grad.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,9 @@ def _get_grad_fn(self, args):
if getattr(arg, "requires_grad", True):
argnum.append(idx)

if len(argnum) == 1:
argnum = argnum[0]

return self._grad_with_forward(
self._fun,
argnum=argnum,
Expand Down
3 changes: 3 additions & 0 deletions pennylane/optimize/gradient_descent.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,9 @@ def compute_grad(objective_fn, args, kwargs, grad_fn=None):
grad = g(*args, **kwargs)
forward = getattr(g, "forward", None)

if len(args) == 1:
grad = (grad,)

return grad, forward

def apply_grad(self, grad, args):
Expand Down
3 changes: 3 additions & 0 deletions pennylane/optimize/nesterov_momentum.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,4 +82,7 @@ def compute_grad(self, objective_fn, args, kwargs, grad_fn=None):
grad = g(*shifted_args, **kwargs)
forward = getattr(g, "forward", None)

if len(args) == 1:
grad = (grad,)

return grad, forward
4 changes: 2 additions & 2 deletions tests/devices/test_default_qubit_autograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,7 +352,7 @@ def circuit(weights):
qml.init.strong_ent_layers_normal(n_wires=2, n_layers=2), requires_grad=True
)

grad = qml.grad(circuit)(weights)[0]
grad = qml.grad(circuit)(weights)
assert grad.shape == weights.shape

def test_qnode_collection_integration(self):
Expand All @@ -374,7 +374,7 @@ def ansatz(weights, **kwargs):
def cost(weights):
return np.sum(qnodes(weights))

grad = qml.grad(cost)(weights)[0]
grad = qml.grad(cost)(weights)
assert grad.shape == weights.shape

class TestOps:
Expand Down
11 changes: 4 additions & 7 deletions tests/interfaces/test_autograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -541,8 +541,7 @@ def circuit(weights, data1, data2):

# we do not check for correctness, just that the output
# is the correct shape
assert len(res) == 1
assert res[0].shape == weights.shape
assert res.shape == weights.shape

# check that the first arg was marked as non-differentiable
assert circuit.get_trainable_args() == {0}
Expand Down Expand Up @@ -587,8 +586,7 @@ def circuit(data1, weights, data2):

# we do not check for correctness, just that the output
# is the correct shape
assert len(res) == 1
assert res[0].shape == weights.shape
assert res.shape == weights.shape

# check that the second arg was marked as non-differentiable
assert circuit.get_trainable_args() == {1}
Expand Down Expand Up @@ -633,8 +631,7 @@ def circuit(data1, data2, weights):

# we do not check for correctness, just that the output
# is the correct shape
assert len(res) == 1
assert res[0].shape == weights.shape
assert res.shape == weights.shape

# check that the last arg was marked as non-differentiable
assert circuit.get_trainable_args() == {2}
Expand Down Expand Up @@ -748,7 +745,7 @@ def cost(weights):
grad_fn = qml.grad(cost)
res = grad_fn(weights)

assert len(res[0]) == 2
assert len(res) == 2

def test_gradient_value(self, tol):
"""Test that the returned gradient value for a qubit QNode is correct,
Expand Down
2 changes: 1 addition & 1 deletion tests/math/test_autograd_box.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ def test_autodifferentiation():
cost_fn = lambda a: (qml.math.TensorBox(a).T() ** 2).unbox()[0, 1]
grad_fn = qml.grad(cost_fn)

res = grad_fn(x)[0]
res = grad_fn(x)
expected = np.array([[0.0, 0.0, 0.0], [8.0, 0.0, 0.0]])
assert np.all(res == expected)

Expand Down
2 changes: 1 addition & 1 deletion tests/math/test_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -975,7 +975,7 @@ def cost(weights):
assert isinstance(res, np.ndarray)
assert fn.allclose(res, onp.array([[1.0, 1.0, 1.0], [1.0, 1.0, 1.3136]]))

grad = qml.grad(lambda weights: cost(weights)[1, 2])([x, y])[0]
grad = qml.grad(lambda weights: cost(weights)[1, 2])([x, y])
assert fn.allclose(grad[0], onp.array([[0, 0, 0], [0, 0, 1.]]))
assert fn.allclose(grad[1], 2 * y)

Expand Down
129 changes: 124 additions & 5 deletions tests/tape/interfaces/test_qnode_autograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,9 +118,8 @@ def circuit(a):

# gradients should work
grad = qml.grad(circuit)(a)
assert len(grad) == 1
assert isinstance(grad[0], np.ndarray)
assert grad[0].shape == tuple()
assert isinstance(grad, float)
assert grad.shape == tuple()

def test_interface_swap(self, dev_name, diff_method, tol):
"""Test that the autograd interface can be applied to a QNode
Expand Down Expand Up @@ -152,10 +151,9 @@ def circuit(a):

res = circuit(a)
grad = qml.grad(circuit)(a)
assert len(grad) == 1

assert np.allclose(res, res_tf, atol=tol, rtol=0)
assert np.allclose(grad[0], grad_tf, atol=tol, rtol=0)
assert np.allclose(grad, grad_tf, atol=tol, rtol=0)

def test_jacobian(self, dev_name, diff_method, mocker, tol):
"""Test jacobian calculation"""
Expand Down Expand Up @@ -594,6 +592,127 @@ def circuit():
assert res.shape == (2, 10)
assert isinstance(res, np.ndarray)

def test_gradient_non_differentiable_exception(self, dev_name, diff_method):
"""Test that an exception is raised if non-differentiable data is
differentiated"""
dev = qml.device(dev_name, wires=2)

@qml.qnode(dev, interface="autograd", diff_method=diff_method)
def circuit(data1):
qml.templates.AmplitudeEmbedding(data1, wires=[0, 1])
return qml.expval(qml.PauliZ(0))

grad_fn = qml.grad(circuit, argnum=0)
data1 = np.array([0, 1, 1, 0], requires_grad=False) / np.sqrt(2)

with pytest.raises(qml.numpy.NonDifferentiableError, match="is non-differentiable"):
grad_fn(data1)

def test_chained_qnodes(self, dev_name, diff_method):
"""Test that the gradient of chained QNodes works without error"""
dev = qml.device(dev_name, wires=2)

@qml.qnode(dev, interface="autograd", diff_method=diff_method)
def circuit1(weights):
qml.templates.StronglyEntanglingLayers(weights, wires=[0, 1])
return qml.expval(qml.PauliZ(0)), qml.expval(qml.PauliZ(1))

@qml.qnode(dev, interface="autograd", diff_method=diff_method)
def circuit2(data, weights):
qml.templates.AngleEmbedding(data, wires=[0, 1])
qml.templates.StronglyEntanglingLayers(weights, wires=[0, 1])
return qml.expval(qml.PauliX(0))

def cost(weights):
w1, w2 = weights
c1 = circuit1(w1)
c2 = circuit2(c1, w2)
return np.sum(c2) ** 2

w1 = qml.init.strong_ent_layers_normal(n_wires=2, n_layers=3)
w2 = qml.init.strong_ent_layers_normal(n_wires=2, n_layers=4)

weights = [w1, w2]

grad_fn = qml.grad(cost)
res = grad_fn(weights)

assert len(res) == 2

def test_chained_gradient_value(self, dev_name, diff_method, tol):
"""Test that the returned gradient value for two chained qubit QNodes
is correct."""
dev1 = qml.device(dev_name, wires=3)

@qml.qnode(dev1, diff_method=diff_method)
def circuit1(a, b, c):
qml.RX(a, wires=0)
qml.RX(b, wires=1)
qml.RX(c, wires=2)
qml.CNOT(wires=[0, 1])
qml.CNOT(wires=[1, 2])
return qml.expval(qml.PauliZ(0)), qml.expval(qml.PauliY(2))

dev2 = qml.device("default.qubit", wires=2)

@qml.qnode(dev2, diff_method=diff_method)
def circuit2(data, weights):
qml.RX(data[0], wires=0)
qml.RX(data[1], wires=1)
qml.CNOT(wires=[0, 1])
qml.RZ(weights[0], wires=0)
qml.RZ(weights[1], wires=1)
qml.CNOT(wires=[0, 1])
return qml.expval(qml.PauliX(0) @ qml.PauliY(1))

def cost(a, b, c, weights):
return circuit2(circuit1(a, b, c), weights)

grad_fn = qml.grad(cost)

# Set the first parameter of circuit1 as non-differentiable.
a = np.array(0.4, requires_grad=False)

# The remaining free parameters are all differentiable.
b = 0.5
c = 0.1
weights = np.array([0.2, 0.3])

res = grad_fn(a, b, c, weights)

# Output should have shape [dcost/db, dcost/dc, dcost/dw],
# where b,c are scalars, and w is a vector of length 2.
assert len(res) == 3
assert res[0].shape == tuple() # scalar
assert res[1].shape == tuple() # scalar
assert res[2].shape == (2,) # vector

cacbsc = np.cos(a)*np.cos(b)*np.sin(c)

expected = np.array([
# analytic expression for dcost/db
-np.cos(a)*np.sin(b)*np.sin(c)*np.cos(cacbsc)*np.sin(weights[0])*np.sin(np.cos(a)),
# analytic expression for dcost/dc
np.cos(a)*np.cos(b)*np.cos(c)*np.cos(cacbsc)*np.sin(weights[0])*np.sin(np.cos(a)),
# analytic expression for dcost/dw[0]
np.sin(cacbsc)*np.cos(weights[0])*np.sin(np.cos(a)),
# analytic expression for dcost/dw[1]
0
])

# np.hstack 'flattens' the ragged gradient array allowing it
# to be compared with the expected result
assert np.allclose(np.hstack(res), expected, atol=tol, rtol=0)

if diff_method != "backprop":
# Check that the gradient was computed
# for all parameters in circuit2
assert circuit2.qtape.trainable_params == {0, 1, 2, 3}

# Check that the parameter-shift rule was not applied
# to the first parameter of circuit1.
assert circuit1.qtape.trainable_params == {1, 2}


def qtransform(qnode, a, framework=np):
"""Transforms every RY(y) gate in a circuit to RX(-a*cos(y))"""
Expand Down
2 changes: 1 addition & 1 deletion tests/test_classical_gradients.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,7 +306,7 @@ def test_no_argnum_grad(self, mocker, tol):
res = grad_fn(x, y)
expected = np.array([np.cos(x) * np.cos(y) + y ** 2])
assert np.allclose(res, expected, atol=tol, rtol=0)
assert spy.call_args_list[0][1]["argnum"] == [0]
assert spy.call_args_list[0][1]["argnum"] == 0


class TestJacobian:
Expand Down

0 comments on commit cfdb6f8

Please sign in to comment.