diff --git a/pyro/infer/csis.py b/pyro/infer/csis.py index 1d392c2a2c..06db33a844 100644 --- a/pyro/infer/csis.py +++ b/pyro/infer/csis.py @@ -113,6 +113,8 @@ def loss_and_grads(self, grads, batch, *args, **kwargs): for site in particle_param_capture.trace.nodes.values()) guide_grads = torch.autograd.grad(particle_loss, guide_params, allow_unused=True) for guide_grad, guide_param in zip(guide_grads, guide_params): + if guide_grad is None: + continue guide_param.grad = guide_grad if guide_param.grad is None else guide_param.grad + guide_grad loss += torch_item(particle_loss)