Skip to content

Commit

Permalink
[FIX] Use materialize_grad=True in GGNVPs
Browse files Browse the repository at this point in the history
  • Loading branch information
f-dangel committed May 3, 2024
1 parent 1ebfb40 commit d2ff6fe
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 2 deletions.
1 change: 1 addition & 0 deletions backpack/hessianfree/lop.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ def L_op(ys, xs, ws, retain_graph=True, detach=True):
create_graph=True,
retain_graph=retain_graph,
allow_unused=True,
materialize_grads=True,
)
if detach:
return tuple(j.detach() for j in vJ)
Expand Down
9 changes: 8 additions & 1 deletion backpack/hessianfree/rop.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,17 @@ def R_op(ys, xs, vs, retain_graph=True, detach=True):
create_graph=True,
retain_graph=retain_graph,
allow_unused=True,
materialize_grads=True,
)

re = torch.autograd.grad(
gs, ws, grad_outputs=vs, create_graph=True, retain_graph=True, allow_unused=True
gs,
ws,
grad_outputs=vs,
create_graph=True,
retain_graph=True,
allow_unused=True,
materialize_grads=True,
)

if detach:
Expand Down
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ setup_requires =
setuptools_scm
# Dependencies of the project (semicolon/line-separated):
install_requires =
torch >= 1.9.0
torch >= 2.2.0
torchvision >= 0.7.0
einops >= 0.3.0, < 1.0.0
unfoldNd >= 0.2.0, < 1.0.0
Expand Down

0 comments on commit d2ff6fe

Please sign in to comment.