diff --git a/docs/source/interpretations.rst b/docs/source/interpretations.rst index 009afb84..61be0e5d 100644 --- a/docs/source/interpretations.rst +++ b/docs/source/interpretations.rst @@ -23,6 +23,13 @@ Monte Carlo :show-inheritance: :member-order: bysource +Preconditioning +--------------- +.. automodule:: funsor.precondition + :members: + :show-inheritance: + :member-order: bysource + Approximations -------------- .. automodule:: funsor.approximations diff --git a/funsor/__init__.py b/funsor/__init__.py index dfbbf86a..9ac1ee5d 100644 --- a/funsor/__init__.py +++ b/funsor/__init__.py @@ -46,6 +46,7 @@ joint, montecarlo, ops, + precondition, recipes, sum_product, terms, @@ -102,6 +103,7 @@ "montecarlo", "of_shape", "ops", + "precondition", "pretty", "quote", "reals", diff --git a/funsor/adjoint.py b/funsor/adjoint.py index a573e448..4b779380 100644 --- a/funsor/adjoint.py +++ b/funsor/adjoint.py @@ -67,10 +67,7 @@ def __enter__(self): self._old_interpretation = interpreter.get_interpretation() return super().__enter__() - def adjoint(self, sum_op, bin_op, root, targets=None, *, batch_vars=frozenset()): - # TODO Replace this with root + Constant(...) after #548 merges. - root_vars = root.input_vars | batch_vars - + def adjoint(self, sum_op, bin_op, root, targets=None, *, batch_vars=set()): zero = to_funsor(ops.UNITS[sum_op]) one = to_funsor(ops.UNITS[bin_op]) adjoint_values = defaultdict(lambda: zero) @@ -118,7 +115,7 @@ def adjoint(self, sum_op, bin_op, root, targets=None, *, batch_vars=frozenset()) in_adjs = adjoint_ops(fn, sum_op, bin_op, adjoint_values[output], *inputs) for v, adjv in in_adjs: # Marginalize out message variables that don't appear in recipients. - agg_vars = adjv.input_vars - v.input_vars - root_vars + agg_vars = adjv.input_vars - v.input_vars - root.input_vars - batch_vars assert "particle" not in {var.name for var in agg_vars} # DEBUG FIXME old_value = adjoint_values[v] adjoint_values[v] = sum_op(old_value, adjv.reduce(sum_op, agg_vars)) diff --git a/funsor/cnf.py b/funsor/cnf.py index 2ce978e0..8d2bf65e 100644 --- a/funsor/cnf.py +++ b/funsor/cnf.py @@ -106,6 +106,11 @@ def _sample(self, sampled_vars, sample_inputs, rng_key): sampled_vars = sampled_vars.intersection(self.inputs) if not sampled_vars: return self + for term in self.terms: + if isinstance(term, Delta): + sampled_vars -= term.fresh + if not sampled_vars: + return self if self.red_op in (ops.null, ops.logaddexp): if rng_key is not None and get_backend() == "jax": @@ -116,8 +121,8 @@ def _sample(self, sampled_vars, sample_inputs, rng_key): rng_keys = [None] * len(self.terms) if self.bin_op in (ops.null, ops.logaddexp): - # Design choice: we sample over logaddexp reductions, but leave logaddexp - # binary choices symbolic. + # Design choice: we sample over logaddexp reductions, but leave + # logaddexp binary choices symbolic. terms = [ term._sample( sampled_vars.intersection(term.inputs), sample_inputs, rng_key @@ -132,11 +137,15 @@ def _sample(self, sampled_vars, sample_inputs, rng_key): greedy_vars = sampled_vars.intersection(term.inputs) if greedy_vars: break + assert greedy_vars greedy_terms, terms = [], [] for term in self.terms: - ( - terms if greedy_vars.isdisjoint(term.inputs) else greedy_terms - ).append(term) + if greedy_vars.isdisjoint(term.inputs): + terms.append(term) + elif isinstance(term, Delta) and greedy_vars.isdisjoint(term.fresh): + terms.append(term) + else: + greedy_terms.append(term) if len(greedy_terms) == 1: term = greedy_terms[0] terms.append(term._sample(greedy_vars, sample_inputs, rng_keys[0])) @@ -392,7 +401,7 @@ def _(fn): # Normalizing Contractions ########################################## -ORDERING = {Delta: 1, Number: 2, Tensor: 3, Gaussian: 4} +ORDERING = {Delta: 1, Number: 2, Tensor: 3, Gaussian: 4, Unary[ops.NegOp, Gaussian]: 5} GROUND_TERMS = tuple(ORDERING) diff --git a/funsor/domains.py b/funsor/domains.py index c5add799..ac9bab76 100644 --- a/funsor/domains.py +++ b/funsor/domains.py @@ -494,6 +494,22 @@ def _find_domain_stack(op, parts): return output +@find_domain.register(ops.CatOp) +def _find_domain_cat(op, parts): + dim = op.defaults["axis"] + if dim >= 0: + event_dims = {len(x.shape) for x in parts} + assert len(event_dims) == 1, "undefined" + dim = dim - next(iter(event_dims)) + assert dim < 0 + shape = broadcast_shape(*(x.shape[:dim] for x in parts)) + shape += (sum(x.shape[dim] for x in parts),) + if dim < -1: + shape += broadcast_shape(*(x.shape[dim + 1 :] for x in parts)) + output = Array[parts[0].dtype, shape] + return output + + @find_domain.register(ops.EinsumOp) def _find_domain_einsum(op, operands): equation = op.defaults["equation"] diff --git a/funsor/gaussian.py b/funsor/gaussian.py index 50d5dfc2..2e04a417 100644 --- a/funsor/gaussian.py +++ b/funsor/gaussian.py @@ -6,9 +6,6 @@ from contextlib import contextmanager from functools import reduce -import numpy as np - -import funsor import funsor.ops as ops from funsor.affine import affine_inputs, extract_affine, is_affine from funsor.delta import Delta @@ -17,7 +14,6 @@ from funsor.ops import AddOp, SubOp from funsor.tensor import Tensor, align_tensor, align_tensors from funsor.terms import ( - Align, Binary, Funsor, FunsorMeta, @@ -28,7 +24,7 @@ eager, reflect, ) -from funsor.util import broadcast_shape, get_backend, get_tracing_state, lazy_property +from funsor.util import broadcast_shape, get_tracing_state, lazy_property def _log_det_tri(x): @@ -854,13 +850,12 @@ def eager_reduce(self, op, reduced_vars): reduced_ints = reduced_vars - real_vars if not reduced_reals: return None # defer to default implementation + if reduced_reals == real_vars: + return self.log_normalizer.reduce(ops.logaddexp, reduced_ints) inputs = OrderedDict( (k, d) for k, d in self.inputs.items() if k not in reduced_reals ) - if reduced_reals == real_vars: - return self.log_normalizer.reduce(ops.logaddexp, reduced_ints) - int_inputs = OrderedDict( (k, v) for k, v in inputs.items() if v.dtype != "real" ) @@ -898,7 +893,6 @@ def eager_reduce(self, op, reduced_vars): b, a = _split_real_inputs(self.inputs, reduced_vars, self.white_vec) prec_sqrt_a = self.prec_sqrt[..., a, :] prec_sqrt_b = self.prec_sqrt[..., b, :] - dim_a = prec_sqrt_a.shape[-2] dim_b = prec_sqrt_b.shape[-2] if self.rank < dim_b: raise ValueError( @@ -906,27 +900,9 @@ def eager_reduce(self, op, reduced_vars): "Consider adding a prior." ) precision_chol_b = ops.cholesky(_mmt(prec_sqrt_b)) # assume full rank - b_log_normalizer = Tensor( - dim_b * math.log(2 * math.pi) / 2 - _log_det_tri(precision_chol_b), - int_inputs, + result = self._marginalize_after_split( + inputs, int_inputs, prec_sqrt_b, prec_sqrt_a, precision_chol_b ) - result = b_log_normalizer - if self.rank > dim_b: - proj_b = _mtm(ops.triangular_solve(prec_sqrt_b, precision_chol_b)) - prec_sqrt = prec_sqrt_a - prec_sqrt_a @ proj_b - white_vec = self.white_vec - _vm(self.white_vec, proj_b) - result += Gaussian(white_vec, prec_sqrt, inputs) - else: # The Gaussian over xa is zero. - # TODO switch from an empty Gaussian to a Constant once this works: - # from .constant import Constant - # const_inputs = OrderedDict( - # (k, v) for k, v in inputs.items() if k not in result.inputs - # ) - # result = Constant(const_inputs, result) - batch_shape = self.white_vec.shape[:-1] - white_vec = ops.new_zeros(self.white_vec, batch_shape + (0,)) - prec_sqrt = ops.new_zeros(self.white_vec, batch_shape + (dim_a, 0)) - result += Gaussian(white_vec, prec_sqrt, inputs) return result.reduce(ops.logaddexp, reduced_ints) elif op is ops.add: @@ -982,50 +958,150 @@ def _sample(self, sampled_vars, sample_inputs, rng_key): sample_inputs = OrderedDict( (k, d) for k, d in sample_inputs.items() if k not in self.inputs ) - sample_shape = tuple(int(d.dtype) for d in sample_inputs.values()) - int_inputs = OrderedDict( - (k, d) for k, d in self.inputs.items() if d.dtype != "real" - ) - real_inputs = OrderedDict( - (k, d) for k, d in self.inputs.items() if d.dtype == "real" - ) - inputs = sample_inputs.copy() - inputs.update(int_inputs) + int_inputs = OrderedDict() + sampled_real_inputs = OrderedDict() + remaining_real_inputs = OrderedDict() + for k, d in self.inputs.items(): + if d.dtype != "real": + int_inputs[k] = d + elif k in sampled_vars: + sampled_real_inputs[k] = d + else: + remaining_real_inputs[k] = d + if self.rank < sum(d.num_elements for d in sampled_real_inputs.values()): + raise ValueError( + f"Too little information to sample over {set(sampled_vars)}. " + "Consider adding a prior." + ) - assert self.is_full_rank - if sampled_vars == frozenset(real_inputs): - # Call _compress_rank() to triangularize. + if not remaining_real_inputs: # Sample all variables. + # Triangularize via _compress_rank(). white_vec, prec_sqrt, _ = _compress_rank( self.white_vec, self.prec_sqrt, assume_full_rank=True ) - shape = sample_shape + white_vec.shape - backend = get_backend() - if backend != "numpy": - from importlib import import_module - - dist = import_module( - funsor.distribution.BACKEND_TO_DISTRIBUTIONS_BACKEND[backend] - ) - sample_args = (shape,) if rng_key is None else (rng_key, shape) - white_noise = dist.Normal.dist_class(0, 1).sample(*sample_args) - else: - white_noise = np.random.randn(*shape) + # Jointly sample. + # This section may involve either Funsors or backend arrays. + dim = prec_sqrt.shape[-1] + white_noise = _sample_white_noise( + sample_inputs, int_inputs, dim, self.white_vec, rng_key + ) + if isinstance(white_noise, Funsor): + white_vec = Tensor(white_vec, int_inputs) + prec_sqrt = Tensor(prec_sqrt, int_inputs) sample = ops.triangular_solve( (white_noise + white_vec)[..., None], prec_sqrt, transpose=True )[..., 0] - offsets, _ = _compute_offsets(real_inputs) - results = [] - for key, domain in real_inputs.items(): - data = sample[..., offsets[key] : offsets[key] + domain.num_elements] - data = data.reshape(shape[:-1] + domain.shape) - point = Tensor(data, inputs) - assert point.output == domain - results.append(Delta(key, point)) - results.append(self.log_normalizer) - return reduce(ops.add, results) - raise NotImplementedError("TODO implement partial sampling of real variables") + # Compute the remaining Tensor. + remaining = self.log_normalizer + + else: # Sample only a subset of real variables. + # Split into sampled variables a and remaining variables b. + a, b = _split_real_inputs(self.inputs, sampled_vars, self.white_vec) + prec_sqrt_a = self.prec_sqrt[..., a, :] + prec_sqrt_b = self.prec_sqrt[..., b, :] + dim_a = prec_sqrt_a.shape[-2] + + # Compute white_vec of a lazily conditioned on b's variables. + # This requires Funsors rather than backend arrays. + flat = ops.cat( + [ + Variable(k, d).reshape((d.num_elements,)) + for k, d in remaining_real_inputs.items() + ] + ) + white_vec_a = ( + Tensor(self.white_vec, int_inputs) + - (flat[None] @ Tensor(prec_sqrt_b, int_inputs))[0] + ) + + # Triangularize. + precision_chol_a = Tensor(ops.cholesky(_mmt(prec_sqrt_a)), int_inputs) + white_vec_a = ops.triangular_solve( + Tensor(prec_sqrt_a, int_inputs) @ white_vec_a[..., None], + precision_chol_a, + )[..., 0] + + # Jointly sample. + white_noise = _sample_white_noise( + sample_inputs, int_inputs, dim_a, self.white_vec, rng_key + ) + if not isinstance(white_noise, Funsor): + inputs = sample_inputs.copy() + inputs.update(int_inputs) + white_noise = Tensor(white_noise, inputs) + sample = ops.triangular_solve( + (white_noise + white_vec_a)[..., None], precision_chol_a, transpose=True + )[..., 0] + + # Compute the remaining Gaussian, equivalent to + # self.reduce(ops.logaddexp, sampled_vars), but avoiding duplicate work. + inputs = int_inputs.copy() + inputs.update(remaining_real_inputs) + remaining = self._marginalize_after_split( + inputs, int_inputs, prec_sqrt_a, prec_sqrt_b, precision_chol_a.data + ) + + # Extract shaped components of the flat concatenated sample. + results = [remaining] + offsets, _ = _compute_offsets(sampled_real_inputs) + for key, domain in sampled_real_inputs.items(): + point = sample[..., offsets[key] : offsets[key] + domain.num_elements] + point = point.reshape(point.shape[:-1] + domain.shape) + if not isinstance(point, Funsor): # I.e. when eagerly sampling. + inputs = sample_inputs.copy() + inputs.update(int_inputs) + point = Tensor(point, inputs) + assert point.output == domain + results.append(Delta(key, point)) + + return reduce(ops.add, results) + + def _marginalize_after_split( + self, inputs, int_inputs, prec_sqrt_a, prec_sqrt_b, precision_chol_a + ): + """ + Helper used in partial reduction and partial sampling. + This marginalizes over a and returns a shifted Gaussian over b. + """ + dim_a = prec_sqrt_a.shape[-2] + dim_b = prec_sqrt_b.shape[-2] + result = Tensor( + dim_a * math.log(2 * math.pi) / 2 - _log_det_tri(precision_chol_a), + int_inputs, + ) + if self.rank > dim_a: + proj_a = _mtm(ops.triangular_solve(prec_sqrt_a, precision_chol_a)) + prec_sqrt = prec_sqrt_b - prec_sqrt_b @ proj_a + white_vec = self.white_vec - _vm(self.white_vec, proj_a) + result += Gaussian(white_vec, prec_sqrt, inputs) + else: # The Gaussian over xa is zero. + # TODO switch from an empty Gaussian to a Constant once this works: + # from .constant import Constant + # const_inputs = OrderedDict( + # (k, v) for k, v in inputs.items() if k not in result.inputs + # ) + # result = Constant(const_inputs, result) + batch_shape = self.white_vec.shape[:-1] + white_vec = ops.new_zeros(self.white_vec, batch_shape + (0,)) + prec_sqrt = ops.new_zeros(self.white_vec, batch_shape + (dim_b, 0)) + result += Gaussian(white_vec, prec_sqrt, inputs) + return result + + +def _sample_white_noise(sample_inputs, int_inputs, dim, prototype, rng_key): + if [v.dtype for v in sample_inputs.values()] == ["real"]: + # Lazily compute a sample as a function of white noise. + k, d = next(iter(sample_inputs.items())) + return Variable(k, d)[tuple(int_inputs)] + + # Eagerly draw noise. + shape = tuple(d.size for d in sample_inputs.values() if d.dtype != "real") + shape += tuple(d.size for d in int_inputs.values()) + shape += (dim,) + assert ops.is_numeric_array(prototype) + return ops.randn(prototype, shape, rng_key) @compress_gaussians.register(Gaussian, object, object, tuple) @@ -1056,10 +1132,9 @@ def eager_add_gaussian_gaussian(op, lhs, rhs): return Gaussian(white_vec, prec_sqrt, inputs) -@eager.register(Binary, SubOp, Gaussian, (Funsor, Align, Gaussian)) -@eager.register(Binary, SubOp, (Funsor, Align, Delta), Gaussian) +@eager.register(Binary, SubOp, Gaussian, Gaussian) def eager_sub(op, lhs, rhs): - return lhs + -rhs + return lhs + (-rhs) __all__ = [ diff --git a/funsor/jax/ops.py b/funsor/jax/ops.py index 5106beb9..9111a4de 100644 --- a/funsor/jax/ops.py +++ b/funsor/jax/ops.py @@ -5,6 +5,7 @@ import typing import jax.numpy as np +import jax.random import numpy as onp from jax import lax from jax.core import Tracer @@ -259,6 +260,12 @@ def _new_zeros(x, shape): return onp.zeros(shape, dtype=np.result_type(x)) +@ops.randn.register(array) +def _randn(prototype, shape, rng_key=None): + assert isinstance(shape, tuple) + return jax.random.normal(rng_key, shape, dtype=prototype.dtype) + + @ops.reciprocal.register(array) def _reciprocal(x): result = np.clip(np.reciprocal(x), a_max=np.finfo(np.result_type(x)).max) diff --git a/funsor/montecarlo.py b/funsor/montecarlo.py index 06d66961..c5527377 100644 --- a/funsor/montecarlo.py +++ b/funsor/montecarlo.py @@ -6,10 +6,11 @@ from funsor.cnf import Contraction from funsor.delta import Delta +from funsor.gaussian import Gaussian from funsor.integrate import Integrate from funsor.interpretations import StatefulInterpretation from funsor.tensor import Tensor -from funsor.terms import Approximate, Funsor, Number +from funsor.terms import Approximate, Funsor, Number, Subs, Unary from funsor.util import get_backend from . import ops @@ -86,8 +87,11 @@ def _extract_samples_contraction(discrete_density): return result +@extract_samples.register(Subs) @extract_samples.register(Number) @extract_samples.register(Tensor) +@extract_samples.register(Gaussian) +@extract_samples.register(Unary) def _extract_samples_scale(discrete_density): return {} diff --git a/funsor/ops/array.py b/funsor/ops/array.py index 848345c2..b5ae8afe 100644 --- a/funsor/ops/array.py +++ b/funsor/ops/array.py @@ -180,7 +180,7 @@ def _astype(x, dtype): @FinitaryOp.make -def cat(parts, axis): +def cat(parts, axis=0): raise NotImplementedError @@ -366,6 +366,12 @@ def new_eye(x, shape): return np.broadcast_to(np.eye(n), shape + (n,)) +@UnaryOp.make +def randn(prototype, shape, rng_key=None): + assert isinstance(shape, tuple) + return np.random.randn(*shape) + + @UnaryOp.make def permute(x, dims): return np.transpose(x, axes=dims) @@ -496,6 +502,7 @@ def unsqueeze(x, dim): "permute", "prod", "qr", + "randn", "sample", "scatter", "scatter_add", diff --git a/funsor/precondition.py b/funsor/precondition.py new file mode 100644 index 00000000..5d080883 --- /dev/null +++ b/funsor/precondition.py @@ -0,0 +1,141 @@ +# Copyright Contributors to the Pyro project. +# SPDX-License-Identifier: Apache-2.0 + +from collections import OrderedDict + +from . import ops +from .cnf import Contraction, GaussianMixture +from .domains import Reals +from .gaussian import Gaussian +from .interpretations import StatefulInterpretation +from .terms import Approximate, Funsor, Subs, Variable + + +class Precondition(StatefulInterpretation): + """ + Preconditioning interpretation for adjoint computations. + + This interpretation is intended to be used once, followed by a call to + :meth:`combine_subs` as follows:: + + # Lazily build a factor graph. + with reflect: + log_joint = Gaussian(...) + ... + Gaussian(...) + log_Z = log_joint.reduce(ops.logaddexp) + + # Run a backward sampling under the precondition interpretation. + with Precondition() as p: + marginals = adjoint( + ops.logaddexp, ops.add, log_Z, batch_vars=p.sample_vars + ) + combine_subs = p.combine_subs() + + # Extract samples from Delta distributions. + samples = { + k: v(**combine_subs) + for name, delta in marginals.items() + for k, v in funsor.montecarlo.extract_samples(delta).items() + } + + See :func:`~funsor.recipes.forward_filter_backward_precondition` for + complete usage. + + :param str aux_name: Name of the auxiliary variable containing white noise. + """ + + def __init__(self, aux_name="aux"): + super().__init__("precondition") + self.aux_name = aux_name + self.sample_inputs = OrderedDict() + self.sample_vars = set() + + def combine_subs(self): + """ + Method to create a combining substitution after preconditioning is + complete. The returned substitution replaces per-factor auxiliary + variables with slices into a single combined auxiliary variable. + + :returns: A substitution indexing each factor-wise auxiliary variable + into a single global auxiliary variable. + :rtype: dict + """ + total_size = sum(v.num_elements for v in self.sample_inputs.values()) + aux = Variable(self.aux_name, Reals[total_size]) + subs = {} + start = 0 + for k, v in self.sample_inputs.items(): + stop = start + v.num_elements + subs[k] = aux[start:stop].reshape(v.shape) + start = stop + return subs + + +@Precondition.register(Approximate, ops.LogaddexpOp, Funsor, Funsor, frozenset) +def precondition_approximate_todo(state, op, model, guide, approx_vars): + if approx_vars.isdisjoint(guide.input_vars): + return + raise NotImplementedError("TODO handle:\n" + guide.pretty(100, 0)) + + +@Precondition.register( + Approximate, + ops.LogaddexpOp, + Funsor, + Contraction[ops.NullOp, ops.AddOp, frozenset, tuple], + frozenset, +) +def precondition_approximate_contraction(state, op, model, guide, approx_vars): + # Eagerly winnow approx_vars. + approx_vars = approx_vars.intersection(guide.input_vars) + if not approx_vars: + return model + + terms = [ + term for term in guide.terms if not approx_vars.isdisjoint(term.input_vars) + ] + if len(terms) == 1: + guide = terms[0] + return Approximate(ops.logaddexp, model, guide, approx_vars) + raise NotImplementedError("TODO") + + +@Precondition.register(Approximate, ops.LogaddexpOp, Funsor, GaussianMixture, frozenset) +def precondition_approximate_gaussian_mixture(state, op, model, guide, approx_vars): + tensor, gaussian = guide.terms + return precondition_approximate_gaussian(state, op, model, gaussian, approx_vars) + + +@Precondition.register(Approximate, ops.LogaddexpOp, Funsor, Gaussian, frozenset) +@Precondition.register( + Approximate, ops.LogaddexpOp, Funsor, Subs[Gaussian, tuple], frozenset +) +def precondition_approximate_gaussian(state, op, model, guide, approx_vars): + # Eagerly winnow approx_vars. + approx_vars = approx_vars.intersection(guide.input_vars) + if not approx_vars: + return model + + # Determine how much white noise is needed to generate a sample. + batch_shape = [] + event_numel = 0 + for k, d in guide.inputs.items(): + if d.dtype == "real": + if Variable(k, d) in approx_vars: + event_numel += d.num_elements + else: + batch_shape += (d.size,) + shape = tuple(batch_shape) + (event_numel,) + name = f"{state.aux_name}_{len(state.sample_inputs)}" + state.sample_inputs[name] = Reals[shape] + state.sample_vars.add(Variable(name, Reals[shape])) + + # Precondition this factor. + sample = guide.sample(approx_vars, OrderedDict([(name, Reals[shape])])) + assert sample is not guide, "no progress" + result = sample + model - guide + return result + + +__all__ = [ + "Precondition", +] diff --git a/funsor/recipes.py b/funsor/recipes.py index 1e52b2ee..6e119037 100644 --- a/funsor/recipes.py +++ b/funsor/recipes.py @@ -82,3 +82,79 @@ def forward_filter_backward_rsample( assert set(log_prob.inputs) == set(sample_inputs) return samples, log_prob + + +def forward_filter_backward_precondition( + factors: Dict[str, funsor.Funsor], + eliminate: FrozenSet[str], + plates: FrozenSet[str], + aux_name: str = "aux", +): + """ + A forward-filter backward-precondition algorithm for use in variational + inference or preconditioning in Hamiltonian Monte Carlo. The motivating use + case is performing Gaussian tensor variable elimination over structured + variational posteriors, and optionally using the learned posterior to + determine momentum in HMC. + + :param dict factors: A dictionary mapping sample site name to a Funsor + factor created at that sample site. + :param frozenset: A set of names of latent variables to marginalize and + plates to aggregate. + :param plates: A set of names of plates to aggregate. + :param str aux_name: Name of the auxiliary variable containing white noise. + :returns: A pair ``samples:Dict[str, Tensor], log_prob: Tensor`` of samples + and log density evaluated at each of those samples. Both outputs depend + on a vector named by ``aux_name``, e.g. ``aux: Reals[d]`` where ``d`` + is the total number of elements in eliminated variables. + :rtype: tuple + """ + assert isinstance(factors, dict) + assert all(isinstance(k, str) for k in factors) + assert all(isinstance(v, funsor.Funsor) for v in factors.values()) + assert isinstance(eliminate, frozenset) + assert all(isinstance(v, str) for v in eliminate) + assert isinstance(plates, frozenset) + assert all(isinstance(v, str) for v in plates) + assert isinstance(aux_name, str) + assert not any(aux_name in f.inputs for f in factors.values()) + + # Perform tensor variable elimination. + with funsor.interpretations.reflect: + log_Z = funsor.sum_product.sum_product( + funsor.ops.logaddexp, + funsor.ops.add, + list(factors.values()), + eliminate, + plates, + ) + log_Z = funsor.optimizer.apply_optimizer(log_Z) + with funsor.precondition.Precondition(aux_name=aux_name) as precondition: + log_Z, marginals = funsor.adjoint.forward_backward( + funsor.ops.logaddexp, + funsor.ops.add, + log_Z, + batch_vars=precondition.sample_vars, + ) + + # Extract sample tensors. + samples = {} + for name, factor in factors.items(): + if name in eliminate: + samples.update(funsor.montecarlo.extract_samples(marginals[factor])) + assert frozenset(samples) == eliminate - plates + + # Combine into a single auxiliary variable. + subs = precondition.combine_subs() + samples = {k: v(**subs) for k, v in samples.items()} + + # Compute log density at each sample, lazily dependent on aux_name. + log_prob = -log_Z + for f in factors.values(): + term = f(**samples) + plates = eliminate.intersection(term.inputs) + term = term.reduce(funsor.ops.add, plates) + log_prob += term + assert set(log_prob.inputs) == {aux_name} + + return samples, log_prob diff --git a/funsor/tensor.py b/funsor/tensor.py index a2bac8ff..d23dcf55 100644 --- a/funsor/tensor.py +++ b/funsor/tensor.py @@ -883,6 +883,19 @@ def eager_finitary_stack(op, parts): return Tensor(raw_result, inputs, parts[0].dtype) +@eager.register(Finitary, ops.CatOp, typing.Tuple[Tensor, ...]) +def eager_finitary_cat(op, parts): + dim = op.defaults["axis"] + if dim >= 0: + event_dims = {len(part.output.shape) for part in parts} + assert len(event_dims) == 1, "undefined" + dim = dim - next(iter(event_dims)) + assert dim < 0 + inputs, raw_parts = align_tensors(*parts, expand=True) + raw_result = ops.cat(raw_parts, dim) + return Tensor(raw_result, inputs, parts[0].dtype) + + @eager.register(Finitary, FinitaryOp, typing.Tuple[typing.Union[(Number, Tensor)], ...]) def eager_finitary_generic_tensors(op, args): inputs, raw_args = align_tensors(*args) diff --git a/funsor/torch/ops.py b/funsor/torch/ops.py index dfcfa5cc..4f44d6ca 100644 --- a/funsor/torch/ops.py +++ b/funsor/torch/ops.py @@ -321,6 +321,12 @@ def _new_full(x, shape, value): return x.new_full(shape, value) +@ops.randn.register(torch.Tensor) +def _randn(prototype, shape, rng_key=None): + assert isinstance(shape, tuple) + return torch.randn(shape, dtype=prototype.dtype, device=prototype.device) + + @ops.permute.register(torch.Tensor) def _permute(x, dims): return x.permute(dims) diff --git a/test/test_gaussian.py b/test/test_gaussian.py index f79b3dd9..704b0d4c 100644 --- a/test/test_gaussian.py +++ b/test/test_gaussian.py @@ -24,6 +24,7 @@ ) from funsor.integrate import Integrate from funsor.interpretations import eager, lazy +from funsor.montecarlo import extract_samples from funsor.tensor import Einsum, Tensor, numeric_array from funsor.terms import Number, Subs, Unary, Variable from funsor.testing import ( @@ -324,7 +325,7 @@ def test_meta(loc, scale): ("shift + g1", Contraction), ("shift - g1", Contraction), ("g1 + g1", (Gaussian, Contraction)), - ("(g1 + g2 + g2) - g2", Contraction), + ("(g1 + g2 + g2) - g2", (Gaussian, Contraction)), ("g1(i=i0)", Gaussian), ("g2(i=i0)", Gaussian), ("g1(i=i0) + g2(i=i0)", Gaussian), @@ -763,7 +764,7 @@ def test_reduce_logsumexp(int_inputs, real_inputs): ], ids=id_from_inputs, ) -def test_reduce_logsumexp_subs(int_inputs): +def test_reduce_logsumexp_partial(int_inputs): int_inputs = OrderedDict(sorted(int_inputs.items())) real_inputs = OrderedDict( [("w", Reals[2]), ("x", Reals[4]), ("y", Reals[2, 3]), ("z", Real)] @@ -777,14 +778,72 @@ def test_reduce_logsumexp_subs(int_inputs): k: Tensor(randn(batch_shape + v.shape), int_inputs) for k, v in real_inputs.items() } + real_vars = frozenset("wxyz") subsets = "w x y z wx wy wz xy xz yz wxy wxz wyz xyz".split() for reduced_vars in map(frozenset, subsets): values = {k: v for k, v in all_values.items() if k not in reduced_vars} - actual = g.reduce(ops.logaddexp, reduced_vars)(**all_values) + + # Check two ways of completely marginalizing. + expected = g.reduce(ops.logaddexp, real_vars) + actual = g.reduce(ops.logaddexp, reduced_vars).reduce( + ops.logaddexp, real_vars - reduced_vars + ) + assert_close(actual, expected, atol=1e-4, rtol=None) + + # Check two ways of substituting. expected = g(**values).reduce(ops.logaddexp, reduced_vars) + actual = g.reduce(ops.logaddexp, reduced_vars)(**all_values) assert_close(actual, expected, atol=1e-4, rtol=None) +@pytest.mark.parametrize( + "int_inputs", + [ + OrderedDict(), + OrderedDict([("i", Bint[2])]), + OrderedDict([("i", Bint[2]), ("j", Bint[3])]), + ], + ids=id_from_inputs, +) +def test_sample_partial(int_inputs): + int_inputs = OrderedDict(sorted(int_inputs.items())) + real_inputs = OrderedDict( + [("w", Reals[2]), ("x", Reals[4]), ("y", Reals[2, 3]), ("z", Real)] + ) + inputs = int_inputs.copy() + inputs.update(real_inputs) + flat = ops.cat( + [Variable(k, d).reshape((d.num_elements,)) for k, d in real_inputs.items()] + ) + + def compute_moments(samples): + flat_samples = flat(**extract_samples(samples)) + assert set(flat_samples.inputs) == {"particle"} | set(int_inputs) + mean = flat_samples.reduce(ops.mean) + diff = flat_samples - mean + cov = (diff[:, None] - diff[None, :]).reduce(ops.mean) + return mean, cov + + sample_inputs = OrderedDict(particle=Bint[50000]) + rng_keys = [None] * 3 + if get_backend() == "jax": + import jax.random + + rng_keys = jax.random.split(np.array([0, 0], dtype=np.uint32), 3) + + g = random_gaussian(inputs) + all_vars = frozenset("wxyz") + samples = g.sample(all_vars, sample_inputs, rng_keys[0]) + expected_mean, expected_cov = compute_moments(samples) + subsets = "w x y z wx wy wz xy xz yz wxy wxz wyz xyz".split() + for sampled_vars in map(frozenset, subsets): + g2 = g.sample(sampled_vars, sample_inputs, rng_keys[1]) + samples = g2.sample(all_vars, sample_inputs, rng_keys[2]) + actual_mean, actual_cov = compute_moments(samples) + assert_close(actual_mean, expected_mean, atol=1e-1, rtol=1e-1) + assert_close(actual_cov, expected_cov, atol=1e-1, rtol=1e-1) + + @pytest.mark.parametrize("int_inputs", [{}, {"i": Bint[2]}], ids=id_from_inputs) @pytest.mark.parametrize( "real_inputs", diff --git a/test/test_recipes.py b/test/test_recipes.py index 8eb249a7..5f764fbd 100644 --- a/test/test_recipes.py +++ b/test/test_recipes.py @@ -9,10 +9,14 @@ import funsor.ops as ops from funsor.domains import Bint, Real, Reals +from funsor.interpretations import memoize from funsor.montecarlo import extract_samples -from funsor.recipes import forward_filter_backward_rsample +from funsor.recipes import ( + forward_filter_backward_precondition, + forward_filter_backward_rsample, +) from funsor.terms import Lambda, Variable -from funsor.testing import assert_close, random_gaussian +from funsor.testing import Tensor, assert_close, randn, random_gaussian from funsor.util import get_backend @@ -92,7 +96,22 @@ def check_ffbr(factors, eliminate, plates, actual_samples, actual_log_prob): assert_close(actual_moments, expected_moments, atol=0.02, rtol=None) -def test_ffbr_1(): +def substitute_aux(samples, log_prob, num_samples): + assert all("aux" in v.inputs for v in samples.values()) + assert set(log_prob.inputs) == {"aux"} + + # Substitute noise for the aux value, as would happen each SVI step. + aux_numel = log_prob.inputs["aux"].num_elements + noise = Tensor(randn(num_samples, aux_numel))["particle"] + with memoize(): + samples = {k: v(aux=noise) for k, v in samples.items()} + log_prob = log_prob(aux=noise) + + return samples, log_prob + + +@pytest.mark.parametrize("backward", ["sample", "precondition"]) +def test_ffb_1(backward): """ def model(data): a = pyro.sample("a", dist.Normal(0, 1)) @@ -106,20 +125,27 @@ def model(data): } eliminate = frozenset(["a"]) plates = frozenset() - sample_inputs = OrderedDict(particle=Bint[num_samples]) - rng_key = None if get_backend() != "jax" else np.array([0, 0], dtype=np.uint32) - actual_samples, actual_log_prob = forward_filter_backward_rsample( - factors, eliminate, plates, sample_inputs, rng_key - ) + if backward == "sample": + sample_inputs = OrderedDict(particle=Bint[num_samples]) + rng_key = None if get_backend() != "jax" else np.array([0, 0], dtype=np.uint32) + actual_samples, actual_log_prob = forward_filter_backward_rsample( + factors, eliminate, plates, sample_inputs, rng_key + ) + elif backward == "precondition": + samples, log_prob = forward_filter_backward_precondition( + factors, eliminate, plates + ) + actual_samples, actual_log_prob = substitute_aux(samples, log_prob, num_samples) + assert set(actual_samples) == {"a"} assert actual_samples["a"].output == Real assert set(actual_samples["a"].inputs) == {"particle"} - check_ffbr(factors, eliminate, plates, actual_samples, actual_log_prob) -def test_ffbr_2(): +@pytest.mark.parametrize("backward", ["sample", "precondition"]) +def test_ffb_2(backward): """ def model(data): a = pyro.sample("a", dist.Normal(0, 1)) @@ -135,22 +161,29 @@ def model(data): } eliminate = frozenset(["a", "b"]) plates = frozenset() - sample_inputs = {"particle": Bint[num_samples]} - rng_key = None if get_backend() != "jax" else np.array([0, 0], dtype=np.uint32) - actual_samples, actual_log_prob = forward_filter_backward_rsample( - factors, eliminate, plates, sample_inputs, rng_key - ) + if backward == "sample": + sample_inputs = {"particle": Bint[num_samples]} + rng_key = None if get_backend() != "jax" else np.array([0, 0], dtype=np.uint32) + actual_samples, actual_log_prob = forward_filter_backward_rsample( + factors, eliminate, plates, sample_inputs, rng_key + ) + elif backward == "precondition": + samples, log_prob = forward_filter_backward_precondition( + factors, eliminate, plates + ) + actual_samples, actual_log_prob = substitute_aux(samples, log_prob, num_samples) + assert set(actual_samples) == {"a", "b"} assert actual_samples["a"].output == Real assert actual_samples["b"].output == Real assert set(actual_samples["a"].inputs) == {"particle"} assert set(actual_samples["b"].inputs) == {"particle"} - check_ffbr(factors, eliminate, plates, actual_samples, actual_log_prob) -def test_ffbr_3(): +@pytest.mark.parametrize("backward", ["sample", "precondition"]) +def test_ffb_3(backward): """ def model(data): a = pyro.sample("a", dist.Normal(0, 1)) @@ -167,22 +200,29 @@ def model(data): } eliminate = frozenset(["a", "b", "i"]) plates = frozenset(["i"]) - sample_inputs = {"particle": Bint[num_samples]} - rng_key = None if get_backend() != "jax" else np.array([0, 0], dtype=np.uint32) - actual_samples, actual_log_prob = forward_filter_backward_rsample( - factors, eliminate, plates, sample_inputs, rng_key - ) + if backward == "sample": + sample_inputs = {"particle": Bint[num_samples]} + rng_key = None if get_backend() != "jax" else np.array([0, 0], dtype=np.uint32) + actual_samples, actual_log_prob = forward_filter_backward_rsample( + factors, eliminate, plates, sample_inputs, rng_key + ) + elif backward == "precondition": + samples, log_prob = forward_filter_backward_precondition( + factors, eliminate, plates + ) + actual_samples, actual_log_prob = substitute_aux(samples, log_prob, num_samples) + assert set(actual_samples) == {"a", "b"} assert actual_samples["a"].output == Real assert actual_samples["b"].output == Real assert set(actual_samples["a"].inputs) == {"particle"} assert set(actual_samples["b"].inputs) == {"particle", "i"} - check_ffbr(factors, eliminate, plates, actual_samples, actual_log_prob) -def test_ffbr_4(): +@pytest.mark.parametrize("backward", ["sample", "precondition"]) +def test_ffb_4(backward): """ def model(data): a = pyro.sample("a", dist.Normal(0, 1)) @@ -206,12 +246,19 @@ def model(data): } eliminate = frozenset(["a", "b", "c", "d", "i", "j"]) plates = frozenset(["i", "j"]) - sample_inputs = {"particle": Bint[num_samples]} - rng_key = None if get_backend() != "jax" else np.array([0, 0], dtype=np.uint32) - actual_samples, actual_log_prob = forward_filter_backward_rsample( - factors, eliminate, plates, sample_inputs, rng_key - ) + if backward == "sample": + sample_inputs = {"particle": Bint[num_samples]} + rng_key = None if get_backend() != "jax" else np.array([0, 0], dtype=np.uint32) + actual_samples, actual_log_prob = forward_filter_backward_rsample( + factors, eliminate, plates, sample_inputs, rng_key + ) + elif backward == "precondition": + samples, log_prob = forward_filter_backward_precondition( + factors, eliminate, plates + ) + actual_samples, actual_log_prob = substitute_aux(samples, log_prob, num_samples) + assert set(actual_samples) == {"a", "b", "c", "d"} assert actual_samples["a"].output == Real assert actual_samples["b"].output == Real @@ -221,11 +268,11 @@ def model(data): assert set(actual_samples["b"].inputs) == {"particle"} assert set(actual_samples["c"].inputs) == {"particle", "i"} assert set(actual_samples["d"].inputs) == {"particle", "i"} - check_ffbr(factors, eliminate, plates, actual_samples, actual_log_prob) -def test_ffbr_5(): +@pytest.mark.parametrize("backward", ["sample", "precondition"]) +def test_ffb_5(backward): """ def model(data): a = pyro.sample("a", dist.MultivariateNormal(zeros(2), eye(2))) @@ -245,12 +292,19 @@ def model(data): } eliminate = frozenset(["a", "b", "c", "d"]) plates = frozenset() - sample_inputs = {"particle": Bint[num_samples]} - rng_key = None if get_backend() != "jax" else np.array([0, 0], dtype=np.uint32) - actual_samples, actual_log_prob = forward_filter_backward_rsample( - factors, eliminate, plates, sample_inputs, rng_key - ) + if backward == "sample": + sample_inputs = {"particle": Bint[num_samples]} + rng_key = None if get_backend() != "jax" else np.array([0, 0], dtype=np.uint32) + actual_samples, actual_log_prob = forward_filter_backward_rsample( + factors, eliminate, plates, sample_inputs, rng_key + ) + elif backward == "precondition": + samples, log_prob = forward_filter_backward_precondition( + factors, eliminate, plates + ) + actual_samples, actual_log_prob = substitute_aux(samples, log_prob, num_samples) + assert set(actual_samples) == {"a", "b", "c", "d"} assert actual_samples["a"].output == Reals[2] assert actual_samples["b"].output == Reals[2] @@ -260,12 +314,12 @@ def model(data): assert set(actual_samples["b"].inputs) == {"particle"} assert set(actual_samples["c"].inputs) == {"particle"} assert set(actual_samples["d"].inputs) == {"particle"} - check_ffbr(factors, eliminate, plates, actual_samples, actual_log_prob) @pytest.mark.xfail(reason="TODO handle intractable case") -def test_ffbr_intractable_1(): +@pytest.mark.parametrize("backward", ["sample", "precondition"]) +def test_ffb_intractable_1(backward): """ def model(data): i_plate = pyro.plate("i", 2, dim=-2) @@ -288,23 +342,30 @@ def model(data): } eliminate = frozenset(["a", "b", "i", "j"]) plates = frozenset(["i", "j"]) - sample_inputs = {"particle": Bint[num_samples]} - rng_key = None if get_backend() != "jax" else np.array([0, 0], dtype=np.uint32) - actual_samples, actual_log_prob = forward_filter_backward_rsample( - factors, eliminate, plates, sample_inputs, rng_key - ) + if backward == "sample": + sample_inputs = {"particle": Bint[num_samples]} + rng_key = None if get_backend() != "jax" else np.array([0, 0], dtype=np.uint32) + actual_samples, actual_log_prob = forward_filter_backward_rsample( + factors, eliminate, plates, sample_inputs, rng_key + ) + elif backward == "precondition": + samples, log_prob = forward_filter_backward_precondition( + factors, eliminate, plates + ) + actual_samples, actual_log_prob = substitute_aux(samples, log_prob, num_samples) + assert set(actual_samples) == {"a", "b"} assert actual_samples["a"].output == Real assert actual_samples["b"].output == Real assert set(actual_samples["a"].inputs) == {"particle", "i"} assert set(actual_samples["b"].inputs) == {"particle", "j"} - check_ffbr(factors, eliminate, plates, actual_samples, actual_log_prob) @pytest.mark.xfail(reason="TODO handle colliders via Lambda") -def test_ffbr_intractable_2(): +@pytest.mark.parametrize("backward", ["sample", "precondition"]) +def test_ffb_intractable_2(backward): """ def model(data): with pyro.plate("i", 2): @@ -319,13 +380,19 @@ def model(data): } eliminate = frozenset(["a", "i"]) plates = frozenset(["i"]) - sample_inputs = {"particle": Bint[num_samples]} - rng_key = None if get_backend() != "jax" else np.array([0, 0], dtype=np.uint32) - actual_samples, actual_log_prob = forward_filter_backward_rsample( - factors, eliminate, plates, sample_inputs, rng_key - ) + if backward == "sample": + sample_inputs = {"particle": Bint[num_samples]} + rng_key = None if get_backend() != "jax" else np.array([0, 0], dtype=np.uint32) + actual_samples, actual_log_prob = forward_filter_backward_rsample( + factors, eliminate, plates, sample_inputs, rng_key + ) + elif backward == "precondition": + samples, log_prob = forward_filter_backward_precondition( + factors, eliminate, plates + ) + actual_samples, actual_log_prob = substitute_aux(samples, log_prob, num_samples) + assert set(actual_samples) == {"a"} assert set(actual_samples["a"].inputs) == {"particle", "i"} - check_ffbr(factors, eliminate, plates, actual_samples, actual_log_prob) diff --git a/test/test_terms.py b/test/test_terms.py index e8a20dae..4e38be54 100644 --- a/test/test_terms.py +++ b/test/test_terms.py @@ -584,6 +584,17 @@ def test_cat_simple(): assert xy(i=i) is Number(i) +@pytest.mark.parametrize("right_shape", [(), (4,), (3, 2)], ids=str) +@pytest.mark.parametrize("left_shape", [(), (4,), (3, 2)], ids=str) +def test_cat_variable(left_shape, right_shape): + x = Variable("x", Reals[left_shape + (1,) + right_shape]) + y = Variable("y", Reals[left_shape + (2,) + right_shape]) + z = Variable("z", Reals[left_shape + (3,) + right_shape]) + + actual = ops.cat([x, y, z], -1 - len(right_shape)) + assert actual.output == Reals[left_shape + (6,) + right_shape] + + def test_align_simple(): x = Variable("x", Real) y = Variable("y", Real)