Skip to content

Commit

Permalink
Bug fix to AffineAutoregressive (#2504)
Browse files Browse the repository at this point in the history
  • Loading branch information
stefanwebb authored May 25, 2020
1 parent aa40beb commit 21e3aee
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 7 deletions.
2 changes: 0 additions & 2 deletions .gitattributes

This file was deleted.

2 changes: 1 addition & 1 deletion pyro/distributions/transforms/affine_autoregressive.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,8 +217,8 @@ def _inverse_stable(self, y):
mean, logit_scale = self.arn(torch.stack(x, dim=-1))
inverse_scale = 1 + torch.exp(-logit_scale[..., idx] - self.sigmoid_bias)
x[idx] = inverse_scale * y[..., idx] + (1 - inverse_scale) * mean[..., idx]
self._cached_log_scale = inverse_scale

self._cached_log_scale = self.logsigmoid(logit_scale + self.sigmoid_bias)
x = torch.stack(x, dim=-1)
return x

Expand Down
11 changes: 7 additions & 4 deletions tests/distributions/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,16 +95,19 @@ def nonzero(x):
assert lower_sum == float(0.0)

def _test_inverse(self, shape, transform):
# Test g^{-1}(g(x)) = x
# NOTE: Calling _call and _inverse directly bypasses caching
base_dist = dist.Normal(torch.zeros(shape), torch.ones(shape))

x_true = base_dist.sample(torch.Size([10]))
y = transform._call(x_true)

# Cache is empty, hence must be calculating inverse afresh
J_1 = transform.log_abs_det_jacobian(x_true, y)
x_calculated = transform._inverse(y)

J_2 = transform.log_abs_det_jacobian(x_true, y)
assert (x_true - x_calculated).abs().max().item() < self.delta

# Test that Jacobian after inverse op is same as after forward
assert (J_1 - J_2).abs().max().item() < self.delta

def _test_shape(self, base_shape, transform):
base_dist = dist.Normal(torch.zeros(base_shape), torch.ones(base_shape))
sample = dist.TransformedDistribution(base_dist, [transform]).sample()
Expand Down

0 comments on commit 21e3aee

Please sign in to comment.