Skip to content

Commit

Permalink
Add new functions jax.ops.index_add and jax.ops.index_update for …
Browse files Browse the repository at this point in the history
…NumPy-style indexed updates.

Create a new library `jax.ops` for user-facing ops that don't exist in NumPy or SciPy.

Progress on issue jax-ml#101. Fixes jax-ml#122.
  • Loading branch information
hawkinsp committed Mar 4, 2019
1 parent ff89a90 commit 25a93c2
Show file tree
Hide file tree
Showing 4 changed files with 515 additions and 160 deletions.
1 change: 1 addition & 0 deletions docs/jax.rst
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ Subpackages
jax.scipy
jax.experimental
jax.lax
jax.ops
jax.random

Module contents
Expand Down
17 changes: 17 additions & 0 deletions jax/ops/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
# Copyright 2019 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import absolute_import

from .scatter import index, index_add, index_update
246 changes: 246 additions & 0 deletions jax/ops/scatter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,246 @@
# Copyright 2019 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Helpers for indexed updates.

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import numpy as onp

from ..abstract_arrays import ShapedArray, ConcreteArray
from .. import core
from .. import lax
from ..numpy import lax_numpy as np


def _scatter_update(x, idx, y, scatter_op):
"""Helper for indexed updates.
Computes the value of x that would result from computing::
x[idx] op= y
except in a pure functional way, with no in-place updating.
Support NumPy-style basic indexing only, i.e., `idx` must be
`None`, an integer, a `slice` object, or ellipses, or a tuple of the above.
TODO(phawkins): support advanced indexing.
"""

x = np.asarray(x)
y = np.asarray(y)
x_shape = np.shape(x)
y_shape = np.shape(y)
y = lax.convert_element_type(y, lax._dtype(x))

if not isinstance(idx, tuple):
idx = (idx,)

# Test for unsupported advanced indexing and report an error.
if any(onp.ndim(elt) != 0 for elt in idx):
raise NotImplementedError("Unimplemented case for indexed update. Advanced "
"indexing is not yet implemented.")

# Remove ellipses and add trailing slice(None)s.
idx = np._canonicalize_tuple_index(x, idx)

_int = lambda aval: not aval.shape and onp.issubdtype(aval.dtype, onp.integer)

x_axis = 0
y_axis = 0 # Current axis in y, before collapsing. See below.
collapsed_y_axis = 0 # Current axis in y, after collapsing.

# Scatter dimension numbers.
update_window_dims = []
inserted_window_dims = []
scatter_dims_to_operand_dims = []

scatter_indices = np.zeros((0,), dtype=np.int32)

# We perform three transformations to y before the scatter op, in order:
# First, y is broadcast to slice_shape. In general `y` only need broadcast to
# the right shape.
slice_shape = []
# Next, y is reshaped to collapsed_slice_shape. This is to handle `None`
# indices, which the scatter cannot remove itself.
collapsed_slice_shape = []
# Finally, we reverse reversed_y_dims to handle slices with negative strides.
reversed_y_dims = []

for i in idx:
try:
abstract_i = core.get_aval(i)
except TypeError:
abstract_i = None
if (isinstance(abstract_i, ConcreteArray) or
isinstance(abstract_i, ShapedArray)) and _int(abstract_i):
i = np.mod(i, np._constant_like(i, x.shape[x_axis]))
i = lax.convert_element_type(i, np.int32)
i = np.broadcast_to(i, tuple(scatter_indices.shape[:-1]) + (1,))
scatter_indices = np.concatenate((scatter_indices, i), -1)
inserted_window_dims.append(x_axis)
scatter_dims_to_operand_dims.append(x_axis)
x_axis += 1
elif i is None:
slice_shape.append(1)
y_axis += 1
elif np._is_slice_none(i):
slice_shape.append(x_shape[x_axis])
collapsed_slice_shape.append(x_shape[x_axis])
update_window_dims.append(collapsed_y_axis)
collapsed_y_axis += 1
y_axis += 1
x_axis += 1
elif isinstance(i, slice):
start, limit, stride, needs_rev = np._static_idx(i, x.shape[x_axis])
if needs_rev:
reversed_y_dims.append(collapsed_y_axis)
if stride == 1:
i = lax.convert_element_type(start, np.int32)
i = np.broadcast_to(i, tuple(scatter_indices.shape[:-1]) + (1,))
scatter_indices = np.concatenate((scatter_indices, i), -1)
slice_shape.append(limit - start)
collapsed_slice_shape.append(limit - start)
update_window_dims.append(collapsed_y_axis)
scatter_dims_to_operand_dims.append(x_axis)
else:
i = np.arange(start, limit, stride, dtype=np.int32)
size = i.shape[0]
slice_shape.append(size)
collapsed_slice_shape.append(size)
scatter_indices_shape = tuple(scatter_indices.shape[:-1]) + (size,)
i = lax.broadcast_in_dim(
i, shape=scatter_indices_shape + (1,),
broadcast_dimensions=(len(scatter_indices_shape) - 1,))
scatter_indices = lax.broadcast_in_dim(
scatter_indices,
shape=scatter_indices_shape + (len(scatter_dims_to_operand_dims),),
broadcast_dimensions=(
tuple(range(len(scatter_indices_shape) - 1)) +
(len(scatter_indices_shape),)))
scatter_indices = np.concatenate(
(scatter_indices, i), len(scatter_indices_shape))
scatter_dims_to_operand_dims.append(x_axis)
inserted_window_dims.append(x_axis)

collapsed_y_axis += 1
y_axis += 1
x_axis += 1
else:
raise IndexError("Unknown index type ", i)

y = np.broadcast_to(y, tuple(slice_shape))
y = lax.reshape(y, collapsed_slice_shape)
if reversed_y_dims:
y = lax.rev(y, reversed_y_dims)

dnums = lax.ScatterDimensionNumbers(
update_window_dims = tuple(update_window_dims),
inserted_window_dims = tuple(inserted_window_dims),
scatter_dims_to_operand_dims = tuple(scatter_dims_to_operand_dims)
)
return scatter_op(x, scatter_indices, y, dnums)


class _Indexable(object):
"""Helper object for building indexes for indexed update functions.
This is a singleton object that overrides the :code:`__getitem__` method
to return the index it is passed.
>>> jax.ops.index[1:2, 3, None, ..., ::2]
(slice(1, 2, None), 3, None, Ellipsis, slice(None, None, 2))
"""
__slots__ = ()

def __getitem__(self, index):
return index

#: Index object singleton
index = _Indexable()


def index_add(x, idx, y):
"""Pure equivalent of :code:`x[idx] += y`.
Returns the the value of `x` that would result from the
NumPy-style :mod:`indexed assignment <numpy.doc.indexing>`::
x[idx] += y
Note the `index_add` operator is pure; `x` itself is
not modified, instead the new value that `x` would have taken is returned.
Unlike the NumPy code :code:`x[idx] += y`, if multiple indices refer to the
same location the updates will be summed. (NumPy would only apply the last
update, rather than summing the updates.) The order in which conflicting
updates are applied is implementation-defined and may be nondeterministic
(e.g., due to concurrency on some hardware platforms).
Args:
x: an array.
idx: a Numpy-style basic index, consisting of `None`, integers, `slice`
objects, ellipses, or a tuple of the above. A convenient syntactic sugar
for forming indices is via the :data:`jax.ops.index` object.
y: the array of updates. `y` must be broadcastable to the shape of the
array that would be returned by `x[idx]`.
Returns:
An array.
>>> x = jax.numpy.ones((5, 6))
>>> jax.ops.index_add(x, jax.ops.index[2:4, 3:], 6.)
array([[1., 1., 1., 1., 1., 1.],
[1., 1., 1., 1., 1., 1.],
[1., 1., 1., 7., 7., 7.],
[1., 1., 1., 7., 7., 7.],
[1., 1., 1., 1., 1., 1.]], dtype=float32)
"""
return _scatter_update(x, idx, y, lax.scatter_add)

def index_update(x, idx, y):
"""Pure equivalent of :code:`x[idx] = y`.
Returns the the value of `x` that would result from the
NumPy-style :mod:`indexed assignment <numpy.doc.indexing>`::
x[idx] += y
Note the `index_update` operator is pure; `x` itself is
not modified, instead the new value that `x` would have taken is returned.
Unlike NumPy's :code:`x[idx] = y`, if multiple indices refer to the same
location it is undefined which update is chosen; JAX may choose the order of
updates arbitrarily and nondeterministically (e.g., due to concurrent
updates on some hardware platforms).
Args:
x: an array.
idx: a Numpy-style basic index, consisting of `None`, integers, `slice`
objects, ellipses, or a tuple of the above. A convenient syntactic sugar
for forming indices is via the :data:`jax.ops.index` object.
y: the array of updates. `y` must be broadcastable to the shape of the
array that would be returned by `x[idx]`.
Returns:
An array.
>>> x = jax.numpy.ones((5, 6))
>>> jax.ops.index_update(x, jax.ops.index[::2, 3:], 6.)
array([[1., 1., 1., 6., 6., 6.],
[1., 1., 1., 1., 1., 1.],
[1., 1., 1., 6., 6., 6.],
[1., 1., 1., 1., 1., 1.],
[1., 1., 1., 6., 6., 6.]], dtype=float32)
"""
return _scatter_update(x, idx, y, lax.scatter)
Loading

0 comments on commit 25a93c2

Please sign in to comment.