Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Expose Gaussian algorithms #3145

Merged
merged 2 commits into from
Oct 18, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
135 changes: 14 additions & 121 deletions pyro/distributions/hmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
gaussian_tensordot,
matrix_and_mvn_to_gaussian,
mvn_to_gaussian,
sequential_gaussian_filter_sample,
sequential_gaussian_tensordot,
)
from pyro.ops.indexing import Vindex
from pyro.ops.special import safe_log
Expand Down Expand Up @@ -159,115 +161,6 @@ def _sequential_index(samples):
return samples.squeeze(-3)[..., :duration, :]


def _sequential_gaussian_tensordot(gaussian):
"""
Integrates a Gaussian ``x`` whose rightmost batch dimension is time, computes::

x[..., 0] @ x[..., 1] @ ... @ x[..., T-1]
"""
assert isinstance(gaussian, Gaussian)
assert gaussian.dim() % 2 == 0, "dim is not even"
batch_shape = gaussian.batch_shape[:-1]
state_dim = gaussian.dim() // 2
while gaussian.batch_shape[-1] > 1:
time = gaussian.batch_shape[-1]
even_time = time // 2 * 2
even_part = gaussian[..., :even_time]
x_y = even_part.reshape(batch_shape + (even_time // 2, 2))
x, y = x_y[..., 0], x_y[..., 1]
contracted = gaussian_tensordot(x, y, state_dim)
if time > even_time:
contracted = Gaussian.cat((contracted, gaussian[..., -1:]), dim=-1)
gaussian = contracted
return gaussian[..., 0]


def _is_subshape(x, y):
return broadcast_shape(x, y) == y


def _sequential_gaussian_filter_sample(init, trans, sample_shape):
"""
Draws a reparameterized sample from a Markov product of Gaussians via
parallel-scan forward-filter backward-sample.
"""
assert isinstance(init, Gaussian)
assert isinstance(trans, Gaussian)
assert trans.dim() == 2 * init.dim()
assert _is_subshape(trans.batch_shape[:-1], init.batch_shape)
state_dim = trans.dim() // 2
device = trans.precision.device
perm = torch.cat(
[
torch.arange(1 * state_dim, 2 * state_dim, device=device),
torch.arange(0 * state_dim, 1 * state_dim, device=device),
torch.arange(2 * state_dim, 3 * state_dim, device=device),
]
)

# Forward filter, similar to _sequential_gaussian_tensordot().
tape = []
shape = trans.batch_shape[:-1] # Note trans may be unbroadcasted.
gaussian = trans
while gaussian.batch_shape[-1] > 1:
time = gaussian.batch_shape[-1]
even_time = time // 2 * 2
even_part = gaussian[..., :even_time]
x_y = even_part.reshape(shape + (even_time // 2, 2))
x, y = x_y[..., 0], x_y[..., 1]
x = x.event_pad(right=state_dim)
y = y.event_pad(left=state_dim)
joint = (x + y).event_permute(perm)
tape.append(joint)
contracted = joint.marginalize(left=state_dim)
if time > even_time:
contracted = Gaussian.cat((contracted, gaussian[..., -1:]), dim=-1)
gaussian = contracted
gaussian = gaussian[..., 0] + init.event_pad(right=state_dim)

# Backward sample.
shape = sample_shape + init.batch_shape
result = gaussian.rsample(sample_shape).reshape(shape + (2, state_dim))
for joint in reversed(tape):
# The following comments demonstrate two example computations, one
# EVEN, one ODD. Ignoring sample_shape and batch_shape, let each zn be
# a single sampled event of shape (state_dim,).
if joint.batch_shape[-1] == result.size(-2) - 1: # EVEN case.
# Suppose e.g. result = [z0, z2, z4]
cond = result.repeat_interleave(2, dim=-2) # [z0, z0, z2, z2, z4, z4]
cond = cond[..., 1:-1, :] # [z0, z2, z2, z4]
cond = cond.reshape(shape + (-1, 2 * state_dim)) # [z0z2, z2z4]
sample = joint.condition(cond).rsample() # [z1, z3]
sample = torch.nn.functional.pad(sample, (0, 0, 0, 1)) # [z1, z3, 0]
result = torch.stack(
[
result, # [z0, z2, z4]
sample, # [z1, z3, 0]
],
dim=-2,
) # [[z0, z1], [z2, z3], [z4, 0]]
result = result.reshape(shape + (-1, state_dim)) # [z0, z1, z2, z3, z4, 0]
result = result[..., :-1, :] # [z0, z1, z2, z3, z4]
else: # ODD case.
assert joint.batch_shape[-1] == result.size(-2) - 2
# Suppose e.g. result = [z0, z2, z3]
cond = result[..., :-1, :].repeat_interleave(2, dim=-2) # [z0, z0, z2, z2]
cond = cond[..., 1:-1, :] # [z0, z2]
cond = cond.reshape(shape + (-1, 2 * state_dim)) # [z0z2]
sample = joint.condition(cond).rsample() # [z1]
sample = torch.cat([sample, result[..., -1:, :]], dim=-2) # [z1, z3]
result = torch.stack(
[
result[..., :-1, :], # [z0, z2]
sample, # [z1, z3]
],
dim=-2,
) # [[z0, z1], [z2, z3]]
result = result.reshape(shape + (-1, state_dim)) # [z0, z1, z2, z3]

return result[..., 1:, :] # [z1, z2, z3, ...]


def _sequential_gamma_gaussian_tensordot(gamma_gaussian):
"""
Integrates a GammaGaussian ``x`` whose rightmost batch dimension is time, computes::
Expand Down Expand Up @@ -657,9 +550,9 @@ def expand(self, batch_shape, _instance=None):
new._obs = self._obs
new._trans = self._trans

# To save computation in _sequential_gaussian_tensordot(), we expand
# To save computation in sequential_gaussian_tensordot(), we expand
# only _init, which is applied only after
# _sequential_gaussian_tensordot().
# sequential_gaussian_tensordot().
batch_shape = torch.Size(broadcast_shape(self.batch_shape, batch_shape))
new._init = self._init.expand(batch_shape)

Expand All @@ -679,7 +572,7 @@ def log_prob(self, value):
)

# Eliminate time dimension.
result = _sequential_gaussian_tensordot(result.expand(result.batch_shape))
result = sequential_gaussian_tensordot(result.expand(result.batch_shape))

# Combine initial factor.
result = gaussian_tensordot(self._init, result, dims=self.hidden_dim)
Expand All @@ -695,7 +588,7 @@ def rsample(self, sample_shape=torch.Size()):
left=self.hidden_dim
)
trans = trans.expand(trans.batch_shape[:-1] + (self.duration,))
z = _sequential_gaussian_filter_sample(self._init, trans, sample_shape)
z = sequential_gaussian_filter_sample(self._init, trans, sample_shape)
x = self._obs.left_condition(z).rsample()
return x

Expand All @@ -705,7 +598,7 @@ def rsample_posterior(self, value, sample_shape=torch.Size()):
"""
trans = self._trans + self._obs.condition(value).event_pad(left=self.hidden_dim)
trans = trans.expand(trans.batch_shape)
z = _sequential_gaussian_filter_sample(self._init, trans, sample_shape)
z = sequential_gaussian_filter_sample(self._init, trans, sample_shape)
return z

def filter(self, value):
Expand All @@ -726,7 +619,7 @@ def filter(self, value):
logp = self._trans + self._obs.condition(value).event_pad(left=self.hidden_dim)

# Eliminate time dimension.
logp = _sequential_gaussian_tensordot(logp.expand(logp.batch_shape))
logp = sequential_gaussian_tensordot(logp.expand(logp.batch_shape))

# Combine initial factor.
logp = gaussian_tensordot(self._init, logp, dims=self.hidden_dim)
Expand Down Expand Up @@ -780,7 +673,7 @@ def conjugate_update(self, other):
logp = new._trans + new._obs.marginalize(right=new.obs_dim).event_pad(
left=new.hidden_dim
)
logp = _sequential_gaussian_tensordot(logp.expand(logp.batch_shape))
logp = sequential_gaussian_tensordot(logp.expand(logp.batch_shape))
logp = gaussian_tensordot(new._init, logp, dims=new.hidden_dim)
log_normalizer = logp.event_logsumexp()
new._init = new._init - log_normalizer
Expand Down Expand Up @@ -970,8 +863,8 @@ def expand(self, batch_shape, _instance=None):
new.hidden_dim = self.hidden_dim
new.obs_dim = self.obs_dim
# We only need to expand one of the inputs, since batch_shape is determined
# by broadcasting all three. To save computation in _sequential_gaussian_tensordot(),
# we expand only _init, which is applied only after _sequential_gaussian_tensordot().
# by broadcasting all three. To save computation in sequential_gaussian_tensordot(),
# we expand only _init, which is applied only after sequential_gaussian_tensordot().
new._init = self._init.expand(batch_shape)
new._trans = self._trans
new._obs = self._obs
Expand Down Expand Up @@ -1380,8 +1273,8 @@ def expand(self, batch_shape, _instance=None):
new.hidden_dim = self.hidden_dim
new.obs_dim = self.obs_dim
# We only need to expand one of the inputs, since batch_shape is determined
# by broadcasting all three. To save computation in _sequential_gaussian_tensordot(),
# we expand only _init, which is applied only after _sequential_gaussian_tensordot().
# by broadcasting all three. To save computation in sequential_gaussian_tensordot(),
# we expand only _init, which is applied only after sequential_gaussian_tensordot().
new._init = self._init.expand(batch_shape)
new._trans = self._trans
new._obs = self._obs
Expand Down Expand Up @@ -1411,7 +1304,7 @@ def log_prob(self, value):
logp = Gaussian.cat([logp_oh.expand(batch_shape), logp_h.expand(batch_shape)])

# Eliminate time dimension.
logp = _sequential_gaussian_tensordot(logp)
logp = sequential_gaussian_tensordot(logp)

# Combine initial factor.
logp = gaussian_tensordot(self._init, logp, dims=self.hidden_dim)
Expand Down
Loading