Skip to content

Commit

Permalink
Initial implementation of jnp.delete
Browse files Browse the repository at this point in the history
  • Loading branch information
jakevdp committed Mar 17, 2021
1 parent 0a84db5 commit 737e479
Show file tree
Hide file tree
Showing 4 changed files with 121 additions and 2 deletions.
1 change: 1 addition & 0 deletions docs/jax.numpy.rst
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@ Not every function in NumPy is implemented; contributions are welcome!
cumsum
deg2rad
degrees
delete
diag
diagflat
diag_indices
Expand Down
44 changes: 44 additions & 0 deletions jax/_src/numpy/lax_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -3606,6 +3606,50 @@ def append(arr, values, axis: Optional[int] = None):
return concatenate([arr, values], axis=axis)


@_wraps(np.delete)
def delete(arr, obj, axis=None):
_check_arraylike("delete", arr)
if axis is None:
arr = ravel(arr)
axis = 0
axis = _canonicalize_axis(axis, arr.ndim)

# Case 1: obj is a static integer.
try:
obj = operator.index(obj)
obj = _canonicalize_axis(obj, arr.shape[axis])
except TypeError:
pass
else:
idx = tuple(slice(None) for i in range(axis))
return concatenate([arr[idx + (slice(0, obj),)], arr[idx + (slice(obj + 1, None),)]], axis=axis)

# Case 2: obj is a static slice.
if isinstance(obj, slice):
# TODO(jakevdp): we should be able to do this dynamically with care.
indices = np.delete(np.arange(arr.shape[axis]), obj)
return take(arr, indices, axis=axis)

# Case 3: obj is an array
# NB: pass both arrays to check for appropriate error message.
_check_arraylike("delete", arr, obj)
obj = core.concrete_or_error(np.asarray, obj, "'obj' array argument of jnp.delete()")

if issubdtype(obj.dtype, integer):
# TODO(jakevdp): in theory this could be done dynamically if obj has no duplicates,
# but this would require the complement of lax.gather.
mask = np.ones(arr.shape[axis], dtype=bool)
mask[obj] = False
elif obj.dtype == bool:
if obj.shape != (arr.shape[axis],):
raise ValueError("np.delete(arr, obj): for boolean indices, obj must be one-dimensional "
"with length matching specified axis.")
mask = ~obj
else:
raise ValueError(f"np.delete(arr, obj): got obj.dtype={obj.dtype}; must be integer or bool.")
return arr[tuple(slice(None) for i in range(axis)) + (mask,)]


@_wraps(np.apply_along_axis)
def apply_along_axis(func1d, axis: int, arr, *args, **kwargs):
num_dims = ndim(arr)
Expand Down
4 changes: 2 additions & 2 deletions jax/numpy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@
broadcast_to, can_cast, cbrt, cdouble, ceil, character, choose, clip, column_stack,
complex128, complex64, complex_, complexfloating, compress, concatenate,
conj, conjugate, convolve, copysign, corrcoef, correlate, cos, cosh,
count_nonzero, cov, cross, csingle, cumprod, cumproduct, cumsum, deg2rad,
degrees, diag, diagflat, diag_indices, diag_indices_from, diagonal, diff, digitize, divide, divmod, dot,
count_nonzero, cov, cross, csingle, cumprod, cumproduct, cumsum, deg2rad, degrees,
delete, diag, diagflat, diag_indices, diag_indices_from, diagonal, diff, digitize, divide, divmod, dot,
double, dsplit, dstack, dtype, e, ediff1d, einsum, einsum_path, empty,
empty_like, equal, euler_gamma, exp, exp2, expand_dims, expm1, extract, eye,
fabs, finfo, fix, flatnonzero, flexible, flip, fliplr, flipud, float16, float32,
Expand Down
74 changes: 74 additions & 0 deletions tests/lax_numpy_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1960,6 +1960,80 @@ def args_maker():
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
self._CompileAndCheck(jnp_fun, args_maker)


@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_{}_axis={}_idx={}".format(
jtu.format_shape_dtype_string(shape, dtype), axis, idx),
"dtype": dtype, "shape": shape, "axis": axis, "idx": idx}
for shape in nonempty_nonscalar_array_shapes
for dtype in all_dtypes
for axis in [None] + list(range(-len(shape), len(shape)))
for idx in (range(-prod(shape), prod(shape))
if axis is None else
range(-shape[axis], shape[axis]))))
def testDeleteInteger(self, shape, dtype, idx, axis):
rng = jtu.rand_default(self.rng())
args_maker = lambda: [rng(shape, dtype)]
np_fun = lambda arg: np.delete(arg, idx, axis=axis)
jnp_fun = lambda arg: jnp.delete(arg, idx, axis=axis)
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
self._CompileAndCheck(jnp_fun, args_maker)

@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_{}_axis={}_slc={}".format(
jtu.format_shape_dtype_string(shape, dtype), axis, slc),
"dtype": dtype, "shape": shape, "axis": axis, "slc": slc}
for shape in nonempty_nonscalar_array_shapes
for dtype in all_dtypes
for axis in [None] + list(range(-len(shape), len(shape)))
for slc in [slice(None), slice(1, 3), slice(1, 5, 2)]))
def testDeleteSlice(self, shape, dtype, axis, slc):
rng = jtu.rand_default(self.rng())
args_maker = lambda: [rng(shape, dtype)]
np_fun = lambda arg: np.delete(arg, slc, axis=axis)
jnp_fun = lambda arg: jnp.delete(arg, slc, axis=axis)
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
self._CompileAndCheck(jnp_fun, args_maker)

@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_{}_axis={}_idx={}".format(
jtu.format_shape_dtype_string(shape, dtype), axis,
jtu.format_shape_dtype_string(idx_shape, int)),
"dtype": dtype, "shape": shape, "axis": axis, "idx_shape": idx_shape}
for shape in nonempty_nonscalar_array_shapes
for dtype in all_dtypes
for axis in [None] + list(range(-len(shape), len(shape)))
for idx_shape in all_shapes))
def testDeleteIndexArray(self, shape, dtype, axis, idx_shape):
rng = jtu.rand_default(self.rng())
max_idx = np.zeros(shape).size if axis is None else np.zeros(shape).shape[axis]
# Previous to numpy 1.19, negative indices were ignored so we don't test this.
low = 0 if numpy_version < (1, 19, 0) else -max_idx
idx = jtu.rand_int(self.rng(), low=low, high=max_idx)(idx_shape, int)
args_maker = lambda: [rng(shape, dtype)]
np_fun = lambda arg: np.delete(arg, idx, axis=axis)
jnp_fun = lambda arg: jnp.delete(arg, idx, axis=axis)
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
self._CompileAndCheck(jnp_fun, args_maker)

@unittest.skipIf(numpy_version < (1, 19), "boolean mask not supported in numpy < 1.19.0")
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_{}_axis={}".format(
jtu.format_shape_dtype_string(shape, dtype), axis),
"dtype": dtype, "shape": shape, "axis": axis}
for shape in nonempty_nonscalar_array_shapes
for dtype in all_dtypes
for axis in [None] + list(range(-len(shape), len(shape)))))
def testDeleteMaskArray(self, shape, dtype, axis):
rng = jtu.rand_default(self.rng())
mask_size = np.zeros(shape).size if axis is None else np.zeros(shape).shape[axis]
mask = jtu.rand_int(self.rng(), low=0, high=2)(mask_size, bool)
args_maker = lambda: [rng(shape, dtype)]
np_fun = lambda arg: np.delete(arg, mask, axis=axis)
jnp_fun = lambda arg: jnp.delete(arg, mask, axis=axis)
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
self._CompileAndCheck(jnp_fun, args_maker)

@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_{}_axis={}_out_dims={}".format(
jtu.format_shape_dtype_string(shape, dtype),
Expand Down

0 comments on commit 737e479

Please sign in to comment.