Skip to content

Commit

Permalink
apply suggestion
Browse files Browse the repository at this point in the history
  • Loading branch information
antalszava committed Aug 4, 2021
1 parent c45226e commit 386eb77
Showing 1 changed file with 2 additions and 5 deletions.
7 changes: 2 additions & 5 deletions pennylane/optimize/gradient_descent.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
"""Gradient descent optimizer"""

from pennylane._grad import grad as get_gradient
from pennylane.math import get_trainable_indices
from pennylane.utils import _flatten, unflatten
from pennylane.numpy import ndarray, tensor

Expand Down Expand Up @@ -129,11 +130,7 @@ def compute_grad(objective_fn, args, kwargs, grad_fn=None):
grad = g(*args, **kwargs)
forward = getattr(g, "forward", None)

num_trainable_args = 0
for arg in args:
if getattr(arg, "requires_grad", True):
num_trainable_args += 1

num_trainable_args = len(get_trainable_indices(args))
if num_trainable_args == 1:
grad = (grad,)

Expand Down

0 comments on commit 386eb77

Please sign in to comment.