Skip to content

Commit

Permalink
Add ops.getslice for complex indexing by int,slice,None,Ellipsis (#555)
Browse files Browse the repository at this point in the history
  • Loading branch information
fritzo authored Sep 27, 2021
1 parent 3d43e6b commit 86d4a22
Show file tree
Hide file tree
Showing 9 changed files with 379 additions and 25 deletions.
4 changes: 3 additions & 1 deletion funsor/affine.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,9 @@ def _(fn):

@affine_inputs.register(Unary)
def _(fn):
if fn.op in (ops.neg, ops.sum) or isinstance(fn.op, ops.ReshapeOp):
if fn.op in (ops.neg, ops.sum) or isinstance(
fn.op, (ops.ReshapeOp, ops.GetsliceOp)
):
return affine_inputs(fn.arg)
return frozenset()

Expand Down
53 changes: 53 additions & 0 deletions funsor/domains.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from weakref import WeakValueDictionary

import funsor.ops as ops
from funsor.ops.builtin import parse_ellipsis, parse_slice
from funsor.util import broadcast_shape, get_backend, get_tracing_state, quote

Domain = type
Expand Down Expand Up @@ -331,6 +332,58 @@ def _find_domain_getitem(op, lhs_domain, rhs_domain):
)


@find_domain.register(ops.GetsliceOp)
def _find_domain_getslice(op, domain):
index = op.defaults["index"]
if isinstance(domain, ArrayType):
dtype = domain.dtype
shape = list(domain.shape)
left, right = parse_ellipsis(index)

i = 0
for part in left:
if part is None:
shape.insert(i, 1)
i += 1
elif isinstance(part, int):
del shape[i]
elif isinstance(part, slice):
start, stop, step = parse_slice(part, shape[i])
shape[i] = max(0, (stop - start + step - 1) // step)
i += 1
else:
raise ValueError(part)

i = -1
for part in reversed(right):
if part is None:
shape.insert(len(shape) + i + 1, 1)
i -= 1
elif isinstance(part, int):
del shape[i]
elif isinstance(part, slice):
start, stop, step = parse_slice(part, shape[i])
shape[i] = max(0, (stop - start + step - 1) // step)
i -= 1
else:
raise ValueError(part)

return Array[dtype, tuple(shape)]

if isinstance(domain, ProductDomain):
if isinstance(index, tuple):
assert len(index) == 1
index = index[0]
if isinstance(index, int):
return domain.__args__[index]
elif isinstance(index, slice):
return Product[domain.__args__[index]]
else:
raise ValueError(index)

raise NotImplementedError("TODO")


@find_domain.register(ops.BinaryOp)
def _find_domain_pointwise_binary_generic(op, lhs, rhs):
if (
Expand Down
101 changes: 101 additions & 0 deletions funsor/ops/builtin.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
UNITS,
BinaryOp,
Op,
OpMeta,
TransformOp,
UnaryOp,
declare_op_types,
Expand Down Expand Up @@ -43,6 +44,105 @@ def getitem(lhs, rhs, offset=0):
return lhs[(slice(None),) * offset + (rhs,)]


class GetsliceMeta(OpMeta):
"""
Works around slice objects not being hashable.
"""

def hash_args_kwargs(cls, args, kwargs):
index = args[0] if args else kwargs["index"]
if not isinstance(index, tuple):
index = (index,)
key = tuple(
(x.start, x.stop, x.step) if isinstance(x, slice) else x for x in index
)
return key


@UnaryOp.make(metaclass=GetsliceMeta)
def getslice(x, index=Ellipsis):
return x[index]


getslice.supported_types = (type(None), type(Ellipsis), int, slice)


def parse_ellipsis(index):
"""
Helper to split a slice into parts left and right of Ellipses.
:param index: A tuple, or other object (None, int, slice, Funsor).
:returns: a pair of tuples ``left, right``.
:rtype: tuple
"""
if not isinstance(index, tuple):
index = (index,)
left = []
i = 0
for part in index:
i += 1
if part is Ellipsis:
break
left.append(part)
right = []
for part in reversed(index[i:]):
if part is Ellipsis:
break
right.append(part)
right.reverse()
return tuple(left), tuple(right)


def normalize_ellipsis(index, size):
"""
Expand Ellipses in an index to fill the given number of dimensions.
This should satisfy the equation::
x[i] == x[normalize_ellipsis(i, len(x.shape))]
"""
left, right = parse_ellipsis(index)
if len(left) + len(right) > size:
raise ValueError(f"Index is too wide: {index}")
middle = (slice(None),) * (size - len(left) - len(right))
return left + middle + right


def parse_slice(s, size):
"""
Helper to determine nonnegative integers (start, stop, step) of a slice.
:param slice s: A slice.
:param int size: The size of the array being indexed into.
:returns: A tuple of nonnegative integers ``start, stop, step``.
:rtype: tuple
"""
start = s.start
if start is None:
start = 0
assert isinstance(start, int)
if start >= 0:
start = min(size, start)
else:
start = max(0, size + start)

stop = s.stop
if stop is None:
stop = size
assert isinstance(stop, int)
if stop >= 0:
stop = min(size, stop)
else:
stop = max(0, size + stop)

step = s.step
if step is None:
step = 1
assert isinstance(step, int)

return start, stop, step


abs = UnaryOp.make(_builtin_abs)
eq = BinaryOp.make(operator.eq)
ge = BinaryOp.make(operator.ge)
Expand Down Expand Up @@ -194,6 +294,7 @@ def sigmoid_log_abs_det_jacobian(x, y):
"floordiv",
"ge",
"getitem",
"getslice",
"gt",
"invert",
"le",
Expand Down
10 changes: 10 additions & 0 deletions funsor/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -859,6 +859,16 @@ def eager_getitem_tensor_tensor(op, lhs, rhs):
return Tensor(data, inputs, lhs.dtype)


@eager.register(Unary, ops.GetsliceOp, Tensor)
def eager_getslice_tensor(op, x):
index = op.defaults["index"]
if not isinstance(index, tuple):
index = (index,)
index = (slice(None),) * len(x.inputs) + index
data = x.data[index]
return Tensor(data, x.inputs, x.dtype)


@eager.register(
Finitary, ops.StackOp, typing.Tuple[typing.Union[(Number, Tensor)], ...]
)
Expand Down
93 changes: 69 additions & 24 deletions funsor/terms.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from funsor.domains import (
Array,
Bint,
BintType,
Domain,
Product,
ProductDomain,
Expand All @@ -34,6 +35,7 @@
)
from funsor.interpreter import PatternMissingError, interpret
from funsor.ops import AssociativeOp, GetitemOp, Op
from funsor.ops.builtin import normalize_ellipsis, parse_ellipsis, parse_slice
from funsor.syntax import INFIX_OPERATORS, PREFIX_OPERATORS
from funsor.typing import GenericTypeMeta, Variadic, deep_type, get_args, get_origin
from funsor.util import getargspec, lazy_property, pretty, quote
Expand Down Expand Up @@ -730,23 +732,24 @@ def __ge__(self, other):
return Binary(ops.ge, self, to_funsor(other))

def __getitem__(self, other):
"""
Helper to desugar into either ops.getitem (for advanced indexing
involving Funsors as indices) or ops.getslice (for simple indexing
involving only integers, slices, None, and Ellipsis).
"""
if type(other) is not tuple:
if isinstance(other, ops.getslice.supported_types):
return ops.getslice(self, other)
other = to_funsor(other, Bint[self.output.shape[0]])
return Binary(ops.getitem, self, other)

# Handle complex slicing operations involving no funsors.
if all(isinstance(part, ops.getslice.supported_types) for part in other):
return ops.getslice(self, other)

# Handle Ellipsis slicing.
if any(part is Ellipsis for part in other):
left = []
for part in other:
if part is Ellipsis:
break
left.append(part)
right = []
for part in reversed(other):
if part is Ellipsis:
break
right.append(part)
right.reverse()
left, right = parse_ellipsis(other)
missing = len(self.output.shape) - len(left) - len(right)
assert missing >= 0
middle = [slice(None)] * missing
Expand All @@ -756,6 +759,8 @@ def __getitem__(self, other):
result = self
offset = 0
for part in other:
if part is None:
raise NotImplementedError("TODO")
if isinstance(part, slice):
if part != slice(None):
raise NotImplementedError("TODO support nontrivial slicing")
Expand Down Expand Up @@ -1474,6 +1479,15 @@ def eager_subs(self, subs):
)


@to_funsor.register(slice)
def slice_to_funsor(s, output=None, dim_to_name=None):
if not isinstance(output, BintType):
raise ValueError("Incompatible slice output: {output}")
start, stop, step = parse_slice(s, output.size)
i = Variable("slice", output)
return Lambda(i, Slice("slice", start, stop, step, output.size))


class Align(Funsor):
"""
Lazy call to ``.align(...)``.
Expand Down Expand Up @@ -1571,20 +1585,21 @@ def eager_subs(self, subs):
index = subs[0][1]

# Try to eagerly select an index.
assert index.output == Bint[len(self.parts)]

if isinstance(index, Number):
# Select a single part.
return self.parts[index.data]
elif isinstance(index, Variable):
# Rename the stacking dimension.
parts = self.parts
return Stack(index.name, parts)
elif isinstance(index, Slice):
parts = self.parts[index.slice]
return Stack(index.name, parts)
if index.output == Bint[len(self.parts)]:
if isinstance(index, Number):
# Select a single part.
return self.parts[index.data]
elif isinstance(index, Variable):
# Rename the stacking dimension.
parts = self.parts
return Stack(index.name, parts)
elif isinstance(index, Slice):
parts = self.parts[index.slice]
return Stack(index.name, parts)
else:
raise NotImplementedError("TODO support advanced indexing in Stack")
else:
raise NotImplementedError("TODO support advanced indexing in Stack")
raise NotImplementedError("TODO support slicing in Stack")

def eager_reduce(self, op, reduced_vars):
parts = self.parts
Expand Down Expand Up @@ -1754,6 +1769,21 @@ def eager_getitem_lambda(op, lhs, rhs):
return Lambda(lhs.var, expr)


@eager.register(Unary, ops.GetsliceOp, Lambda)
def eager_getslice_lambda(op, x):
index = normalize_ellipsis(op.defaults["index"], len(x.shape))
head, tail = index[0], index[1:]
expr = x.expr
if head != slice(None):
expr = expr(**{x.var.name: head})
if tail:
expr = ops.getslice(expr, tail)
if x.var.name in expr.inputs: # dim is preserved, e.g. x[1:]
return Lambda(x.var, expr)
else: # dim is eliminated, e.g. x[0]
return expr


class Independent(Funsor):
"""
Creates an independent diagonal distribution.
Expand Down Expand Up @@ -1885,6 +1915,21 @@ def eager_getitem_tuple(op, lhs, rhs):
return op(lhs.args, rhs.data)


@lazy.register(Unary, ops.GetsliceOp, Tuple)
@eager.register(Unary, ops.GetsliceOp, Tuple)
def eager_getslice_tuple(op, x):
index = op.defaults["index"]
if isinstance(index, tuple):
assert len(index) == 1
index = index[0]
if isinstance(index, int):
return op(x.args)
elif isinstance(index, slice):
return Tuple(op(x.args))
else:
raise ValueError(index)


def _symbolic(inputs, output, fn):
args, vargs, kwargs, defaults = getargspec(fn)
assert not vargs
Expand Down
17 changes: 17 additions & 0 deletions funsor/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -478,3 +478,20 @@ def iter_subsets(iterable, *, min_size=None, max_size=None):
max_size = len(iterable)
for size in range(min_size, max_size + 1):
yield from itertools.combinations(iterable, size)


class DesugarGetitem:
"""
Helper to desugar ``.__getitem__()`` syntax.
Example::
>>> desugar_getitem[1:3, ..., None]
(slice(1, 3), Ellipsis, None)
"""

def __getitem__(self, index):
return index


desugar_getitem = DesugarGetitem()
Loading

0 comments on commit 86d4a22

Please sign in to comment.