From 6b2db4dda0d29b86342d788d52085ee4b9df882e Mon Sep 17 00:00:00 2001 From: Fritz Obermeyer Date: Fri, 24 Sep 2021 20:06:54 -0400 Subject: [PATCH 1/4] Add an ops.getslice for more complex eager indexing --- funsor/domains.py | 39 ++++++++++++++++++++ funsor/ops/builtin.py | 84 +++++++++++++++++++++++++++++++++++++++++++ funsor/tensor.py | 10 ++++++ funsor/terms.py | 21 ++++++----- test/test_tensor.py | 45 +++++++++++++++++++++++ 5 files changed, 188 insertions(+), 11 deletions(-) diff --git a/funsor/domains.py b/funsor/domains.py index a9472e7a4..84dee368c 100644 --- a/funsor/domains.py +++ b/funsor/domains.py @@ -10,6 +10,7 @@ from weakref import WeakValueDictionary import funsor.ops as ops +from funsor.ops.builtin import parse_ellipsis, parse_slice from funsor.util import broadcast_shape, get_backend, get_tracing_state, quote Domain = type @@ -331,6 +332,44 @@ def _find_domain_getitem(op, lhs_domain, rhs_domain): ) +@find_domain.register(ops.GetsliceOp) +def _find_domain_getslice(op, domain): + index = op.defaults["index"] + left, right = parse_ellipsis(index) + if isinstance(domain, ArrayType): + dtype = domain.dtype + shape = list(domain.shape) + + offset = len(shape) + for i, part in enumerate(left): + i -= offset + if part is None: + shape.insert(i, 1) + elif isinstance(part, int): + del shape[i] + elif isinstance(part, slice): + start, stop, step = parse_slice(part, shape[i]) + shape[i] = max(0, (stop - start) // step) + else: + raise ValueError(part) + + for i in range(-len(right), 0): + part = right[i] + if part is None: + shape.insert(len(shape) + i, 1) + elif isinstance(part, int): + del shape[i] + elif isinstance(part, slice): + start, stop, step = parse_slice(part, shape[i]) + shape[i] = max(0, (stop - start) // step) + else: + raise ValueError(part) + + return Array[dtype, tuple(shape)] + + raise NotImplementedError("TODO") + + @find_domain.register(ops.BinaryOp) def _find_domain_pointwise_binary_generic(op, lhs, rhs): if ( diff --git a/funsor/ops/builtin.py b/funsor/ops/builtin.py index 103cf3e1f..ccd1c327d 100644 --- a/funsor/ops/builtin.py +++ b/funsor/ops/builtin.py @@ -13,6 +13,7 @@ UNITS, BinaryOp, Op, + OpMeta, TransformOp, UnaryOp, declare_op_types, @@ -43,6 +44,88 @@ def getitem(lhs, rhs, offset=0): return lhs[(slice(None),) * offset + (rhs,)] +class GetsliceMeta(OpMeta): + """ + Works around slice objects not being hashable. + """ + + def hash_args_kwargs(cls, args, kwargs): + index = args[0] if args else kwargs["index"] + if not isinstance(index, tuple): + index = (index,) + key = tuple( + (x.start, x.stop, x.step) if isinstance(x, slice) else x for x in index + ) + return key + + +@UnaryOp.make(metaclass=GetsliceMeta) +def getslice(x, index=Ellipsis): + return x[index] + + +getslice.supported_types = (type(None), type(Ellipsis), int, slice) + + +def parse_ellipsis(index): + """ + Helper to split a slice into parts left and right of Ellipses. + + :param index: A tuple, or other object (None, int, slice, Funsor). + :returns: a pair of tuples ``left, right``. + :rtype: tuple + """ + if not isinstance(index, tuple): + index = (index,) + left = [] + for i, part in enumerate(index): + if part is Ellipsis: + break + left.append(part) + right = [] + for part in reversed(index[i:]): + if part is Ellipsis: + break + right.append(part) + right.reverse() + return tuple(left), tuple(right) + + +def parse_slice(s, size): + """ + Helper to determine nonnegative integer start, stop, and step of a slice. + + :param slice s: A slice. + :param int size: The size of the array being indexed into. + :returns: A tuple of nonnegative integers ``start, stop, step``. + :rtype: tuple + """ + start = s.start + if start is None: + start = 0 + assert isinstance(start, int) + if start >= 0: + start = min(size, start) + else: + start = max(0, size + start) + + stop = s.stop + if stop is None: + stop = size + assert isinstance(stop, int) + if stop >= 0: + stop = min(size, stop) + else: + stop = max(0, size + stop) + + step = s.step + if step is None: + step = 1 + assert isinstance(step, int) + + return start, stop, step + + abs = UnaryOp.make(_builtin_abs) eq = BinaryOp.make(operator.eq) ge = BinaryOp.make(operator.ge) @@ -194,6 +277,7 @@ def sigmoid_log_abs_det_jacobian(x, y): "floordiv", "ge", "getitem", + "getslice", "gt", "invert", "le", diff --git a/funsor/tensor.py b/funsor/tensor.py index 204c0aaf5..9a6d1bd68 100644 --- a/funsor/tensor.py +++ b/funsor/tensor.py @@ -859,6 +859,16 @@ def eager_getitem_tensor_tensor(op, lhs, rhs): return Tensor(data, inputs, lhs.dtype) +@eager.register(Binary, ops.GetsliceOp, Tensor) +def eager_getslice_tensor(op, x): + index = op.defaults["index"] + if not isinstance(index, tuple): + index = (index,) + index = (slice(None),) * len(x.inputs) + index + data = x.data[index] + return Tensor(data, x.inputs, x.dtype) + + @eager.register( Finitary, ops.StackOp, typing.Tuple[typing.Union[(Number, Tensor)], ...] ) diff --git a/funsor/terms.py b/funsor/terms.py index 354921bae..e8aa53f96 100644 --- a/funsor/terms.py +++ b/funsor/terms.py @@ -34,6 +34,7 @@ ) from funsor.interpreter import PatternMissingError, interpret from funsor.ops import AssociativeOp, GetitemOp, Op +from funsor.ops.builtin import parse_ellipsis from funsor.syntax import INFIX_OPERATORS, PREFIX_OPERATORS from funsor.typing import GenericTypeMeta, Variadic, deep_type, get_args, get_origin from funsor.util import getargspec, lazy_property, pretty, quote @@ -731,22 +732,18 @@ def __ge__(self, other): def __getitem__(self, other): if type(other) is not tuple: + if isinstance(other, ops.getslice.supported_types): + return ops.getslice(self, other) other = to_funsor(other, Bint[self.output.shape[0]]) return Binary(ops.getitem, self, other) + # Handle complex slicing operations involving no funsors. + if all(isinstance(part, ops.getslice.supported_types) for part in other): + return ops.getslice(self, other) + # Handle Ellipsis slicing. if any(part is Ellipsis for part in other): - left = [] - for part in other: - if part is Ellipsis: - break - left.append(part) - right = [] - for part in reversed(other): - if part is Ellipsis: - break - right.append(part) - right.reverse() + left, right = parse_ellipsis(other) missing = len(self.output.shape) - len(left) - len(right) assert missing >= 0 middle = [slice(None)] * missing @@ -756,6 +753,8 @@ def __getitem__(self, other): result = self offset = 0 for part in other: + if part is None: + raise NotImplementedError("TODO") if isinstance(part, slice): if part != slice(None): raise NotImplementedError("TODO support nontrivial slicing") diff --git a/test/test_tensor.py b/test/test_tensor.py index daee33f21..62ae311ed 100644 --- a/test/test_tensor.py +++ b/test/test_tensor.py @@ -608,6 +608,51 @@ def test_lambda_getitem(): assert Lambda(i, y) is x +class _Slice: + def __getitem__(self, index): + return index + + +_slice = _Slice() + + +@pytest.mark.parametrize( + "index", + [ + _slice[0], + _slice[1, 2], + _slice[None], + _slice[None, 1], + _slice[2, None], + _slice[None, 0, None], + _slice[:], + _slice[:, :], + _slice[1:], + _slice[1:3], + _slice[::2], + _slice[1::2], + _slice[:, None], + _slice[None, :], + _slice[None, :, 1], + _slice[...], + _slice[..., 0], + _slice[..., 0, 1], + _slice[..., 0, :], + _slice[..., None, :], + _slice[..., 1:-1:2, :], + _slice[:, 0, ...], + _slice[:, None, ...], + _slice[:, 1:-1:2, ...], + ], + ids=str, +) +def test_getslice_shape(index): + data = randn(6, 5, 4, 3) + expected = Tensor(data[index]) + actual = Tensor(data)[index] + assert_close(actual, expected) + + REDUCE_OPS = [ ops.add, ops.mul, From 6f7b32636503a6520b59a203c57518eb714c16fe Mon Sep 17 00:00:00 2001 From: Fritz Obermeyer Date: Sat, 25 Sep 2021 10:04:22 -0400 Subject: [PATCH 2/4] Fix bugs, add patterns, add more tests --- funsor/domains.py | 32 +++++++++++++----- funsor/ops/builtin.py | 21 ++++++++++-- funsor/tensor.py | 2 +- funsor/terms.py | 36 +++++++++++++++++++- funsor/testing.py | 17 ++++++++++ test/test_ops.py | 38 ++++++++++++++++++++++ test/test_tensor.py | 76 ++++++++++++++++++++++++------------------- test/test_terms.py | 4 +++ 8 files changed, 180 insertions(+), 46 deletions(-) diff --git a/funsor/domains.py b/funsor/domains.py index 84dee368c..59cf7fc52 100644 --- a/funsor/domains.py +++ b/funsor/domains.py @@ -335,38 +335,52 @@ def _find_domain_getitem(op, lhs_domain, rhs_domain): @find_domain.register(ops.GetsliceOp) def _find_domain_getslice(op, domain): index = op.defaults["index"] - left, right = parse_ellipsis(index) if isinstance(domain, ArrayType): dtype = domain.dtype shape = list(domain.shape) + left, right = parse_ellipsis(index) - offset = len(shape) - for i, part in enumerate(left): - i -= offset + i = 0 + for part in left: if part is None: shape.insert(i, 1) + i += 1 elif isinstance(part, int): del shape[i] elif isinstance(part, slice): start, stop, step = parse_slice(part, shape[i]) - shape[i] = max(0, (stop - start) // step) + shape[i] = max(0, (stop - start + step - 1) // step) + i += 1 else: raise ValueError(part) - for i in range(-len(right), 0): - part = right[i] + i = -1 + for part in reversed(right): if part is None: - shape.insert(len(shape) + i, 1) + shape.insert(len(shape) + i + 1, 1) + i -= 1 elif isinstance(part, int): del shape[i] elif isinstance(part, slice): start, stop, step = parse_slice(part, shape[i]) - shape[i] = max(0, (stop - start) // step) + shape[i] = max(0, (stop - start + step - 1) // step) + i -= 1 else: raise ValueError(part) return Array[dtype, tuple(shape)] + if isinstance(domain, ProductDomain): + if isinstance(index, tuple): + assert len(index) == 1 + index = index[0] + if isinstance(index, int): + return domain.__args__[index] + elif isinstance(index, slice): + return Product[domain.__args__[index]] + else: + raise ValueError(index) + raise NotImplementedError("TODO") diff --git a/funsor/ops/builtin.py b/funsor/ops/builtin.py index ccd1c327d..5fee7e8ae 100644 --- a/funsor/ops/builtin.py +++ b/funsor/ops/builtin.py @@ -78,7 +78,9 @@ def parse_ellipsis(index): if not isinstance(index, tuple): index = (index,) left = [] - for i, part in enumerate(index): + i = 0 + for part in index: + i += 1 if part is Ellipsis: break left.append(part) @@ -91,9 +93,24 @@ def parse_ellipsis(index): return tuple(left), tuple(right) +def normalize_ellipsis(index, size): + """ + Expand Ellipses in an index to fill the given number of dimensions. + + This should satisfy the equation:: + + x[i] == x[normalize_ellipsis(i, len(x.shape))] + """ + left, right = parse_ellipsis(index) + if len(left) + len(right) > size: + raise ValueError(f"Index is too wide: {index}") + middle = (slice(None),) * (size - len(left) - len(right)) + return left + middle + right + + def parse_slice(s, size): """ - Helper to determine nonnegative integer start, stop, and step of a slice. + Helper to determine nonnegative integers (start, stop, step) of a slice. :param slice s: A slice. :param int size: The size of the array being indexed into. diff --git a/funsor/tensor.py b/funsor/tensor.py index 9a6d1bd68..a2bac8ff2 100644 --- a/funsor/tensor.py +++ b/funsor/tensor.py @@ -859,7 +859,7 @@ def eager_getitem_tensor_tensor(op, lhs, rhs): return Tensor(data, inputs, lhs.dtype) -@eager.register(Binary, ops.GetsliceOp, Tensor) +@eager.register(Unary, ops.GetsliceOp, Tensor) def eager_getslice_tensor(op, x): index = op.defaults["index"] if not isinstance(index, tuple): diff --git a/funsor/terms.py b/funsor/terms.py index e8aa53f96..707601a0d 100644 --- a/funsor/terms.py +++ b/funsor/terms.py @@ -34,7 +34,7 @@ ) from funsor.interpreter import PatternMissingError, interpret from funsor.ops import AssociativeOp, GetitemOp, Op -from funsor.ops.builtin import parse_ellipsis +from funsor.ops.builtin import normalize_ellipsis, parse_ellipsis from funsor.syntax import INFIX_OPERATORS, PREFIX_OPERATORS from funsor.typing import GenericTypeMeta, Variadic, deep_type, get_args, get_origin from funsor.util import getargspec, lazy_property, pretty, quote @@ -731,6 +731,11 @@ def __ge__(self, other): return Binary(ops.ge, self, to_funsor(other)) def __getitem__(self, other): + """ + Helper to desugar into either ops.getitem (for advanced indexing + involving Funsors as indices) or ops.getslice (for simple indexing + involving only integers, slices, None, and Ellipsis). + """ if type(other) is not tuple: if isinstance(other, ops.getslice.supported_types): return ops.getslice(self, other) @@ -1753,6 +1758,20 @@ def eager_getitem_lambda(op, lhs, rhs): return Lambda(lhs.var, expr) +@eager.register(Unary, ops.GetsliceOp, Lambda) +def eager_getslice_lambda(op, x): + index = normalize_ellipsis(op.defaults["index"], len(x.shape)) + head, tail = index[0], index[1:] + expr = x.expr + if head != slice(None): + expr = expr(**{x.var.name: head}) + if x.var.name not in expr.inputs: + return expr + if tail: + expr = ops.getslice(expr, tail) + return Lambda(x.var, expr) + + class Independent(Funsor): """ Creates an independent diagonal distribution. @@ -1884,6 +1903,21 @@ def eager_getitem_tuple(op, lhs, rhs): return op(lhs.args, rhs.data) +@lazy.register(Unary, ops.GetsliceOp, Tuple) +@eager.register(Unary, ops.GetsliceOp, Tuple) +def eager_getslice_tuple(op, x): + index = op.defaults["index"] + if isinstance(index, tuple): + assert len(index) == 1 + index = index[0] + if isinstance(index, int): + return op(x.args) + elif isinstance(index, slice): + return Tuple(op(x.args)) + else: + raise ValueError(index) + + def _symbolic(inputs, output, fn): args, vargs, kwargs, defaults = getargspec(fn) assert not vargs diff --git a/funsor/testing.py b/funsor/testing.py index c1c39e2db..acedea8f4 100644 --- a/funsor/testing.py +++ b/funsor/testing.py @@ -478,3 +478,20 @@ def iter_subsets(iterable, *, min_size=None, max_size=None): max_size = len(iterable) for size in range(min_size, max_size + 1): yield from itertools.combinations(iterable, size) + + +class DesugarGetitem: + """ + Helper to desugar ``.__getitem__()`` syntax. + + Example:: + + >>> desugar_getitem[1:3, ..., None] + (slice(1, 3), Ellipsis, None) + """ + + def __getitem__(self, index): + return index + + +desugar_getitem = DesugarGetitem() diff --git a/test/test_ops.py b/test/test_ops.py index f70b8d804..f9b2ca4a2 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -7,6 +7,8 @@ from funsor import ops from funsor.distribution import BACKEND_TO_DISTRIBUTIONS_BACKEND +from funsor.ops.builtin import parse_ellipsis, parse_slice +from funsor.testing import desugar_getitem from funsor.util import get_backend @@ -36,3 +38,39 @@ def test_transform_op_gc(dist): assert len(op_set) == 1 del op assert len(op_set) == 0 + + +@pytest.mark.parametrize( + "index, left, right", + [ + (desugar_getitem[()], (), ()), + (desugar_getitem[0], (0,), ()), + (desugar_getitem[...], (), ()), + (desugar_getitem[..., ...], (), ()), + (desugar_getitem[1, ...], (1,), ()), + (desugar_getitem[..., 1], (), (1,)), + (desugar_getitem[:, None, ..., 1, 1:2], (slice(None), None), (1, slice(1, 2))), + ], + ids=str, +) +def test_parse_ellipsis(index, left, right): + assert parse_ellipsis(index) == (left, right) + + +@pytest.mark.parametrize( + "s, size, start, stop, step", + [ + (desugar_getitem[:], 5, 0, 5, 1), + (desugar_getitem[:3], 5, 0, 3, 1), + (desugar_getitem[-9:3], 5, 0, 3, 1), + (desugar_getitem[:-2], 5, 0, 3, 1), + (desugar_getitem[2:], 5, 2, 5, 1), + (desugar_getitem[2:9], 5, 2, 5, 1), + (desugar_getitem[-3:], 5, 2, 5, 1), + (desugar_getitem[-3:-2], 5, 2, 3, 1), + ], + ids=str, +) +def test_parse_slice(s, size, start, stop, step): + actual = parse_slice(s, size) + assert actual == (start, stop, step) diff --git a/test/test_tensor.py b/test/test_tensor.py index 62ae311ed..1a9f41e94 100644 --- a/test/test_tensor.py +++ b/test/test_tensor.py @@ -31,6 +31,7 @@ assert_close, assert_equiv, check_funsor, + desugar_getitem, empty, iter_subsets, rand, @@ -608,50 +609,59 @@ def test_lambda_getitem(): assert Lambda(i, y) is x -class _Slice: - def __getitem__(self, index): - return index - - -_slice = _Slice() - - @pytest.mark.parametrize( "index", [ - _slice[0], - _slice[1, 2], - _slice[None], - _slice[None, 1], - _slice[2, None], - _slice[None, 0, None], - _slice[:], - _slice[:, :], - _slice[1:], - _slice[1:3], - _slice[::2], - _slice[1::2], - _slice[:, None], - _slice[None, :], - _slice[None, :, 1], - _slice[...], - _slice[..., 0], - _slice[..., 0, 1], - _slice[..., 0, :], - _slice[..., None, :], - _slice[..., 1:-1:2, :], - _slice[:, 0, ...], - _slice[:, None, ...], - _slice[:, 1:-1:2, ...], + desugar_getitem[0], + desugar_getitem[1, 2], + desugar_getitem[None], + desugar_getitem[None, 1], + desugar_getitem[2, None], + desugar_getitem[None, 0, None], + desugar_getitem[:], + desugar_getitem[:, :], + desugar_getitem[1:], + desugar_getitem[1:3], + desugar_getitem[::2], + desugar_getitem[1::2], + desugar_getitem[:, None], + desugar_getitem[None, :], + desugar_getitem[None, :, 1], + desugar_getitem[...], + desugar_getitem[..., 0], + desugar_getitem[..., 0, 1], + desugar_getitem[..., 0, :], + desugar_getitem[..., None, :], + desugar_getitem[..., 1:-1:2, :], + desugar_getitem[:, 0, ...], + desugar_getitem[:, None, ...], + desugar_getitem[:, 1:-1:2, ...], + desugar_getitem[None, ..., None], + desugar_getitem[:, None, ..., :, None], + desugar_getitem[:, None, ..., None, :], + desugar_getitem[None, :, ..., :, None], + desugar_getitem[None, :, ..., None, :], + desugar_getitem[0, None, ..., 0, None], + desugar_getitem[0, None, ..., None, 0], + desugar_getitem[None, 0, ..., 0, None], + desugar_getitem[None, 0, ..., None, 0], ], ids=str, ) def test_getslice_shape(index): - data = randn(6, 5, 4, 3) + shape = (6, 5, 4, 3) + data = randn(shape) expected = Tensor(data[index]) + + # Check eager indexing. actual = Tensor(data)[index] assert_close(actual, expected) + # Check lazy find_domain. + actual = Variable("x", Reals[shape])[index] + assert actual.dtype == expected.dtype + assert actual.shape == expected.shape + REDUCE_OPS = [ ops.add, diff --git a/test/test_terms.py b/test/test_terms.py index 720bd5715..63d75c67c 100644 --- a/test/test_terms.py +++ b/test/test_terms.py @@ -643,6 +643,10 @@ def test_funsor_tuple(): assert xyz[0] is x assert xyz[1] is y assert xyz[2] is z + assert xyz[:] is xyz + assert xyz[1:] is Tuple((y, z)) + assert xyz[:2] is Tuple((x, y)) + assert xyz[::2] is Tuple((x, z)) x1, y1, z1 = xyz assert x1 is x From 33d382b3b9a364073c797871c13927cf6d19e7a2 Mon Sep 17 00:00:00 2001 From: Fritz Obermeyer Date: Sat, 25 Sep 2021 10:23:45 -0400 Subject: [PATCH 3/4] fix is_affine() --- funsor/affine.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/funsor/affine.py b/funsor/affine.py index 0751f84b2..1e07a152c 100644 --- a/funsor/affine.py +++ b/funsor/affine.py @@ -61,7 +61,9 @@ def _(fn): @affine_inputs.register(Unary) def _(fn): - if fn.op in (ops.neg, ops.sum) or isinstance(fn.op, ops.ReshapeOp): + if fn.op in (ops.neg, ops.sum) or isinstance( + fn.op, (ops.ReshapeOp, ops.GetsliceOp) + ): return affine_inputs(fn.arg) return frozenset() From 697c6cd019158a8e5b78cbfe996bd6c0d6b713bf Mon Sep 17 00:00:00 2001 From: Fritz Obermeyer Date: Mon, 27 Sep 2021 09:48:56 -0400 Subject: [PATCH 4/4] Fix eager_getslice_lambda --- funsor/terms.py | 46 +++++++++++++++++++++++++++++----------------- test/test_terms.py | 29 +++++++++++++++++++++++++++++ 2 files changed, 58 insertions(+), 17 deletions(-) diff --git a/funsor/terms.py b/funsor/terms.py index 707601a0d..021c31878 100644 --- a/funsor/terms.py +++ b/funsor/terms.py @@ -17,6 +17,7 @@ from funsor.domains import ( Array, Bint, + BintType, Domain, Product, ProductDomain, @@ -34,7 +35,7 @@ ) from funsor.interpreter import PatternMissingError, interpret from funsor.ops import AssociativeOp, GetitemOp, Op -from funsor.ops.builtin import normalize_ellipsis, parse_ellipsis +from funsor.ops.builtin import normalize_ellipsis, parse_ellipsis, parse_slice from funsor.syntax import INFIX_OPERATORS, PREFIX_OPERATORS from funsor.typing import GenericTypeMeta, Variadic, deep_type, get_args, get_origin from funsor.util import getargspec, lazy_property, pretty, quote @@ -1478,6 +1479,15 @@ def eager_subs(self, subs): ) +@to_funsor.register(slice) +def slice_to_funsor(s, output=None, dim_to_name=None): + if not isinstance(output, BintType): + raise ValueError("Incompatible slice output: {output}") + start, stop, step = parse_slice(s, output.size) + i = Variable("slice", output) + return Lambda(i, Slice("slice", start, stop, step, output.size)) + + class Align(Funsor): """ Lazy call to ``.align(...)``. @@ -1575,20 +1585,21 @@ def eager_subs(self, subs): index = subs[0][1] # Try to eagerly select an index. - assert index.output == Bint[len(self.parts)] - - if isinstance(index, Number): - # Select a single part. - return self.parts[index.data] - elif isinstance(index, Variable): - # Rename the stacking dimension. - parts = self.parts - return Stack(index.name, parts) - elif isinstance(index, Slice): - parts = self.parts[index.slice] - return Stack(index.name, parts) + if index.output == Bint[len(self.parts)]: + if isinstance(index, Number): + # Select a single part. + return self.parts[index.data] + elif isinstance(index, Variable): + # Rename the stacking dimension. + parts = self.parts + return Stack(index.name, parts) + elif isinstance(index, Slice): + parts = self.parts[index.slice] + return Stack(index.name, parts) + else: + raise NotImplementedError("TODO support advanced indexing in Stack") else: - raise NotImplementedError("TODO support advanced indexing in Stack") + raise NotImplementedError("TODO support slicing in Stack") def eager_reduce(self, op, reduced_vars): parts = self.parts @@ -1765,11 +1776,12 @@ def eager_getslice_lambda(op, x): expr = x.expr if head != slice(None): expr = expr(**{x.var.name: head}) - if x.var.name not in expr.inputs: - return expr if tail: expr = ops.getslice(expr, tail) - return Lambda(x.var, expr) + if x.var.name in expr.inputs: # dim is preserved, e.g. x[1:] + return Lambda(x.var, expr) + else: # dim is eliminated, e.g. x[0] + return expr class Independent(Funsor): diff --git a/test/test_terms.py b/test/test_terms.py index 63d75c67c..5f99e5537 100644 --- a/test/test_terms.py +++ b/test/test_terms.py @@ -629,6 +629,35 @@ def test_stack_lambda(dtype): assert z[1] is x2 +@pytest.mark.parametrize("dtype", ["real", 4, 5]) +def test_stack_lambda_2(dtype): + + x1 = Number(0, dtype) + x2 = Number(1, dtype) + x3 = Number(2, dtype) + x4 = Number(3, dtype) + x = [[x1, x2, x3], [x2, x3, x4]] + + i = Variable("i", Bint[3]) + y1 = Lambda(i, Stack("i", (x1, x2, x3))) + y2 = Lambda(i, Stack("i", (x2, x3, x4))) + + j = Variable("j", Bint[2]) + z = Lambda(j, Stack("j", (y1, y2))) + assert not z.inputs + assert z.output == Array[dtype, (2, 3)] + + assert z[0] is y1 + assert z[1] is y2 + for i, j in itertools.product(range(2), range(3)): + assert z[i, j] is x[i][j] + assert z[:, j][i] is x[i][j] + assert z[i, :][j] is x[i][j] + # TODO support advanced slicing of Stack + # assert z[0:9, j][i] is x[i][j] + # assert z[i, 0:9][j] is x[i][j] + + def test_funsor_tuple(): x = Number(1, 3) y = Number(2.5, "real")