diff --git a/pyro/distributions/hmm.py b/pyro/distributions/hmm.py index 5524766414..341f4881e2 100644 --- a/pyro/distributions/hmm.py +++ b/pyro/distributions/hmm.py @@ -15,6 +15,8 @@ gaussian_tensordot, matrix_and_mvn_to_gaussian, mvn_to_gaussian, + sequential_gaussian_filter_sample, + sequential_gaussian_tensordot, ) from pyro.ops.indexing import Vindex from pyro.ops.special import safe_log @@ -159,115 +161,6 @@ def _sequential_index(samples): return samples.squeeze(-3)[..., :duration, :] -def _sequential_gaussian_tensordot(gaussian): - """ - Integrates a Gaussian ``x`` whose rightmost batch dimension is time, computes:: - - x[..., 0] @ x[..., 1] @ ... @ x[..., T-1] - """ - assert isinstance(gaussian, Gaussian) - assert gaussian.dim() % 2 == 0, "dim is not even" - batch_shape = gaussian.batch_shape[:-1] - state_dim = gaussian.dim() // 2 - while gaussian.batch_shape[-1] > 1: - time = gaussian.batch_shape[-1] - even_time = time // 2 * 2 - even_part = gaussian[..., :even_time] - x_y = even_part.reshape(batch_shape + (even_time // 2, 2)) - x, y = x_y[..., 0], x_y[..., 1] - contracted = gaussian_tensordot(x, y, state_dim) - if time > even_time: - contracted = Gaussian.cat((contracted, gaussian[..., -1:]), dim=-1) - gaussian = contracted - return gaussian[..., 0] - - -def _is_subshape(x, y): - return broadcast_shape(x, y) == y - - -def _sequential_gaussian_filter_sample(init, trans, sample_shape): - """ - Draws a reparameterized sample from a Markov product of Gaussians via - parallel-scan forward-filter backward-sample. - """ - assert isinstance(init, Gaussian) - assert isinstance(trans, Gaussian) - assert trans.dim() == 2 * init.dim() - assert _is_subshape(trans.batch_shape[:-1], init.batch_shape) - state_dim = trans.dim() // 2 - device = trans.precision.device - perm = torch.cat( - [ - torch.arange(1 * state_dim, 2 * state_dim, device=device), - torch.arange(0 * state_dim, 1 * state_dim, device=device), - torch.arange(2 * state_dim, 3 * state_dim, device=device), - ] - ) - - # Forward filter, similar to _sequential_gaussian_tensordot(). - tape = [] - shape = trans.batch_shape[:-1] # Note trans may be unbroadcasted. - gaussian = trans - while gaussian.batch_shape[-1] > 1: - time = gaussian.batch_shape[-1] - even_time = time // 2 * 2 - even_part = gaussian[..., :even_time] - x_y = even_part.reshape(shape + (even_time // 2, 2)) - x, y = x_y[..., 0], x_y[..., 1] - x = x.event_pad(right=state_dim) - y = y.event_pad(left=state_dim) - joint = (x + y).event_permute(perm) - tape.append(joint) - contracted = joint.marginalize(left=state_dim) - if time > even_time: - contracted = Gaussian.cat((contracted, gaussian[..., -1:]), dim=-1) - gaussian = contracted - gaussian = gaussian[..., 0] + init.event_pad(right=state_dim) - - # Backward sample. - shape = sample_shape + init.batch_shape - result = gaussian.rsample(sample_shape).reshape(shape + (2, state_dim)) - for joint in reversed(tape): - # The following comments demonstrate two example computations, one - # EVEN, one ODD. Ignoring sample_shape and batch_shape, let each zn be - # a single sampled event of shape (state_dim,). - if joint.batch_shape[-1] == result.size(-2) - 1: # EVEN case. - # Suppose e.g. result = [z0, z2, z4] - cond = result.repeat_interleave(2, dim=-2) # [z0, z0, z2, z2, z4, z4] - cond = cond[..., 1:-1, :] # [z0, z2, z2, z4] - cond = cond.reshape(shape + (-1, 2 * state_dim)) # [z0z2, z2z4] - sample = joint.condition(cond).rsample() # [z1, z3] - sample = torch.nn.functional.pad(sample, (0, 0, 0, 1)) # [z1, z3, 0] - result = torch.stack( - [ - result, # [z0, z2, z4] - sample, # [z1, z3, 0] - ], - dim=-2, - ) # [[z0, z1], [z2, z3], [z4, 0]] - result = result.reshape(shape + (-1, state_dim)) # [z0, z1, z2, z3, z4, 0] - result = result[..., :-1, :] # [z0, z1, z2, z3, z4] - else: # ODD case. - assert joint.batch_shape[-1] == result.size(-2) - 2 - # Suppose e.g. result = [z0, z2, z3] - cond = result[..., :-1, :].repeat_interleave(2, dim=-2) # [z0, z0, z2, z2] - cond = cond[..., 1:-1, :] # [z0, z2] - cond = cond.reshape(shape + (-1, 2 * state_dim)) # [z0z2] - sample = joint.condition(cond).rsample() # [z1] - sample = torch.cat([sample, result[..., -1:, :]], dim=-2) # [z1, z3] - result = torch.stack( - [ - result[..., :-1, :], # [z0, z2] - sample, # [z1, z3] - ], - dim=-2, - ) # [[z0, z1], [z2, z3]] - result = result.reshape(shape + (-1, state_dim)) # [z0, z1, z2, z3] - - return result[..., 1:, :] # [z1, z2, z3, ...] - - def _sequential_gamma_gaussian_tensordot(gamma_gaussian): """ Integrates a GammaGaussian ``x`` whose rightmost batch dimension is time, computes:: @@ -657,9 +550,9 @@ def expand(self, batch_shape, _instance=None): new._obs = self._obs new._trans = self._trans - # To save computation in _sequential_gaussian_tensordot(), we expand + # To save computation in sequential_gaussian_tensordot(), we expand # only _init, which is applied only after - # _sequential_gaussian_tensordot(). + # sequential_gaussian_tensordot(). batch_shape = torch.Size(broadcast_shape(self.batch_shape, batch_shape)) new._init = self._init.expand(batch_shape) @@ -679,7 +572,7 @@ def log_prob(self, value): ) # Eliminate time dimension. - result = _sequential_gaussian_tensordot(result.expand(result.batch_shape)) + result = sequential_gaussian_tensordot(result.expand(result.batch_shape)) # Combine initial factor. result = gaussian_tensordot(self._init, result, dims=self.hidden_dim) @@ -695,7 +588,7 @@ def rsample(self, sample_shape=torch.Size()): left=self.hidden_dim ) trans = trans.expand(trans.batch_shape[:-1] + (self.duration,)) - z = _sequential_gaussian_filter_sample(self._init, trans, sample_shape) + z = sequential_gaussian_filter_sample(self._init, trans, sample_shape) x = self._obs.left_condition(z).rsample() return x @@ -705,7 +598,7 @@ def rsample_posterior(self, value, sample_shape=torch.Size()): """ trans = self._trans + self._obs.condition(value).event_pad(left=self.hidden_dim) trans = trans.expand(trans.batch_shape) - z = _sequential_gaussian_filter_sample(self._init, trans, sample_shape) + z = sequential_gaussian_filter_sample(self._init, trans, sample_shape) return z def filter(self, value): @@ -726,7 +619,7 @@ def filter(self, value): logp = self._trans + self._obs.condition(value).event_pad(left=self.hidden_dim) # Eliminate time dimension. - logp = _sequential_gaussian_tensordot(logp.expand(logp.batch_shape)) + logp = sequential_gaussian_tensordot(logp.expand(logp.batch_shape)) # Combine initial factor. logp = gaussian_tensordot(self._init, logp, dims=self.hidden_dim) @@ -780,7 +673,7 @@ def conjugate_update(self, other): logp = new._trans + new._obs.marginalize(right=new.obs_dim).event_pad( left=new.hidden_dim ) - logp = _sequential_gaussian_tensordot(logp.expand(logp.batch_shape)) + logp = sequential_gaussian_tensordot(logp.expand(logp.batch_shape)) logp = gaussian_tensordot(new._init, logp, dims=new.hidden_dim) log_normalizer = logp.event_logsumexp() new._init = new._init - log_normalizer @@ -970,8 +863,8 @@ def expand(self, batch_shape, _instance=None): new.hidden_dim = self.hidden_dim new.obs_dim = self.obs_dim # We only need to expand one of the inputs, since batch_shape is determined - # by broadcasting all three. To save computation in _sequential_gaussian_tensordot(), - # we expand only _init, which is applied only after _sequential_gaussian_tensordot(). + # by broadcasting all three. To save computation in sequential_gaussian_tensordot(), + # we expand only _init, which is applied only after sequential_gaussian_tensordot(). new._init = self._init.expand(batch_shape) new._trans = self._trans new._obs = self._obs @@ -1380,8 +1273,8 @@ def expand(self, batch_shape, _instance=None): new.hidden_dim = self.hidden_dim new.obs_dim = self.obs_dim # We only need to expand one of the inputs, since batch_shape is determined - # by broadcasting all three. To save computation in _sequential_gaussian_tensordot(), - # we expand only _init, which is applied only after _sequential_gaussian_tensordot(). + # by broadcasting all three. To save computation in sequential_gaussian_tensordot(), + # we expand only _init, which is applied only after sequential_gaussian_tensordot(). new._init = self._init.expand(batch_shape) new._trans = self._trans new._obs = self._obs @@ -1411,7 +1304,7 @@ def log_prob(self, value): logp = Gaussian.cat([logp_oh.expand(batch_shape), logp_h.expand(batch_shape)]) # Eliminate time dimension. - logp = _sequential_gaussian_tensordot(logp) + logp = sequential_gaussian_tensordot(logp) # Combine initial factor. logp = gaussian_tensordot(self._init, logp, dims=self.hidden_dim) diff --git a/pyro/ops/gaussian.py b/pyro/ops/gaussian.py index 3712aa8af5..3305a2d422 100644 --- a/pyro/ops/gaussian.py +++ b/pyro/ops/gaussian.py @@ -2,6 +2,7 @@ # SPDX-License-Identifier: Apache-2.0 import math +from typing import Tuple import torch from torch.distributions.utils import lazy_property @@ -28,7 +29,12 @@ class Gaussian: :param torch.Tensor precision: precision matrix of this gaussian. """ - def __init__(self, log_normalizer, info_vec, precision): + def __init__( + self, + log_normalizer: torch.Tensor, + info_vec: torch.Tensor, + precision: torch.Tensor, + ): # NB: using info_vec instead of mean to deal with rank-deficient problem assert info_vec.dim() >= 1 assert precision.dim() >= 2 @@ -48,21 +54,21 @@ def batch_shape(self): self.precision.shape[:-2], ) - def expand(self, batch_shape): + def expand(self, batch_shape) -> "Gaussian": n = self.dim() log_normalizer = self.log_normalizer.expand(batch_shape) info_vec = self.info_vec.expand(batch_shape + (n,)) precision = self.precision.expand(batch_shape + (n, n)) return Gaussian(log_normalizer, info_vec, precision) - def reshape(self, batch_shape): + def reshape(self, batch_shape) -> "Gaussian": n = self.dim() log_normalizer = self.log_normalizer.reshape(batch_shape) info_vec = self.info_vec.reshape(batch_shape + (n,)) precision = self.precision.reshape(batch_shape + (n, n)) return Gaussian(log_normalizer, info_vec, precision) - def __getitem__(self, index): + def __getitem__(self, index) -> "Gaussian": """ Index into the batch_shape of a Gaussian. """ @@ -73,7 +79,7 @@ def __getitem__(self, index): return Gaussian(log_normalizer, info_vec, precision) @staticmethod - def cat(parts, dim=0): + def cat(parts, dim=0) -> "Gaussian": """ Concatenate a list of Gaussians along a given batch dimension. """ @@ -85,7 +91,7 @@ def cat(parts, dim=0): ] return Gaussian(*args) - def event_pad(self, left=0, right=0): + def event_pad(self, left=0, right=0) -> "Gaussian": """ Pad along event dimension. """ @@ -95,7 +101,7 @@ def event_pad(self, left=0, right=0): precision = pad(self.precision, lr + lr) return Gaussian(log_normalizer, info_vec, precision) - def event_permute(self, perm): + def event_permute(self, perm) -> "Gaussian": """ Permute along event dimension. """ @@ -105,7 +111,7 @@ def event_permute(self, perm): precision = self.precision[..., perm][..., perm, :] return Gaussian(self.log_normalizer, info_vec, precision) - def __add__(self, other): + def __add__(self, other: "Gaussian") -> "Gaussian": """ Adds two Gaussians in log-density space. """ @@ -120,12 +126,12 @@ def __add__(self, other): return Gaussian(self.log_normalizer + other, self.info_vec, self.precision) raise ValueError("Unsupported type: {}".format(type(other))) - def __sub__(self, other): + def __sub__(self, other: "Gaussian") -> "Gaussian": if isinstance(other, (int, float, torch.Tensor)): return Gaussian(self.log_normalizer - other, self.info_vec, self.precision) raise ValueError("Unsupported type: {}".format(type(other))) - def log_density(self, value): + def log_density(self, value: torch.Tensor) -> torch.Tensor: """ Evaluate the log density of this Gaussian at a point value:: @@ -135,13 +141,14 @@ def log_density(self, value): """ if value.size(-1) == 0: batch_shape = broadcast_shape(value.shape[:-1], self.batch_shape) - return self.log_normalizer.expand(batch_shape) + result: torch.Tensor = self.log_normalizer.expand(batch_shape) + return result result = (-0.5) * matvecmul(self.precision, value) result = result + self.info_vec result = (value * result).sum(-1) return result + self.log_normalizer - def rsample(self, sample_shape=torch.Size()): + def rsample(self, sample_shape=torch.Size()) -> torch.Tensor: """ Reparameterized sampler. """ @@ -150,9 +157,10 @@ def rsample(self, sample_shape=torch.Size()): shape = sample_shape + self.batch_shape + (self.dim(), 1) noise = torch.randn(shape, dtype=loc.dtype, device=loc.device) noise = triangular_solve(noise, P_chol, upper=False, transpose=True).squeeze(-1) - return loc + noise + sample: torch.Tensor = loc + noise + return sample - def condition(self, value): + def condition(self, value: torch.Tensor) -> "Gaussian": """ Condition this Gaussian on a trailing subset of its state. This should satisfy:: @@ -189,7 +197,7 @@ def condition(self, value): ) return Gaussian(log_normalizer, info_vec, precision) - def left_condition(self, value): + def left_condition(self, value: torch.Tensor) -> "Gaussian": """ Condition this Gaussian on a leading subset of its state. This should satisfy:: @@ -217,7 +225,7 @@ def left_condition(self, value): ) return self.event_permute(perm).condition(value) - def marginalize(self, left=0, right=0): + def marginalize(self, left=0, right=0) -> "Gaussian": """ Marginalizing out variables on either side of the event dimension:: @@ -259,7 +267,7 @@ def marginalize(self, left=0, right=0): ) return Gaussian(log_normalizer, info_vec, precision) - def event_logsumexp(self): + def event_logsumexp(self) -> torch.Tensor: """ Integrates out all latent state (i.e. operating on event dimensions). """ @@ -269,12 +277,13 @@ def event_logsumexp(self): self.info_vec.unsqueeze(-1), chol_P, upper=False ).squeeze(-1) u_P_u = chol_P_u.pow(2).sum(-1) - return ( + log_Z: torch.Tensor = ( self.log_normalizer + 0.5 * n * math.log(2 * math.pi) + 0.5 * u_P_u - chol_P.diagonal(dim1=-2, dim2=-1).log().sum(-1) ) + return log_Z class AffineNormal: @@ -477,7 +486,7 @@ def matrix_and_mvn_to_gaussian(matrix, mvn): return result -def gaussian_tensordot(x, y, dims=0): +def gaussian_tensordot(x: Gaussian, y: Gaussian, dims: int = 0) -> Gaussian: """ Computes the integral over two gaussians: @@ -538,3 +547,125 @@ def gaussian_tensordot(x, y, dims=0): log_normalizer = log_normalizer + diff return Gaussian(log_normalizer, info_vec, precision) + + +def sequential_gaussian_tensordot(gaussian: Gaussian) -> Gaussian: + """ + Integrates a Gaussian ``x`` whose rightmost batch dimension is time, computes:: + + x[..., 0] @ x[..., 1] @ ... @ x[..., T-1] + + :param Gaussian gaussian: A batched Gaussian whose rightmost dimension is time. + :returns: A Markov product of the Gaussian along its time dimension. + :rtype: Gaussian + """ + assert isinstance(gaussian, Gaussian) + assert gaussian.dim() % 2 == 0, "dim is not even" + batch_shape = gaussian.batch_shape[:-1] + state_dim = gaussian.dim() // 2 + while gaussian.batch_shape[-1] > 1: + time = gaussian.batch_shape[-1] + even_time = time // 2 * 2 + even_part = gaussian[..., :even_time] + x_y = even_part.reshape(batch_shape + (even_time // 2, 2)) + x, y = x_y[..., 0], x_y[..., 1] + contracted = gaussian_tensordot(x, y, state_dim) + if time > even_time: + contracted = Gaussian.cat((contracted, gaussian[..., -1:]), dim=-1) + gaussian = contracted + return gaussian[..., 0] + + +def _is_subshape(x, y): + return broadcast_shape(x, y) == y + + +def sequential_gaussian_filter_sample( + init: Gaussian, trans: Gaussian, sample_shape: Tuple[int, ...] = () +) -> torch.Tensor: + """ + Draws a reparameterized sample from a Markov product of Gaussians via + parallel-scan forward-filter backward-sample. + + :param Gaussian init: A Gaussian representing an initial state. + :param Gaussian trans: A Gaussian representing as series of state transitions, + with time as the rightmost batch dimension. + :param tuple sample_shape: An optional batch shape of samples to draw. + :returns: A reparametrized sample. + :rtype: torch.Tensor + """ + assert isinstance(init, Gaussian) + assert isinstance(trans, Gaussian) + assert trans.dim() == 2 * init.dim() + assert _is_subshape(trans.batch_shape[:-1], init.batch_shape) + state_dim = trans.dim() // 2 + device = trans.precision.device + perm = torch.cat( + [ + torch.arange(1 * state_dim, 2 * state_dim, device=device), + torch.arange(0 * state_dim, 1 * state_dim, device=device), + torch.arange(2 * state_dim, 3 * state_dim, device=device), + ] + ) + + # Forward filter, similar to sequential_gaussian_tensordot(). + tape = [] + shape = trans.batch_shape[:-1] # Note trans may be unbroadcasted. + gaussian = trans + while gaussian.batch_shape[-1] > 1: + time = gaussian.batch_shape[-1] + even_time = time // 2 * 2 + even_part = gaussian[..., :even_time] + x_y = even_part.reshape(shape + (even_time // 2, 2)) + x, y = x_y[..., 0], x_y[..., 1] + x = x.event_pad(right=state_dim) + y = y.event_pad(left=state_dim) + joint = (x + y).event_permute(perm) + tape.append(joint) + contracted = joint.marginalize(left=state_dim) + if time > even_time: + contracted = Gaussian.cat((contracted, gaussian[..., -1:]), dim=-1) + gaussian = contracted + gaussian = gaussian[..., 0] + init.event_pad(right=state_dim) + + # Backward sample. + shape = sample_shape + init.batch_shape + result = gaussian.rsample(sample_shape).reshape(shape + (2, state_dim)) + for joint in reversed(tape): + # The following comments demonstrate two example computations, one + # EVEN, one ODD. Ignoring sample_shape and batch_shape, let each zn be + # a single sampled event of shape (state_dim,). + if joint.batch_shape[-1] == result.size(-2) - 1: # EVEN case. + # Suppose e.g. result = [z0, z2, z4] + cond = result.repeat_interleave(2, dim=-2) # [z0, z0, z2, z2, z4, z4] + cond = cond[..., 1:-1, :] # [z0, z2, z2, z4] + cond = cond.reshape(shape + (-1, 2 * state_dim)) # [z0z2, z2z4] + sample = joint.condition(cond).rsample() # [z1, z3] + sample = torch.nn.functional.pad(sample, (0, 0, 0, 1)) # [z1, z3, 0] + result = torch.stack( + [ + result, # [z0, z2, z4] + sample, # [z1, z3, 0] + ], + dim=-2, + ) # [[z0, z1], [z2, z3], [z4, 0]] + result = result.reshape(shape + (-1, state_dim)) # [z0, z1, z2, z3, z4, 0] + result = result[..., :-1, :] # [z0, z1, z2, z3, z4] + else: # ODD case. + assert joint.batch_shape[-1] == result.size(-2) - 2 + # Suppose e.g. result = [z0, z2, z3] + cond = result[..., :-1, :].repeat_interleave(2, dim=-2) # [z0, z0, z2, z2] + cond = cond[..., 1:-1, :] # [z0, z2] + cond = cond.reshape(shape + (-1, 2 * state_dim)) # [z0z2] + sample = joint.condition(cond).rsample() # [z1] + sample = torch.cat([sample, result[..., -1:, :]], dim=-2) # [z1, z3] + result = torch.stack( + [ + result[..., :-1, :], # [z0, z2] + sample, # [z1, z3] + ], + dim=-2, + ) # [[z0, z1], [z2, z3]] + result = result.reshape(shape + (-1, state_dim)) # [z0, z1, z2, z3] + + return result[..., 1:, :] # [z1, z2, z3, ...] diff --git a/tests/distributions/test_hmm.py b/tests/distributions/test_hmm.py index 1bd88a3d1f..01ea316855 100644 --- a/tests/distributions/test_hmm.py +++ b/tests/distributions/test_hmm.py @@ -13,8 +13,6 @@ import pyro.distributions as dist from pyro.distributions.hmm import ( _sequential_gamma_gaussian_tensordot, - _sequential_gaussian_filter_sample, - _sequential_gaussian_tensordot, _sequential_logmatmulexp, ) from pyro.distributions.util import broadcast_shape @@ -36,7 +34,7 @@ random_gamma, random_gamma_gaussian, ) -from tests.ops.gaussian import assert_close_gaussian, random_gaussian, random_mvn +from tests.ops.gaussian import random_mvn logger = logging.getLogger(__name__) @@ -93,35 +91,6 @@ def test_sequential_logmatmulexp(batch_shape, state_dim, num_steps): assert_close(actual, expected) -@pytest.mark.parametrize("num_steps", list(range(1, 20))) -@pytest.mark.parametrize("state_dim", [1, 2, 3]) -@pytest.mark.parametrize("batch_shape", [(), (5,), (2, 4)], ids=str) -def test_sequential_gaussian_tensordot(batch_shape, state_dim, num_steps): - g = random_gaussian(batch_shape + (num_steps,), state_dim + state_dim) - actual = _sequential_gaussian_tensordot(g) - assert actual.dim() == g.dim() - assert actual.batch_shape == batch_shape - - # Check against hand computation. - expected = g[..., 0] - for t in range(1, num_steps): - expected = gaussian_tensordot(expected, g[..., t], state_dim) - assert_close_gaussian(actual, expected) - - -@pytest.mark.parametrize("num_steps", list(range(1, 20))) -@pytest.mark.parametrize("state_dim", [1, 2, 3]) -@pytest.mark.parametrize("batch_shape", [(), (4,), (3, 2)], ids=str) -@pytest.mark.parametrize("sample_shape", [(), (4,), (3, 2)], ids=str) -def test_sequential_gaussian_filter_sample( - sample_shape, batch_shape, state_dim, num_steps -): - init = random_gaussian(batch_shape, state_dim) - trans = random_gaussian(batch_shape + (num_steps,), state_dim + state_dim) - sample = _sequential_gaussian_filter_sample(init, trans, sample_shape) - assert sample.shape == sample_shape + batch_shape + (num_steps, state_dim) - - @pytest.mark.parametrize("num_steps", list(range(1, 20))) @pytest.mark.parametrize("state_dim", [1, 2, 3]) @pytest.mark.parametrize("batch_shape", [(), (5,), (2, 4)], ids=str) diff --git a/tests/ops/test_gaussian.py b/tests/ops/test_gaussian.py index fa5924b128..6d4e043885 100644 --- a/tests/ops/test_gaussian.py +++ b/tests/ops/test_gaussian.py @@ -17,6 +17,8 @@ gaussian_tensordot, matrix_and_mvn_to_gaussian, mvn_to_gaussian, + sequential_gaussian_filter_sample, + sequential_gaussian_tensordot, ) from tests.common import assert_close from tests.ops.gaussian import assert_close_gaussian, random_gaussian, random_mvn @@ -488,3 +490,32 @@ def check_equal(actual, expected, atol=0.01, rtol=0): funsor.ops.mean, "particle" ) check_equal(fp_entropy.data, entropy) + + +@pytest.mark.parametrize("num_steps", list(range(1, 20))) +@pytest.mark.parametrize("state_dim", [1, 2, 3]) +@pytest.mark.parametrize("batch_shape", [(), (5,), (2, 4)], ids=str) +def test_sequential_gaussian_tensordot(batch_shape, state_dim, num_steps): + g = random_gaussian(batch_shape + (num_steps,), state_dim + state_dim) + actual = sequential_gaussian_tensordot(g) + assert actual.dim() == g.dim() + assert actual.batch_shape == batch_shape + + # Check against hand computation. + expected = g[..., 0] + for t in range(1, num_steps): + expected = gaussian_tensordot(expected, g[..., t], state_dim) + assert_close_gaussian(actual, expected) + + +@pytest.mark.parametrize("num_steps", list(range(1, 20))) +@pytest.mark.parametrize("state_dim", [1, 2, 3]) +@pytest.mark.parametrize("batch_shape", [(), (4,), (3, 2)], ids=str) +@pytest.mark.parametrize("sample_shape", [(), (4,), (3, 2)], ids=str) +def test_sequential_gaussian_filter_sample( + sample_shape, batch_shape, state_dim, num_steps +): + init = random_gaussian(batch_shape, state_dim) + trans = random_gaussian(batch_shape + (num_steps,), state_dim + state_dim) + sample = sequential_gaussian_filter_sample(init, trans, sample_shape) + assert sample.shape == sample_shape + batch_shape + (num_steps, state_dim)