Skip to content

Commit

Permalink
refactor Gaussian to use Tensor api instead of torch.Tensor (#296)
Browse files Browse the repository at this point in the history
  • Loading branch information
fehiepsi authored and fritzo committed Jan 12, 2020
1 parent e7ae22d commit 0f3ec7d
Show file tree
Hide file tree
Showing 10 changed files with 362 additions and 275 deletions.
4 changes: 2 additions & 2 deletions funsor/distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import funsor.ops as ops
from funsor.affine import is_affine
from funsor.domains import bint, reals
from funsor.gaussian import Gaussian, cholesky_inverse
from funsor.gaussian import Gaussian
from funsor.interpreter import gensym, interpretation
from funsor.terms import Funsor, FunsorMeta, Number, Variable, eager, lazy, to_funsor
from funsor.torch import Tensor, align_tensors, ignore_jit_warnings, materialize, torch_stack
Expand Down Expand Up @@ -511,7 +511,7 @@ def eager_mvn(loc, scale_tril, value):
return None # lazy

info_vec = scale_tril.data.new_zeros(scale_tril.data.shape[:-1])
precision = cholesky_inverse(scale_tril.data)
precision = ops.cholesky_inverse(scale_tril.data)
scale_diag = Tensor(scale_tril.data.diagonal(dim1=-1, dim2=-2), scale_tril.inputs)
log_prob = -0.5 * scale_diag.shape[0] * math.log(2 * math.pi) - scale_diag.log().sum()
inputs = scale_tril.inputs.copy()
Expand Down
204 changes: 125 additions & 79 deletions funsor/gaussian.py

Large diffs are not rendered by default.

4 changes: 2 additions & 2 deletions funsor/integrate.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import funsor.ops as ops
from funsor.cnf import Contraction, GaussianMixture
from funsor.delta import Delta
from funsor.gaussian import Gaussian, _mv, _trace_mm, _vv, align_gaussian, cholesky_inverse
from funsor.gaussian import Gaussian, align_gaussian, _mv, _trace_mm, _vv
from funsor.terms import (
Funsor,
FunsorMeta,
Expand Down Expand Up @@ -164,7 +164,7 @@ def eager_integrate(log_measure, integrand, reduced_vars):
# See "The Matrix Cookbook" (November 15, 2012) ss. 8.2.2 eq. 380.
# http://www.math.uwaterloo.ca/~hwolkowi/matrixcookbook.pdf
norm = lhs.log_normalizer.data.exp()
lhs_cov = cholesky_inverse(lhs._precision_chol)
lhs_cov = ops.cholesky_inverse(lhs._precision_chol)
lhs_loc = lhs.info_vec.unsqueeze(-1).cholesky_solve(lhs._precision_chol).squeeze(-1)
vmv_term = _vv(lhs_loc, rhs_info_vec - 0.5 * _mv(rhs_precision, lhs_loc))
data = norm * (vmv_term - 0.5 * _trace_mm(rhs_precision, lhs_cov))
Expand Down
6 changes: 3 additions & 3 deletions funsor/joint.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from funsor.cnf import Contraction, GaussianMixture
from funsor.delta import Delta
from funsor.domains import bint
from funsor.gaussian import Gaussian, align_gaussian, cholesky, cholesky_inverse
from funsor.gaussian import Gaussian, align_gaussian
from funsor.ops import AssociativeOp
from funsor.terms import Funsor, Independent, Number, Reduce, Unary, eager, moment_matching, normalize
from funsor.torch import Tensor, align_tensor
Expand Down Expand Up @@ -106,7 +106,7 @@ def moment_matching_contract_joint(red_op, bin_op, reduced_vars, discrete, gauss
old_loc = Tensor(gaussian.info_vec.unsqueeze(-1).cholesky_solve(gaussian._precision_chol).squeeze(-1),
int_inputs)
new_loc = (probs * old_loc).reduce(ops.add, approx_vars)
old_cov = Tensor(cholesky_inverse(gaussian._precision_chol), int_inputs)
old_cov = Tensor(ops.cholesky_inverse(gaussian._precision_chol), int_inputs)
diff = old_loc - new_loc
outers = Tensor(diff.data.unsqueeze(-1) * diff.data.unsqueeze(-2), diff.inputs)
new_cov = ((probs * old_cov).reduce(ops.add, approx_vars) +
Expand All @@ -117,7 +117,7 @@ def moment_matching_contract_joint(red_op, bin_op, reduced_vars, discrete, gauss
mask = (total.data == 0).to(total.data.dtype).unsqueeze(-1).unsqueeze(-1)
new_cov.data += mask * torch.eye(new_cov.data.size(-1))

new_precision = Tensor(cholesky_inverse(cholesky(new_cov.data)), new_cov.inputs)
new_precision = Tensor(ops.cholesky_inverse(ops.cholesky(new_cov.data)), new_cov.inputs)
new_info_vec = new_precision.data.matmul(new_loc.data.unsqueeze(-1)).squeeze(-1)
new_inputs = new_loc.inputs.copy()
new_inputs.update((k, d) for k, d in gaussian.inputs.items() if d.dtype == 'real')
Expand Down
152 changes: 84 additions & 68 deletions funsor/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,60 +11,6 @@
from funsor.terms import Binary, Funsor, FunsorMeta, Number, eager, substitute, to_data, to_funsor


def align_array(new_inputs, x):
r"""
Permute and expand an array to match desired ``new_inputs``.
:param OrderedDict new_inputs: A target set of inputs.
:param funsor.terms.Funsor x: A :class:`Array` s or
or :class:`~funsor.terms.Number` .
:return: a number or :class:`numpy.ndarray` that can be broadcast to other
array with inputs ``new_inputs``.
:rtype: tuple
"""
assert isinstance(new_inputs, OrderedDict)
assert isinstance(x, (Number, Array))
assert all(isinstance(d.dtype, int) for d in x.inputs.values())

data = x.data
if isinstance(x, Number):
return data

old_inputs = x.inputs
if old_inputs == new_inputs:
return data

# Permute squashed input dims.
x_keys = tuple(old_inputs)
data = np.transpose(data, (tuple(x_keys.index(k) for k in new_inputs if k in old_inputs) +
tuple(range(len(old_inputs), data.ndim))))

# Unsquash multivariate input dims by filling in ones.
data = np.reshape(data, tuple(old_inputs[k].dtype if k in old_inputs else 1 for k in new_inputs) +
x.output.shape)
return data


def align_arrays(*args):
r"""
Permute multiple arrays before applying a broadcasted op.
This is mainly useful for implementing eager funsor operations.
:param funsor.terms.Funsor \*args: Multiple :class:`Array` s and
:class:`~funsor.terms.Number` s.
:return: a pair ``(inputs, arrays)`` where arrayss are all
:class:`numpy.ndarray` s that can be broadcast together to a single data
with given ``inputs``.
:rtype: tuple
"""
inputs = OrderedDict()
for x in args:
inputs.update(x.inputs)
arrays = [align_array(inputs, x) for x in args]
return inputs, arrays


class ArrayMeta(FunsorMeta):
"""
Wrapper to fill in default args and convert between OrderedDict and tuple.
Expand Down Expand Up @@ -207,6 +153,11 @@ def eager_subs(self, subs):
return Array(data, inputs, self.dtype)


@ops.TensorOp.register(np.ndarray, (type(None), tuple, OrderedDict), str)
def _Tensor(x, inputs, dtype):
return Array(x, inputs, dtype)


@dispatch(np.ndarray)
def to_funsor(x):
return Array(x)
Expand All @@ -221,6 +172,69 @@ def to_funsor(x, output):
return result


def align_array(new_inputs, x, expand=False):
r"""
Permute and expand an array to match desired ``new_inputs``.
:param OrderedDict new_inputs: A target set of inputs.
:param funsor.terms.Funsor x: A :class:`Array` s or
or :class:`~funsor.terms.Number` .
:param bool expand: If False (default), set result size to 1 for any input
of ``x`` not in ``new_inputs``; if True expand to ``new_inputs`` size.
:return: a number or :class:`numpy.ndarray` that can be broadcast to other
array with inputs ``new_inputs``.
:rtype: tuple
"""
assert isinstance(new_inputs, OrderedDict)
assert isinstance(x, (Number, Array))
assert all(isinstance(d.dtype, int) for d in x.inputs.values())

data = x.data
if isinstance(x, Number):
return data

old_inputs = x.inputs
if old_inputs == new_inputs:
return data

# Permute squashed input dims.
x_keys = tuple(old_inputs)
data = np.transpose(data, (tuple(x_keys.index(k) for k in new_inputs if k in old_inputs) +
tuple(range(len(old_inputs), data.ndim))))

# Unsquash multivariate input dims by filling in ones.
data = np.reshape(data, tuple(old_inputs[k].dtype if k in old_inputs else 1 for k in new_inputs) +
x.output.shape)

# Optionally expand new dims.
if expand:
data = np.broadcast_to(data, tuple(d.dtype for d in new_inputs.values()) + x.output.shape)
return data


def align_arrays(*args, **kwargs):
r"""
Permute multiple arrays before applying a broadcasted op.
This is mainly useful for implementing eager funsor operations.
:param funsor.terms.Funsor \*args: Multiple :class:`Array` s and
:class:`~funsor.terms.Number` s.
:param bool expand: Whether to expand input tensors. Defaults to False.
:return: a pair ``(inputs, arrays)`` where arrayss are all
:class:`numpy.ndarray` s that can be broadcast together to a single data
with given ``inputs``.
:rtype: tuple
"""
expand = kwargs.pop('expand', False)
assert not kwargs
inputs = OrderedDict()
for x in args:
inputs.update(x.inputs)
arrays = [align_array(inputs, x, expand=expand) for x in args]
return inputs, arrays


@to_data.register(Array)
def _to_data_array(x):
if x.inputs:
Expand Down Expand Up @@ -322,6 +336,10 @@ def _sigmoid(x):
ops.log1p.register(np.ndarray)(np.log1p)
ops.min.register(np.ndarray, np.ndarray)(np.minimum)
ops.max.register(np.ndarray, np.ndarray)(np.maximum)
ops.unsqueeze.register(np.ndarray, int)(np.expand_dims)
ops.expand.register(np.ndarray, tuple)(np.broadcast_to)
ops.permute.register(np.ndarray, tuple)(np.transpose)
ops.transpose.register(np.ndarray, int, int)(np.swapaxes)


# TODO: replace (int, float) by object
Expand Down Expand Up @@ -393,7 +411,15 @@ def _cholesky_inverse(x):
def _triangular_solve(x, y, upper, transpose):
from scipy.linalg import solve_triangular

return solve_triangular(x, y, trans=int(transpose), lower=not upper)
# TODO: remove this logic when using JAX
# work around the issue of scipy which does not support batched input
batch_shape = np.broadcast(x[..., 0, 0], y[..., 0, 0]).shape
xs = np.broadcast_to(x, batch_shape + x.shape[-2:]).reshape((-1,) + x.shape[-2:])
ys = np.broadcast_to(y, batch_shape + y.shape[-2:]).reshape((-1,) + y.shape[-2:])
ans = [solve_triangular(y, x, trans=int(transpose), lower=not upper)
for (x, y) in zip(xs, ys)]
ans = np.stack(ans)
return ans.reshape(batch_shape + ans.shape[-2:])


@ops.diagonal.register(np.ndarray, int, int)
Expand All @@ -416,16 +442,6 @@ def _new_eye(x, shape):
return np.broadcast_to(np.eye(shape[-1]), shape + (-1,))


@ops.unsqueeze.register(np.ndarray, int)
def _unsqueeze(x, dim):
return np.expand_dims(x, dim)


@ops.expand.register(np.ndarray, tuple)
def _expand(x, shape):
return np.broadcast_to(x, shape)


@ops.transpose.register(np.ndarray, int, int)
def _transpose(x, dim0, dim1):
return np.swapaxes(x, dim0, dim1)
@ops.new_arange.register(np.ndarray, int, int, int)
def _new_arange(x, start, stop, step):
return np.arange(start, stop, step)
19 changes: 19 additions & 0 deletions funsor/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,6 +346,11 @@ def new_eye(x, shape):
raise NotImplementedError


@Op
def new_arange(x, start, stop, step):
raise NotImplementedError


@Op
def unsqueeze(x, dim):
raise NotImplementedError
Expand All @@ -366,6 +371,20 @@ def transpose(x, dim0, dim1):
raise NotImplementedError


@Op
def permute(x, dims):
raise NotImplementedError


@Op
def TensorOp(x, inputs, dtype):
raise NotImplementedError


def Tensor(x, inputs=None, dtype="real"):
return TensorOp(x, inputs, dtype)


__all__ = [
'AddOp',
'AssociativeOp',
Expand Down
5 changes: 3 additions & 2 deletions funsor/pyro/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,11 @@
from funsor.delta import Delta
from funsor.distributions import BernoulliLogits, MultivariateNormal, Normal
from funsor.domains import bint, reals
from funsor.gaussian import Gaussian, align_tensors, cholesky
from funsor.gaussian import Gaussian
from funsor.interpreter import gensym
from funsor.ops import cholesky
from funsor.terms import Funsor, Independent, Variable, eager
from funsor.torch import Tensor
from funsor.torch import Tensor, align_tensors

# Conversion functions use fixed names for Pyro batch dims, but
# accept an event_inputs tuple for custom event dim names.
Expand Down
Loading

0 comments on commit 0f3ec7d

Please sign in to comment.