Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Initial implementation of jnp.delete #6085

Merged
merged 1 commit into from
Mar 17, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/jax.numpy.rst
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@ Not every function in NumPy is implemented; contributions are welcome!
cumsum
deg2rad
degrees
delete
diag
diagflat
diag_indices
Expand Down
44 changes: 44 additions & 0 deletions jax/_src/numpy/lax_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -3606,6 +3606,50 @@ def append(arr, values, axis: Optional[int] = None):
return concatenate([arr, values], axis=axis)


@_wraps(np.delete)
def delete(arr, obj, axis=None):
_check_arraylike("delete", arr)
if axis is None:
arr = ravel(arr)
axis = 0
axis = _canonicalize_axis(axis, arr.ndim)

# Case 1: obj is a static integer.
try:
obj = operator.index(obj)
obj = _canonicalize_axis(obj, arr.shape[axis])
except TypeError:
pass
else:
idx = tuple(slice(None) for i in range(axis))
return concatenate([arr[idx + (slice(0, obj),)], arr[idx + (slice(obj + 1, None),)]], axis=axis)

# Case 2: obj is a static slice.
if isinstance(obj, slice):
# TODO(jakevdp): we should be able to do this dynamically with care.
indices = np.delete(np.arange(arr.shape[axis]), obj)
return take(arr, indices, axis=axis)

# Case 3: obj is an array
# NB: pass both arrays to check for appropriate error message.
_check_arraylike("delete", arr, obj)
obj = core.concrete_or_error(np.asarray, obj, "'obj' array argument of jnp.delete()")

if issubdtype(obj.dtype, integer):
# TODO(jakevdp): in theory this could be done dynamically if obj has no duplicates,
# but this would require the complement of lax.gather.
mask = np.ones(arr.shape[axis], dtype=bool)
mask[obj] = False
elif obj.dtype == bool:
if obj.shape != (arr.shape[axis],):
raise ValueError("np.delete(arr, obj): for boolean indices, obj must be one-dimensional "
"with length matching specified axis.")
mask = ~obj
else:
raise ValueError(f"np.delete(arr, obj): got obj.dtype={obj.dtype}; must be integer or bool.")
return arr[tuple(slice(None) for i in range(axis)) + (mask,)]


@_wraps(np.apply_along_axis)
def apply_along_axis(func1d, axis: int, arr, *args, **kwargs):
num_dims = ndim(arr)
Expand Down
4 changes: 2 additions & 2 deletions jax/numpy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@
broadcast_to, can_cast, cbrt, cdouble, ceil, character, choose, clip, column_stack,
complex128, complex64, complex_, complexfloating, compress, concatenate,
conj, conjugate, convolve, copysign, corrcoef, correlate, cos, cosh,
count_nonzero, cov, cross, csingle, cumprod, cumproduct, cumsum, deg2rad,
degrees, diag, diagflat, diag_indices, diag_indices_from, diagonal, diff, digitize, divide, divmod, dot,
count_nonzero, cov, cross, csingle, cumprod, cumproduct, cumsum, deg2rad, degrees,
delete, diag, diagflat, diag_indices, diag_indices_from, diagonal, diff, digitize, divide, divmod, dot,
double, dsplit, dstack, dtype, e, ediff1d, einsum, einsum_path, empty,
empty_like, equal, euler_gamma, exp, exp2, expand_dims, expm1, extract, eye,
fabs, finfo, fix, flatnonzero, flexible, flip, fliplr, flipud, float16, float32,
Expand Down
74 changes: 74 additions & 0 deletions tests/lax_numpy_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1960,6 +1960,80 @@ def args_maker():
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
self._CompileAndCheck(jnp_fun, args_maker)


@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_{}_axis={}_idx={}".format(
jtu.format_shape_dtype_string(shape, dtype), axis, idx),
"dtype": dtype, "shape": shape, "axis": axis, "idx": idx}
for shape in nonempty_nonscalar_array_shapes
for dtype in all_dtypes
for axis in [None] + list(range(-len(shape), len(shape)))
for idx in (range(-prod(shape), prod(shape))
if axis is None else
range(-shape[axis], shape[axis]))))
def testDeleteInteger(self, shape, dtype, idx, axis):
rng = jtu.rand_default(self.rng())
args_maker = lambda: [rng(shape, dtype)]
np_fun = lambda arg: np.delete(arg, idx, axis=axis)
jnp_fun = lambda arg: jnp.delete(arg, idx, axis=axis)
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
self._CompileAndCheck(jnp_fun, args_maker)

@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_{}_axis={}_slc={}".format(
jtu.format_shape_dtype_string(shape, dtype), axis, slc),
"dtype": dtype, "shape": shape, "axis": axis, "slc": slc}
for shape in nonempty_nonscalar_array_shapes
for dtype in all_dtypes
for axis in [None] + list(range(-len(shape), len(shape)))
for slc in [slice(None), slice(1, 3), slice(1, 5, 2)]))
def testDeleteSlice(self, shape, dtype, axis, slc):
rng = jtu.rand_default(self.rng())
args_maker = lambda: [rng(shape, dtype)]
np_fun = lambda arg: np.delete(arg, slc, axis=axis)
jnp_fun = lambda arg: jnp.delete(arg, slc, axis=axis)
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
self._CompileAndCheck(jnp_fun, args_maker)

@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_{}_axis={}_idx={}".format(
jtu.format_shape_dtype_string(shape, dtype), axis,
jtu.format_shape_dtype_string(idx_shape, int)),
"dtype": dtype, "shape": shape, "axis": axis, "idx_shape": idx_shape}
for shape in nonempty_nonscalar_array_shapes
for dtype in all_dtypes
for axis in [None] + list(range(-len(shape), len(shape)))
for idx_shape in all_shapes))
def testDeleteIndexArray(self, shape, dtype, axis, idx_shape):
rng = jtu.rand_default(self.rng())
max_idx = np.zeros(shape).size if axis is None else np.zeros(shape).shape[axis]
# Previous to numpy 1.19, negative indices were ignored so we don't test this.
low = 0 if numpy_version < (1, 19, 0) else -max_idx
idx = jtu.rand_int(self.rng(), low=low, high=max_idx)(idx_shape, int)
args_maker = lambda: [rng(shape, dtype)]
np_fun = lambda arg: np.delete(arg, idx, axis=axis)
jnp_fun = lambda arg: jnp.delete(arg, idx, axis=axis)
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
self._CompileAndCheck(jnp_fun, args_maker)

@unittest.skipIf(numpy_version < (1, 19), "boolean mask not supported in numpy < 1.19.0")
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_{}_axis={}".format(
jtu.format_shape_dtype_string(shape, dtype), axis),
"dtype": dtype, "shape": shape, "axis": axis}
for shape in nonempty_nonscalar_array_shapes
for dtype in all_dtypes
for axis in [None] + list(range(-len(shape), len(shape)))))
def testDeleteMaskArray(self, shape, dtype, axis):
rng = jtu.rand_default(self.rng())
mask_size = np.zeros(shape).size if axis is None else np.zeros(shape).shape[axis]
mask = jtu.rand_int(self.rng(), low=0, high=2)(mask_size, bool)
args_maker = lambda: [rng(shape, dtype)]
np_fun = lambda arg: np.delete(arg, mask, axis=axis)
jnp_fun = lambda arg: jnp.delete(arg, mask, axis=axis)
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
self._CompileAndCheck(jnp_fun, args_maker)

@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_{}_axis={}_out_dims={}".format(
jtu.format_shape_dtype_string(shape, dtype),
Expand Down