diff --git a/funsor/affine.py b/funsor/affine.py index 0751f84b..1e07a152 100644 --- a/funsor/affine.py +++ b/funsor/affine.py @@ -61,7 +61,9 @@ def _(fn): @affine_inputs.register(Unary) def _(fn): - if fn.op in (ops.neg, ops.sum) or isinstance(fn.op, ops.ReshapeOp): + if fn.op in (ops.neg, ops.sum) or isinstance( + fn.op, (ops.ReshapeOp, ops.GetsliceOp) + ): return affine_inputs(fn.arg) return frozenset() diff --git a/funsor/domains.py b/funsor/domains.py index a9472e7a..59cf7fc5 100644 --- a/funsor/domains.py +++ b/funsor/domains.py @@ -10,6 +10,7 @@ from weakref import WeakValueDictionary import funsor.ops as ops +from funsor.ops.builtin import parse_ellipsis, parse_slice from funsor.util import broadcast_shape, get_backend, get_tracing_state, quote Domain = type @@ -331,6 +332,58 @@ def _find_domain_getitem(op, lhs_domain, rhs_domain): ) +@find_domain.register(ops.GetsliceOp) +def _find_domain_getslice(op, domain): + index = op.defaults["index"] + if isinstance(domain, ArrayType): + dtype = domain.dtype + shape = list(domain.shape) + left, right = parse_ellipsis(index) + + i = 0 + for part in left: + if part is None: + shape.insert(i, 1) + i += 1 + elif isinstance(part, int): + del shape[i] + elif isinstance(part, slice): + start, stop, step = parse_slice(part, shape[i]) + shape[i] = max(0, (stop - start + step - 1) // step) + i += 1 + else: + raise ValueError(part) + + i = -1 + for part in reversed(right): + if part is None: + shape.insert(len(shape) + i + 1, 1) + i -= 1 + elif isinstance(part, int): + del shape[i] + elif isinstance(part, slice): + start, stop, step = parse_slice(part, shape[i]) + shape[i] = max(0, (stop - start + step - 1) // step) + i -= 1 + else: + raise ValueError(part) + + return Array[dtype, tuple(shape)] + + if isinstance(domain, ProductDomain): + if isinstance(index, tuple): + assert len(index) == 1 + index = index[0] + if isinstance(index, int): + return domain.__args__[index] + elif isinstance(index, slice): + return Product[domain.__args__[index]] + else: + raise ValueError(index) + + raise NotImplementedError("TODO") + + @find_domain.register(ops.BinaryOp) def _find_domain_pointwise_binary_generic(op, lhs, rhs): if ( diff --git a/funsor/ops/builtin.py b/funsor/ops/builtin.py index 103cf3e1..5fee7e8a 100644 --- a/funsor/ops/builtin.py +++ b/funsor/ops/builtin.py @@ -13,6 +13,7 @@ UNITS, BinaryOp, Op, + OpMeta, TransformOp, UnaryOp, declare_op_types, @@ -43,6 +44,105 @@ def getitem(lhs, rhs, offset=0): return lhs[(slice(None),) * offset + (rhs,)] +class GetsliceMeta(OpMeta): + """ + Works around slice objects not being hashable. + """ + + def hash_args_kwargs(cls, args, kwargs): + index = args[0] if args else kwargs["index"] + if not isinstance(index, tuple): + index = (index,) + key = tuple( + (x.start, x.stop, x.step) if isinstance(x, slice) else x for x in index + ) + return key + + +@UnaryOp.make(metaclass=GetsliceMeta) +def getslice(x, index=Ellipsis): + return x[index] + + +getslice.supported_types = (type(None), type(Ellipsis), int, slice) + + +def parse_ellipsis(index): + """ + Helper to split a slice into parts left and right of Ellipses. + + :param index: A tuple, or other object (None, int, slice, Funsor). + :returns: a pair of tuples ``left, right``. + :rtype: tuple + """ + if not isinstance(index, tuple): + index = (index,) + left = [] + i = 0 + for part in index: + i += 1 + if part is Ellipsis: + break + left.append(part) + right = [] + for part in reversed(index[i:]): + if part is Ellipsis: + break + right.append(part) + right.reverse() + return tuple(left), tuple(right) + + +def normalize_ellipsis(index, size): + """ + Expand Ellipses in an index to fill the given number of dimensions. + + This should satisfy the equation:: + + x[i] == x[normalize_ellipsis(i, len(x.shape))] + """ + left, right = parse_ellipsis(index) + if len(left) + len(right) > size: + raise ValueError(f"Index is too wide: {index}") + middle = (slice(None),) * (size - len(left) - len(right)) + return left + middle + right + + +def parse_slice(s, size): + """ + Helper to determine nonnegative integers (start, stop, step) of a slice. + + :param slice s: A slice. + :param int size: The size of the array being indexed into. + :returns: A tuple of nonnegative integers ``start, stop, step``. + :rtype: tuple + """ + start = s.start + if start is None: + start = 0 + assert isinstance(start, int) + if start >= 0: + start = min(size, start) + else: + start = max(0, size + start) + + stop = s.stop + if stop is None: + stop = size + assert isinstance(stop, int) + if stop >= 0: + stop = min(size, stop) + else: + stop = max(0, size + stop) + + step = s.step + if step is None: + step = 1 + assert isinstance(step, int) + + return start, stop, step + + abs = UnaryOp.make(_builtin_abs) eq = BinaryOp.make(operator.eq) ge = BinaryOp.make(operator.ge) @@ -194,6 +294,7 @@ def sigmoid_log_abs_det_jacobian(x, y): "floordiv", "ge", "getitem", + "getslice", "gt", "invert", "le", diff --git a/funsor/tensor.py b/funsor/tensor.py index 204c0aaf..a2bac8ff 100644 --- a/funsor/tensor.py +++ b/funsor/tensor.py @@ -859,6 +859,16 @@ def eager_getitem_tensor_tensor(op, lhs, rhs): return Tensor(data, inputs, lhs.dtype) +@eager.register(Unary, ops.GetsliceOp, Tensor) +def eager_getslice_tensor(op, x): + index = op.defaults["index"] + if not isinstance(index, tuple): + index = (index,) + index = (slice(None),) * len(x.inputs) + index + data = x.data[index] + return Tensor(data, x.inputs, x.dtype) + + @eager.register( Finitary, ops.StackOp, typing.Tuple[typing.Union[(Number, Tensor)], ...] ) diff --git a/funsor/terms.py b/funsor/terms.py index 354921ba..021c3187 100644 --- a/funsor/terms.py +++ b/funsor/terms.py @@ -17,6 +17,7 @@ from funsor.domains import ( Array, Bint, + BintType, Domain, Product, ProductDomain, @@ -34,6 +35,7 @@ ) from funsor.interpreter import PatternMissingError, interpret from funsor.ops import AssociativeOp, GetitemOp, Op +from funsor.ops.builtin import normalize_ellipsis, parse_ellipsis, parse_slice from funsor.syntax import INFIX_OPERATORS, PREFIX_OPERATORS from funsor.typing import GenericTypeMeta, Variadic, deep_type, get_args, get_origin from funsor.util import getargspec, lazy_property, pretty, quote @@ -730,23 +732,24 @@ def __ge__(self, other): return Binary(ops.ge, self, to_funsor(other)) def __getitem__(self, other): + """ + Helper to desugar into either ops.getitem (for advanced indexing + involving Funsors as indices) or ops.getslice (for simple indexing + involving only integers, slices, None, and Ellipsis). + """ if type(other) is not tuple: + if isinstance(other, ops.getslice.supported_types): + return ops.getslice(self, other) other = to_funsor(other, Bint[self.output.shape[0]]) return Binary(ops.getitem, self, other) + # Handle complex slicing operations involving no funsors. + if all(isinstance(part, ops.getslice.supported_types) for part in other): + return ops.getslice(self, other) + # Handle Ellipsis slicing. if any(part is Ellipsis for part in other): - left = [] - for part in other: - if part is Ellipsis: - break - left.append(part) - right = [] - for part in reversed(other): - if part is Ellipsis: - break - right.append(part) - right.reverse() + left, right = parse_ellipsis(other) missing = len(self.output.shape) - len(left) - len(right) assert missing >= 0 middle = [slice(None)] * missing @@ -756,6 +759,8 @@ def __getitem__(self, other): result = self offset = 0 for part in other: + if part is None: + raise NotImplementedError("TODO") if isinstance(part, slice): if part != slice(None): raise NotImplementedError("TODO support nontrivial slicing") @@ -1474,6 +1479,15 @@ def eager_subs(self, subs): ) +@to_funsor.register(slice) +def slice_to_funsor(s, output=None, dim_to_name=None): + if not isinstance(output, BintType): + raise ValueError("Incompatible slice output: {output}") + start, stop, step = parse_slice(s, output.size) + i = Variable("slice", output) + return Lambda(i, Slice("slice", start, stop, step, output.size)) + + class Align(Funsor): """ Lazy call to ``.align(...)``. @@ -1571,20 +1585,21 @@ def eager_subs(self, subs): index = subs[0][1] # Try to eagerly select an index. - assert index.output == Bint[len(self.parts)] - - if isinstance(index, Number): - # Select a single part. - return self.parts[index.data] - elif isinstance(index, Variable): - # Rename the stacking dimension. - parts = self.parts - return Stack(index.name, parts) - elif isinstance(index, Slice): - parts = self.parts[index.slice] - return Stack(index.name, parts) + if index.output == Bint[len(self.parts)]: + if isinstance(index, Number): + # Select a single part. + return self.parts[index.data] + elif isinstance(index, Variable): + # Rename the stacking dimension. + parts = self.parts + return Stack(index.name, parts) + elif isinstance(index, Slice): + parts = self.parts[index.slice] + return Stack(index.name, parts) + else: + raise NotImplementedError("TODO support advanced indexing in Stack") else: - raise NotImplementedError("TODO support advanced indexing in Stack") + raise NotImplementedError("TODO support slicing in Stack") def eager_reduce(self, op, reduced_vars): parts = self.parts @@ -1754,6 +1769,21 @@ def eager_getitem_lambda(op, lhs, rhs): return Lambda(lhs.var, expr) +@eager.register(Unary, ops.GetsliceOp, Lambda) +def eager_getslice_lambda(op, x): + index = normalize_ellipsis(op.defaults["index"], len(x.shape)) + head, tail = index[0], index[1:] + expr = x.expr + if head != slice(None): + expr = expr(**{x.var.name: head}) + if tail: + expr = ops.getslice(expr, tail) + if x.var.name in expr.inputs: # dim is preserved, e.g. x[1:] + return Lambda(x.var, expr) + else: # dim is eliminated, e.g. x[0] + return expr + + class Independent(Funsor): """ Creates an independent diagonal distribution. @@ -1885,6 +1915,21 @@ def eager_getitem_tuple(op, lhs, rhs): return op(lhs.args, rhs.data) +@lazy.register(Unary, ops.GetsliceOp, Tuple) +@eager.register(Unary, ops.GetsliceOp, Tuple) +def eager_getslice_tuple(op, x): + index = op.defaults["index"] + if isinstance(index, tuple): + assert len(index) == 1 + index = index[0] + if isinstance(index, int): + return op(x.args) + elif isinstance(index, slice): + return Tuple(op(x.args)) + else: + raise ValueError(index) + + def _symbolic(inputs, output, fn): args, vargs, kwargs, defaults = getargspec(fn) assert not vargs diff --git a/funsor/testing.py b/funsor/testing.py index c1c39e2d..acedea8f 100644 --- a/funsor/testing.py +++ b/funsor/testing.py @@ -478,3 +478,20 @@ def iter_subsets(iterable, *, min_size=None, max_size=None): max_size = len(iterable) for size in range(min_size, max_size + 1): yield from itertools.combinations(iterable, size) + + +class DesugarGetitem: + """ + Helper to desugar ``.__getitem__()`` syntax. + + Example:: + + >>> desugar_getitem[1:3, ..., None] + (slice(1, 3), Ellipsis, None) + """ + + def __getitem__(self, index): + return index + + +desugar_getitem = DesugarGetitem() diff --git a/test/test_ops.py b/test/test_ops.py index f70b8d80..f9b2ca4a 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -7,6 +7,8 @@ from funsor import ops from funsor.distribution import BACKEND_TO_DISTRIBUTIONS_BACKEND +from funsor.ops.builtin import parse_ellipsis, parse_slice +from funsor.testing import desugar_getitem from funsor.util import get_backend @@ -36,3 +38,39 @@ def test_transform_op_gc(dist): assert len(op_set) == 1 del op assert len(op_set) == 0 + + +@pytest.mark.parametrize( + "index, left, right", + [ + (desugar_getitem[()], (), ()), + (desugar_getitem[0], (0,), ()), + (desugar_getitem[...], (), ()), + (desugar_getitem[..., ...], (), ()), + (desugar_getitem[1, ...], (1,), ()), + (desugar_getitem[..., 1], (), (1,)), + (desugar_getitem[:, None, ..., 1, 1:2], (slice(None), None), (1, slice(1, 2))), + ], + ids=str, +) +def test_parse_ellipsis(index, left, right): + assert parse_ellipsis(index) == (left, right) + + +@pytest.mark.parametrize( + "s, size, start, stop, step", + [ + (desugar_getitem[:], 5, 0, 5, 1), + (desugar_getitem[:3], 5, 0, 3, 1), + (desugar_getitem[-9:3], 5, 0, 3, 1), + (desugar_getitem[:-2], 5, 0, 3, 1), + (desugar_getitem[2:], 5, 2, 5, 1), + (desugar_getitem[2:9], 5, 2, 5, 1), + (desugar_getitem[-3:], 5, 2, 5, 1), + (desugar_getitem[-3:-2], 5, 2, 3, 1), + ], + ids=str, +) +def test_parse_slice(s, size, start, stop, step): + actual = parse_slice(s, size) + assert actual == (start, stop, step) diff --git a/test/test_tensor.py b/test/test_tensor.py index daee33f2..1a9f41e9 100644 --- a/test/test_tensor.py +++ b/test/test_tensor.py @@ -31,6 +31,7 @@ assert_close, assert_equiv, check_funsor, + desugar_getitem, empty, iter_subsets, rand, @@ -608,6 +609,60 @@ def test_lambda_getitem(): assert Lambda(i, y) is x +@pytest.mark.parametrize( + "index", + [ + desugar_getitem[0], + desugar_getitem[1, 2], + desugar_getitem[None], + desugar_getitem[None, 1], + desugar_getitem[2, None], + desugar_getitem[None, 0, None], + desugar_getitem[:], + desugar_getitem[:, :], + desugar_getitem[1:], + desugar_getitem[1:3], + desugar_getitem[::2], + desugar_getitem[1::2], + desugar_getitem[:, None], + desugar_getitem[None, :], + desugar_getitem[None, :, 1], + desugar_getitem[...], + desugar_getitem[..., 0], + desugar_getitem[..., 0, 1], + desugar_getitem[..., 0, :], + desugar_getitem[..., None, :], + desugar_getitem[..., 1:-1:2, :], + desugar_getitem[:, 0, ...], + desugar_getitem[:, None, ...], + desugar_getitem[:, 1:-1:2, ...], + desugar_getitem[None, ..., None], + desugar_getitem[:, None, ..., :, None], + desugar_getitem[:, None, ..., None, :], + desugar_getitem[None, :, ..., :, None], + desugar_getitem[None, :, ..., None, :], + desugar_getitem[0, None, ..., 0, None], + desugar_getitem[0, None, ..., None, 0], + desugar_getitem[None, 0, ..., 0, None], + desugar_getitem[None, 0, ..., None, 0], + ], + ids=str, +) +def test_getslice_shape(index): + shape = (6, 5, 4, 3) + data = randn(shape) + expected = Tensor(data[index]) + + # Check eager indexing. + actual = Tensor(data)[index] + assert_close(actual, expected) + + # Check lazy find_domain. + actual = Variable("x", Reals[shape])[index] + assert actual.dtype == expected.dtype + assert actual.shape == expected.shape + + REDUCE_OPS = [ ops.add, ops.mul, diff --git a/test/test_terms.py b/test/test_terms.py index 720bd571..5f99e553 100644 --- a/test/test_terms.py +++ b/test/test_terms.py @@ -629,6 +629,35 @@ def test_stack_lambda(dtype): assert z[1] is x2 +@pytest.mark.parametrize("dtype", ["real", 4, 5]) +def test_stack_lambda_2(dtype): + + x1 = Number(0, dtype) + x2 = Number(1, dtype) + x3 = Number(2, dtype) + x4 = Number(3, dtype) + x = [[x1, x2, x3], [x2, x3, x4]] + + i = Variable("i", Bint[3]) + y1 = Lambda(i, Stack("i", (x1, x2, x3))) + y2 = Lambda(i, Stack("i", (x2, x3, x4))) + + j = Variable("j", Bint[2]) + z = Lambda(j, Stack("j", (y1, y2))) + assert not z.inputs + assert z.output == Array[dtype, (2, 3)] + + assert z[0] is y1 + assert z[1] is y2 + for i, j in itertools.product(range(2), range(3)): + assert z[i, j] is x[i][j] + assert z[:, j][i] is x[i][j] + assert z[i, :][j] is x[i][j] + # TODO support advanced slicing of Stack + # assert z[0:9, j][i] is x[i][j] + # assert z[i, 0:9][j] is x[i][j] + + def test_funsor_tuple(): x = Number(1, 3) y = Number(2.5, "real") @@ -643,6 +672,10 @@ def test_funsor_tuple(): assert xyz[0] is x assert xyz[1] is y assert xyz[2] is z + assert xyz[:] is xyz + assert xyz[1:] is Tuple((y, z)) + assert xyz[:2] is Tuple((x, y)) + assert xyz[::2] is Tuple((x, z)) x1, y1, z1 = xyz assert x1 is x