Skip to content

Commit

Permalink
Fix how trainable args are counted for gradients in `GradientDescentO…
Browse files Browse the repository at this point in the history
…ptimizer` and `NesterovMomentumOptimizer` (#1495)

* count the trainable args, not simply the args

* changelog

* update nesterov opt too

* update nesterov opt too

* comment

* changelog update

* remove redundant enumerate

* format

* add another test for two trainable args

* extract trainable args

* enumerate trainable args

* Update tests/test_optimize.py

Co-authored-by: Josh Izaac <josh146@gmail.com>

* Update tests/test_optimize.py

Co-authored-by: Josh Izaac <josh146@gmail.com>

* apply suggestion

* add new test case

* test docstring

* Revert "apply suggestion"

This reverts commit 386eb77.

* format

* remove extra dimensionality from non-trainable

* create a fixture that returns a new optimizer object for each test case; the previous version left a state because the object lived on

* format test

* Update tests/test_optimize.py

Co-authored-by: Josh Izaac <josh146@gmail.com>
  • Loading branch information
antalszava and josh146 authored Aug 6, 2021
1 parent c0cdff2 commit 53e1dc8
Show file tree
Hide file tree
Showing 4 changed files with 123 additions and 11 deletions.
5 changes: 5 additions & 0 deletions .github/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -466,6 +466,11 @@

<h3>Bug fixes</h3>

* Fixes a bug in `GradientDescentOptimizer` and `NesterovMomentumOptimizer`
where a cost function with one trainable parameter and non-trainable
parameters raised an error.
[(#1495)](https://github.com/PennyLaneAI/pennylane/pull/1495)

* Fixed an example in the documentation's
[introduction to numpy gradients](https://pennylane.readthedocs.io/en/stable/introduction/interfaces/numpy.html), where
the wires were a non-differentiable argument to the QNode.
Expand Down
7 changes: 6 additions & 1 deletion pennylane/optimize/gradient_descent.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,12 @@ def compute_grad(objective_fn, args, kwargs, grad_fn=None):
grad = g(*args, **kwargs)
forward = getattr(g, "forward", None)

if len(args) == 1:
num_trainable_args = 0
for arg in args:
if getattr(arg, "requires_grad", True):
num_trainable_args += 1

if num_trainable_args == 1:
grad = (grad,)

return grad, forward
Expand Down
9 changes: 7 additions & 2 deletions pennylane/optimize/nesterov_momentum.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,13 @@ def compute_grad(self, objective_fn, args, kwargs, grad_fn=None):
"""
shifted_args = list(args)

trainable_args = []
for arg in args:
if getattr(arg, "requires_grad", True):
trainable_args.append(arg)

if self.accumulation:
for index, arg in enumerate(args):
for index, arg in enumerate(trainable_args):
if self.accumulation[index]:
x_flat = _flatten(arg)
acc = _flatten(self.accumulation[index])
Expand All @@ -82,7 +87,7 @@ def compute_grad(self, objective_fn, args, kwargs, grad_fn=None):
grad = g(*shifted_args, **kwargs)
forward = getattr(g, "forward", None)

if len(args) == 1:
if len(trainable_args) == 1:
grad = (grad,)

return grad, forward
113 changes: 105 additions & 8 deletions tests/test_optimize.py
Original file line number Diff line number Diff line change
Expand Up @@ -747,16 +747,40 @@ def reset(opt):
opt.reset()


@pytest.fixture
def opt(opt_name):
if opt_name == "gd":
return GradientDescentOptimizer(stepsize)

if opt_name == "nest":
return NesterovMomentumOptimizer(stepsize, momentum=gamma)

if opt_name == "moment":
return MomentumOptimizer(stepsize, momentum=gamma)

if opt_name == "ada":
return AdagradOptimizer(stepsize)

if opt_name == "rms":
return RMSPropOptimizer(stepsize, decay=gamma)

if opt_name == "adam":
return AdamOptimizer(stepsize, beta1=gamma, beta2=delta)

if opt_name == "roto":
return RotosolveOptimizer()


@pytest.mark.parametrize(
"opt, opt_name",
"opt_name",
[
(GradientDescentOptimizer(stepsize), "gd"),
(MomentumOptimizer(stepsize, momentum=gamma), "moment"),
(NesterovMomentumOptimizer(stepsize, momentum=gamma), "nest"),
(AdagradOptimizer(stepsize), "ada"),
(RMSPropOptimizer(stepsize, decay=gamma), "rms"),
(AdamOptimizer(stepsize, beta1=gamma, beta2=delta), "adam"),
(RotosolveOptimizer(), "roto"),
"gd",
"moment",
"nest",
"ada",
"rms",
"adam",
"roto",
],
)
class TestOverOpts:
Expand Down Expand Up @@ -877,3 +901,76 @@ def func2(args):
assert x_seperate == pytest.approx(args_new[0], abs=tol)
assert y_seperate == pytest.approx(args_new[1], abs=tol)
assert z_seperate == pytest.approx(args_new[2], abs=tol)

def test_one_trainable_one_non_trainable(self, opt, opt_name, tol):
"""Tests that a cost function that takes one trainable and one
non-trainable parameter executes well."""
dev = qml.device("default.qubit", wires=2)

@qml.qnode(dev)
def circuit(x):
qml.RX(x, wires=0)
return qml.expval(qml.PauliZ(0) @ qml.PauliZ(1))

def cost(x, target):
return (circuit(x) - target) ** 2

ev = np.tensor(0.7781, requires_grad=False)
x = np.tensor(0.0, requires_grad=True)

original_ev = ev

(x, ev), cost = opt.step_and_cost(cost, x, ev)

# check that the argument to RX doesn't change, as the X rotation doesn't influence <Z>
assert x == 0
assert ev == original_ev

def test_one_non_trainable_one_trainable(self, opt, opt_name, tol):
"""Tests that a cost function that takes one non-trainable and one
trainable parameter executes well."""
dev = qml.device("default.qubit", wires=2)

@qml.qnode(dev)
def circuit(x):
qml.RX(x, wires=0)
return qml.expval(qml.PauliZ(0) @ qml.PauliZ(1))

def cost(target, x): # Note: the order of the arguments has been swapped
return (circuit(x) - target) ** 2

ev = np.tensor(0.7781, requires_grad=False)
x = np.tensor(0.0, requires_grad=True)

original_ev = ev

(ev, x), cost = opt.step_and_cost(cost, ev, x)
# check that the argument to RX doesn't change, as the X rotation doesn't influence <Z>
assert x == 0
assert ev == original_ev

def test_two_trainable_args(self, opt, opt_name, tol):
"""Tests that a cost function that takes at least two trainable
arguments executes well."""
dev = qml.device("default.qubit", wires=2)

@qml.qnode(dev)
def circuit(x, y):
qml.RX(x, wires=0)
qml.RX(y, wires=0)
return qml.expval(qml.PauliZ(0) @ qml.PauliZ(1))

def cost(x, y, target):
return (circuit(x, y) - target) ** 2

ev = np.tensor(0.7781, requires_grad=False)
x = np.tensor(0.0, requires_grad=True)
y = np.tensor(0.0, requires_grad=True)

original_ev = ev

(x, y, ev), cost = opt.step_and_cost(cost, x, y, ev)

# check that the argument to RX doesn't change, as the X rotation doesn't influence <Z>
assert x == 0
assert ev == original_ev

0 comments on commit 53e1dc8

Please sign in to comment.