Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor Gaussian to use Tensor api instead of torch.Tensor #296

Merged
merged 29 commits into from
Jan 12, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
f93ab60
init: make Gaussian use Tensor, instead of torch.Tensor
fehiepsi Jan 4, 2020
ac64a6a
move cholesky to ops
fehiepsi Jan 4, 2020
4fbb404
add ops.cat
fehiepsi Jan 5, 2020
8021fbc
move all linear algebra to ops
fehiepsi Jan 5, 2020
59dd408
commit changes
fehiepsi Jan 6, 2020
4d9a6ca
clean
fehiepsi Jan 6, 2020
6dd7e3a
revert unnecessary ops
fehiepsi Jan 6, 2020
6361828
add new ops
fehiepsi Jan 6, 2020
e2e974d
increase threshold for eager_subs_affine
fehiepsi Jan 6, 2020
0f00245
test numpy ops
fehiepsi Jan 6, 2020
a7156ae
add failing unary test
fehiepsi Jan 6, 2020
a74c942
address some comments
fehiepsi Jan 6, 2020
f4c3fd7
fix typo
fehiepsi Jan 6, 2020
3edc529
revise finfo in torch too
fehiepsi Jan 7, 2020
ea279c2
mark failing tests as xfail
fehiepsi Jan 9, 2020
ea807bf
merge newop
fehiepsi Jan 9, 2020
3152972
lint
fehiepsi Jan 9, 2020
400c919
use object instead of int, float for torch ops
fehiepsi Jan 9, 2020
49f62bb
Merge branch 'newop' into gaussian-tensor
fehiepsi Jan 9, 2020
7408380
resolve merge conflict
fehiepsi Jan 9, 2020
c956cc7
add numpy backend to smoke test
fehiepsi Jan 9, 2020
2aaaf19
add ops.align_tensor and ops.materialize
fehiepsi Jan 10, 2020
88624a3
add scipy to setup file
fehiepsi Jan 10, 2020
9c777b1
fix import error
fehiepsi Jan 10, 2020
969767b
remove ops align tensors and materialize
fehiepsi Jan 10, 2020
74fe5eb
remove ops.materialize
fehiepsi Jan 11, 2020
9f1dfdb
erge remote-tracking branch 'upstream' into gaussian-tensor
fehiepsi Jan 11, 2020
b82bba4
fix bugs
fehiepsi Jan 11, 2020
9a35746
not sure if permute require tuple or *args
fehiepsi Jan 11, 2020
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions funsor/distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import funsor.ops as ops
from funsor.affine import is_affine
from funsor.domains import bint, reals
from funsor.gaussian import Gaussian, cholesky_inverse
from funsor.gaussian import Gaussian
from funsor.interpreter import gensym, interpretation
from funsor.terms import Funsor, FunsorMeta, Number, Variable, eager, lazy, to_funsor
from funsor.torch import Tensor, align_tensors, ignore_jit_warnings, materialize, torch_stack
Expand Down Expand Up @@ -511,7 +511,7 @@ def eager_mvn(loc, scale_tril, value):
return None # lazy

info_vec = scale_tril.data.new_zeros(scale_tril.data.shape[:-1])
precision = cholesky_inverse(scale_tril.data)
precision = ops.cholesky_inverse(scale_tril.data)
scale_diag = Tensor(scale_tril.data.diagonal(dim1=-1, dim2=-2), scale_tril.inputs)
log_prob = -0.5 * scale_diag.shape[0] * math.log(2 * math.pi) - scale_diag.log().sum()
inputs = scale_tril.inputs.copy()
Expand Down
204 changes: 125 additions & 79 deletions funsor/gaussian.py

Large diffs are not rendered by default.

4 changes: 2 additions & 2 deletions funsor/integrate.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import funsor.ops as ops
from funsor.cnf import Contraction, GaussianMixture
from funsor.delta import Delta
from funsor.gaussian import Gaussian, _mv, _trace_mm, _vv, align_gaussian, cholesky_inverse
from funsor.gaussian import Gaussian, align_gaussian, _mv, _trace_mm, _vv
from funsor.terms import (
Funsor,
FunsorMeta,
Expand Down Expand Up @@ -164,7 +164,7 @@ def eager_integrate(log_measure, integrand, reduced_vars):
# See "The Matrix Cookbook" (November 15, 2012) ss. 8.2.2 eq. 380.
# http://www.math.uwaterloo.ca/~hwolkowi/matrixcookbook.pdf
norm = lhs.log_normalizer.data.exp()
lhs_cov = cholesky_inverse(lhs._precision_chol)
lhs_cov = ops.cholesky_inverse(lhs._precision_chol)
lhs_loc = lhs.info_vec.unsqueeze(-1).cholesky_solve(lhs._precision_chol).squeeze(-1)
vmv_term = _vv(lhs_loc, rhs_info_vec - 0.5 * _mv(rhs_precision, lhs_loc))
data = norm * (vmv_term - 0.5 * _trace_mm(rhs_precision, lhs_cov))
Expand Down
6 changes: 3 additions & 3 deletions funsor/joint.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from funsor.cnf import Contraction, GaussianMixture
from funsor.delta import Delta
from funsor.domains import bint
from funsor.gaussian import Gaussian, align_gaussian, cholesky, cholesky_inverse
from funsor.gaussian import Gaussian, align_gaussian
from funsor.ops import AssociativeOp
from funsor.terms import Funsor, Independent, Number, Reduce, Unary, eager, moment_matching, normalize
from funsor.torch import Tensor, align_tensor
Expand Down Expand Up @@ -106,7 +106,7 @@ def moment_matching_contract_joint(red_op, bin_op, reduced_vars, discrete, gauss
old_loc = Tensor(gaussian.info_vec.unsqueeze(-1).cholesky_solve(gaussian._precision_chol).squeeze(-1),
int_inputs)
new_loc = (probs * old_loc).reduce(ops.add, approx_vars)
old_cov = Tensor(cholesky_inverse(gaussian._precision_chol), int_inputs)
old_cov = Tensor(ops.cholesky_inverse(gaussian._precision_chol), int_inputs)
diff = old_loc - new_loc
outers = Tensor(diff.data.unsqueeze(-1) * diff.data.unsqueeze(-2), diff.inputs)
new_cov = ((probs * old_cov).reduce(ops.add, approx_vars) +
Expand All @@ -117,7 +117,7 @@ def moment_matching_contract_joint(red_op, bin_op, reduced_vars, discrete, gauss
mask = (total.data == 0).to(total.data.dtype).unsqueeze(-1).unsqueeze(-1)
new_cov.data += mask * torch.eye(new_cov.data.size(-1))

new_precision = Tensor(cholesky_inverse(cholesky(new_cov.data)), new_cov.inputs)
new_precision = Tensor(ops.cholesky_inverse(ops.cholesky(new_cov.data)), new_cov.inputs)
new_info_vec = new_precision.data.matmul(new_loc.data.unsqueeze(-1)).squeeze(-1)
new_inputs = new_loc.inputs.copy()
new_inputs.update((k, d) for k, d in gaussian.inputs.items() if d.dtype == 'real')
Expand Down
152 changes: 84 additions & 68 deletions funsor/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,60 +11,6 @@
from funsor.terms import Binary, Funsor, FunsorMeta, Number, eager, substitute, to_data, to_funsor


def align_array(new_inputs, x):
r"""
Permute and expand an array to match desired ``new_inputs``.

:param OrderedDict new_inputs: A target set of inputs.
:param funsor.terms.Funsor x: A :class:`Array` s or
or :class:`~funsor.terms.Number` .
:return: a number or :class:`numpy.ndarray` that can be broadcast to other
array with inputs ``new_inputs``.
:rtype: tuple
"""
assert isinstance(new_inputs, OrderedDict)
assert isinstance(x, (Number, Array))
assert all(isinstance(d.dtype, int) for d in x.inputs.values())

data = x.data
if isinstance(x, Number):
return data

old_inputs = x.inputs
if old_inputs == new_inputs:
return data

# Permute squashed input dims.
x_keys = tuple(old_inputs)
data = np.transpose(data, (tuple(x_keys.index(k) for k in new_inputs if k in old_inputs) +
tuple(range(len(old_inputs), data.ndim))))

# Unsquash multivariate input dims by filling in ones.
data = np.reshape(data, tuple(old_inputs[k].dtype if k in old_inputs else 1 for k in new_inputs) +
x.output.shape)
return data


def align_arrays(*args):
r"""
Permute multiple arrays before applying a broadcasted op.

This is mainly useful for implementing eager funsor operations.

:param funsor.terms.Funsor \*args: Multiple :class:`Array` s and
:class:`~funsor.terms.Number` s.
:return: a pair ``(inputs, arrays)`` where arrayss are all
:class:`numpy.ndarray` s that can be broadcast together to a single data
with given ``inputs``.
:rtype: tuple
"""
inputs = OrderedDict()
for x in args:
inputs.update(x.inputs)
arrays = [align_array(inputs, x) for x in args]
return inputs, arrays


class ArrayMeta(FunsorMeta):
"""
Wrapper to fill in default args and convert between OrderedDict and tuple.
Expand Down Expand Up @@ -207,6 +153,11 @@ def eager_subs(self, subs):
return Array(data, inputs, self.dtype)


@ops.TensorOp.register(np.ndarray, (type(None), tuple, OrderedDict), str)
def _Tensor(x, inputs, dtype):
return Array(x, inputs, dtype)


@dispatch(np.ndarray)
def to_funsor(x):
return Array(x)
Expand All @@ -221,6 +172,69 @@ def to_funsor(x, output):
return result


def align_array(new_inputs, x, expand=False):
r"""
Permute and expand an array to match desired ``new_inputs``.

:param OrderedDict new_inputs: A target set of inputs.
:param funsor.terms.Funsor x: A :class:`Array` s or
or :class:`~funsor.terms.Number` .
:param bool expand: If False (default), set result size to 1 for any input
of ``x`` not in ``new_inputs``; if True expand to ``new_inputs`` size.
:return: a number or :class:`numpy.ndarray` that can be broadcast to other
array with inputs ``new_inputs``.
:rtype: tuple
"""
assert isinstance(new_inputs, OrderedDict)
assert isinstance(x, (Number, Array))
assert all(isinstance(d.dtype, int) for d in x.inputs.values())

data = x.data
if isinstance(x, Number):
return data

old_inputs = x.inputs
if old_inputs == new_inputs:
return data

# Permute squashed input dims.
x_keys = tuple(old_inputs)
data = np.transpose(data, (tuple(x_keys.index(k) for k in new_inputs if k in old_inputs) +
tuple(range(len(old_inputs), data.ndim))))

# Unsquash multivariate input dims by filling in ones.
data = np.reshape(data, tuple(old_inputs[k].dtype if k in old_inputs else 1 for k in new_inputs) +
x.output.shape)

# Optionally expand new dims.
if expand:
data = np.broadcast_to(data, tuple(d.dtype for d in new_inputs.values()) + x.output.shape)
return data


def align_arrays(*args, **kwargs):
r"""
Permute multiple arrays before applying a broadcasted op.

This is mainly useful for implementing eager funsor operations.

:param funsor.terms.Funsor \*args: Multiple :class:`Array` s and
:class:`~funsor.terms.Number` s.
:param bool expand: Whether to expand input tensors. Defaults to False.
:return: a pair ``(inputs, arrays)`` where arrayss are all
:class:`numpy.ndarray` s that can be broadcast together to a single data
with given ``inputs``.
:rtype: tuple
"""
expand = kwargs.pop('expand', False)
assert not kwargs
inputs = OrderedDict()
for x in args:
inputs.update(x.inputs)
arrays = [align_array(inputs, x, expand=expand) for x in args]
return inputs, arrays


@to_data.register(Array)
def _to_data_array(x):
if x.inputs:
Expand Down Expand Up @@ -322,6 +336,10 @@ def _sigmoid(x):
ops.log1p.register(np.ndarray)(np.log1p)
ops.min.register(np.ndarray, np.ndarray)(np.minimum)
ops.max.register(np.ndarray, np.ndarray)(np.maximum)
ops.unsqueeze.register(np.ndarray, int)(np.expand_dims)
ops.expand.register(np.ndarray, tuple)(np.broadcast_to)
ops.permute.register(np.ndarray, tuple)(np.transpose)
ops.transpose.register(np.ndarray, int, int)(np.swapaxes)


# TODO: replace (int, float) by object
Expand Down Expand Up @@ -393,7 +411,15 @@ def _cholesky_inverse(x):
def _triangular_solve(x, y, upper, transpose):
from scipy.linalg import solve_triangular

return solve_triangular(x, y, trans=int(transpose), lower=not upper)
# TODO: remove this logic when using JAX
# work around the issue of scipy which does not support batched input
batch_shape = np.broadcast(x[..., 0, 0], y[..., 0, 0]).shape
xs = np.broadcast_to(x, batch_shape + x.shape[-2:]).reshape((-1,) + x.shape[-2:])
ys = np.broadcast_to(y, batch_shape + y.shape[-2:]).reshape((-1,) + y.shape[-2:])
ans = [solve_triangular(y, x, trans=int(transpose), lower=not upper)
for (x, y) in zip(xs, ys)]
ans = np.stack(ans)
return ans.reshape(batch_shape + ans.shape[-2:])


@ops.diagonal.register(np.ndarray, int, int)
Expand All @@ -416,16 +442,6 @@ def _new_eye(x, shape):
return np.broadcast_to(np.eye(shape[-1]), shape + (-1,))


@ops.unsqueeze.register(np.ndarray, int)
def _unsqueeze(x, dim):
return np.expand_dims(x, dim)


@ops.expand.register(np.ndarray, tuple)
def _expand(x, shape):
return np.broadcast_to(x, shape)


@ops.transpose.register(np.ndarray, int, int)
def _transpose(x, dim0, dim1):
return np.swapaxes(x, dim0, dim1)
@ops.new_arange.register(np.ndarray, int, int, int)
def _new_arange(x, start, stop, step):
return np.arange(start, stop, step)
19 changes: 19 additions & 0 deletions funsor/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,6 +346,11 @@ def new_eye(x, shape):
raise NotImplementedError


@Op
def new_arange(x, start, stop, step):
raise NotImplementedError


@Op
def unsqueeze(x, dim):
raise NotImplementedError
Expand All @@ -366,6 +371,20 @@ def transpose(x, dim0, dim1):
raise NotImplementedError


@Op
def permute(x, dims):
raise NotImplementedError


@Op
def TensorOp(x, inputs, dtype):
raise NotImplementedError


def Tensor(x, inputs=None, dtype="real"):
return TensorOp(x, inputs, dtype)


__all__ = [
'AddOp',
'AssociativeOp',
Expand Down
5 changes: 3 additions & 2 deletions funsor/pyro/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,11 @@
from funsor.delta import Delta
from funsor.distributions import BernoulliLogits, MultivariateNormal, Normal
from funsor.domains import bint, reals
from funsor.gaussian import Gaussian, align_tensors, cholesky
from funsor.gaussian import Gaussian
from funsor.interpreter import gensym
from funsor.ops import cholesky
from funsor.terms import Funsor, Independent, Variable, eager
from funsor.torch import Tensor
from funsor.torch import Tensor, align_tensors

# Conversion functions use fixed names for Pyro batch dims, but
# accept an event_inputs tuple for custom event dim names.
Expand Down
Loading