From d2ff6fe65dea0cf93da497c9392188ba5d53a3c6 Mon Sep 17 00:00:00 2001 From: Felix Dangel Date: Fri, 3 May 2024 17:18:37 -0400 Subject: [PATCH] [FIX] Use `materialize_grad=True` in GGNVPs --- backpack/hessianfree/lop.py | 1 + backpack/hessianfree/rop.py | 9 ++++++++- setup.cfg | 2 +- 3 files changed, 10 insertions(+), 2 deletions(-) diff --git a/backpack/hessianfree/lop.py b/backpack/hessianfree/lop.py index 60c0e044e..8d15d8615 100644 --- a/backpack/hessianfree/lop.py +++ b/backpack/hessianfree/lop.py @@ -13,6 +13,7 @@ def L_op(ys, xs, ws, retain_graph=True, detach=True): create_graph=True, retain_graph=retain_graph, allow_unused=True, + materialize_grads=True, ) if detach: return tuple(j.detach() for j in vJ) diff --git a/backpack/hessianfree/rop.py b/backpack/hessianfree/rop.py index 007b9e7ef..e50733cad 100644 --- a/backpack/hessianfree/rop.py +++ b/backpack/hessianfree/rop.py @@ -18,10 +18,17 @@ def R_op(ys, xs, vs, retain_graph=True, detach=True): create_graph=True, retain_graph=retain_graph, allow_unused=True, + materialize_grads=True, ) re = torch.autograd.grad( - gs, ws, grad_outputs=vs, create_graph=True, retain_graph=True, allow_unused=True + gs, + ws, + grad_outputs=vs, + create_graph=True, + retain_graph=True, + allow_unused=True, + materialize_grads=True, ) if detach: diff --git a/setup.cfg b/setup.cfg index 8be33bc68..277759089 100644 --- a/setup.cfg +++ b/setup.cfg @@ -34,7 +34,7 @@ setup_requires = setuptools_scm # Dependencies of the project (semicolon/line-separated): install_requires = - torch >= 1.9.0 + torch >= 2.2.0 torchvision >= 0.7.0 einops >= 0.3.0, < 1.0.0 unfoldNd >= 0.2.0, < 1.0.0