From d0cb1cccd84302814cbfa0d9e1e1a727fb4e839b Mon Sep 17 00:00:00 2001 From: frankmao666 Date: Wed, 28 Apr 2021 13:52:53 -0400 Subject: [PATCH] ignore grad when parameter not used, has a value of None --- pyro/infer/csis.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/pyro/infer/csis.py b/pyro/infer/csis.py index 1d392c2a2c..06db33a844 100644 --- a/pyro/infer/csis.py +++ b/pyro/infer/csis.py @@ -113,6 +113,8 @@ def loss_and_grads(self, grads, batch, *args, **kwargs): for site in particle_param_capture.trace.nodes.values()) guide_grads = torch.autograd.grad(particle_loss, guide_params, allow_unused=True) for guide_grad, guide_param in zip(guide_grads, guide_params): + if guide_grad is None: + continue guide_param.grad = guide_grad if guide_param.grad is None else guide_param.grad + guide_grad loss += torch_item(particle_loss)