Skip to content

Commit

Permalink
jax.numpy: require boolean dtype for where argument
Browse files Browse the repository at this point in the history
  • Loading branch information
jakevdp committed Dec 5, 2024
1 parent 1a3c9c4 commit 29a8cce
Show file tree
Hide file tree
Showing 3 changed files with 59 additions and 3 deletions.
1 change: 1 addition & 0 deletions jax/_src/deprecations.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,5 +130,6 @@ def warn(deprecation_id: str, message: str, stacklevel: int) -> None:
register('jax-numpy-linalg-matrix_rank-tol')
register('jax-numpy-linalg-pinv-rcond')
register('jax-numpy-quantile-interpolation')
register('jax-numpy-reduction-non-boolean-where')
register('jax-numpy-trimzeros-not-1d-array')
register('pallas-gpu-triton')
33 changes: 30 additions & 3 deletions jax/_src/numpy/reductions.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,20 @@ def _promote_integer_dtype(dtype: DTypeLike) -> DTypeLike:
return dtypes.int_
return dtype

def check_where(name: str, where: ArrayLike | None) -> Array | None:
if where is None:
return where
check_arraylike(name, where)
where_arr = lax_internal.asarray(where)
if where_arr.dtype != bool:
# Deprecation added 2024-12-05
deprecations.warn(
'jax-numpy-reduction-non-boolean-where',
f"jnp.{name}: where must be None or a boolean array; got dtype={where_arr.dtype}.",
stacklevel=2)
return where_arr.astype(bool)
return where_arr


ReductionOp = Callable[[Any, Any], Any]

Expand All @@ -101,6 +115,7 @@ def _reduction(a: ArrayLike, name: str, op: ReductionOp, init_val: ArrayLike,
if out is not None:
raise NotImplementedError(f"The 'out' argument to jnp.{name} is not supported.")
check_arraylike(name, a)
where_ = check_where(name, where_)
dtypes.check_user_dtype_supported(dtype, name)
axis = core.concrete_or_error(None, axis, f"axis argument to jnp.{name}().")

Expand Down Expand Up @@ -730,6 +745,8 @@ def _logsumexp(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None,
if out is not None:
raise NotImplementedError("The 'out' argument to jnp.logaddexp.reduce is not supported.")
dtypes.check_user_dtype_supported(dtype, "jnp.logaddexp.reduce")
check_arraylike("logsumexp", a)
where = check_where("logsumexp", where)
a_arr, = promote_dtypes_inexact(a)
pos_dims, dims = _reduction_dims(a_arr, axis)
amax = max(a_arr.real, axis=dims, keepdims=keepdims, where=where, initial=-np.inf)
Expand All @@ -748,6 +765,8 @@ def _logsumexp2(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None,
if out is not None:
raise NotImplementedError("The 'out' argument to jnp.logaddexp2.reduce is not supported.")
dtypes.check_user_dtype_supported(dtype, "jnp.logaddexp2.reduce")
check_arraylike("logsumexp2", a)
where = check_where("logsumexp2", where)
ln2 = float(np.log(2))
if initial is not None:
initial *= ln2
Expand Down Expand Up @@ -850,6 +869,7 @@ def _mean(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None,
upcast_f16_for_computation: bool = True,
where: ArrayLike | None = None) -> Array:
check_arraylike("mean", a)
where = check_where("mean", where)
if out is not None:
raise NotImplementedError("The 'out' argument to jnp.mean is not supported.")

Expand Down Expand Up @@ -1087,6 +1107,7 @@ def _var(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None,
out: None = None, correction: int | float = 0, keepdims: bool = False, *,
where: ArrayLike | None = None) -> Array:
check_arraylike("var", a)
where = check_where("var", where)
dtypes.check_user_dtype_supported(dtype, "var")
if out is not None:
raise NotImplementedError("The 'out' argument to jnp.var is not supported.")
Expand Down Expand Up @@ -1224,6 +1245,7 @@ def _std(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None,
out: None = None, correction: int | float = 0, keepdims: bool = False, *,
where: ArrayLike | None = None) -> Array:
check_arraylike("std", a)
where = check_where("std", where)
dtypes.check_user_dtype_supported(dtype, "std")
if dtype is not None and not dtypes.issubdtype(dtype, np.inexact):
raise ValueError(f"dtype argument to jnp.std must be inexact; got {dtype}")
Expand Down Expand Up @@ -1330,13 +1352,15 @@ def count_nonzero(a: ArrayLike, axis: Axis = None,

def _nan_reduction(a: ArrayLike, name: str, jnp_reduction: Callable[..., Array],
init_val: ArrayLike, nan_if_all_nan: bool,
axis: Axis = None, keepdims: bool = False, **kwargs) -> Array:
axis: Axis = None, keepdims: bool = False, where: ArrayLike | None = None,
**kwargs) -> Array:
check_arraylike(name, a)
where = check_where(name, where)
if not dtypes.issubdtype(dtypes.dtype(a), np.inexact):
return jnp_reduction(a, axis=axis, keepdims=keepdims, **kwargs)
return jnp_reduction(a, axis=axis, keepdims=keepdims, where=where, **kwargs)

out = jnp_reduction(_where(lax_internal._isnan(a), _reduction_init_val(a, init_val), a),
axis=axis, keepdims=keepdims, **kwargs)
axis=axis, keepdims=keepdims, where=where, **kwargs)
if nan_if_all_nan:
return _where(all(lax_internal._isnan(a), axis=axis, keepdims=keepdims),
_lax_const(a, np.nan), out)
Expand Down Expand Up @@ -1755,6 +1779,7 @@ def nanmean(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, out
Array([[nan, nan, nan, nan]], dtype=float32)
"""
check_arraylike("nanmean", a)
where = check_where("nanmean", where)
if out is not None:
raise NotImplementedError("The 'out' argument to jnp.nanmean is not supported.")
if dtypes.issubdtype(dtypes.dtype(a), np.bool_) or dtypes.issubdtype(dtypes.dtype(a), np.integer):
Expand Down Expand Up @@ -1848,6 +1873,7 @@ def nanvar(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, out:
[4. ]], dtype=float32)
"""
check_arraylike("nanvar", a)
where = check_where("nanvar", where)
dtypes.check_user_dtype_supported(dtype, "nanvar")
if out is not None:
raise NotImplementedError("The 'out' argument to jnp.nanvar is not supported.")
Expand Down Expand Up @@ -1943,6 +1969,7 @@ def nanstd(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, out:
Array([[0.5, 0.5, 0. , 0. ]], dtype=float32)
"""
check_arraylike("nanstd", a)
where = check_where("nanstd", where)
dtypes.check_user_dtype_supported(dtype, "nanstd")
if out is not None:
raise NotImplementedError("The 'out' argument to jnp.nanstd is not supported.")
Expand Down
28 changes: 28 additions & 0 deletions tests/lax_numpy_reducers_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -448,6 +448,34 @@ def np_fun(x):
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, atol=tol, rtol=tol)
self._CompileAndCheck(jnp_fun, args_maker)

@jtu.sample_product(rec=JAX_REDUCER_INITIAL_RECORDS)
def testReducerWhereNonBooleanErrorInitial(self, rec):
dtype = rec.dtypes[0]
x = jnp.zeros((10,), dtype)
where = jnp.ones(10, dtype=int)
func = getattr(jnp, rec.name)
def assert_warns_or_errors(msg):
if deprecations.is_accelerated("jax-numpy-reduction-non-boolean-where"):
return self.assertRaisesRegex(ValueError, msg)
else:
return self.assertWarnsRegex(DeprecationWarning, msg)
with assert_warns_or_errors(f"jnp.{rec.name}: where must be None or a boolean array"):
func(x, where=where, initial=jnp.array(0, dtype=dtype))

@jtu.sample_product(rec=JAX_REDUCER_WHERE_NO_INITIAL_RECORDS)
def testReducerWhereNonBooleanErrorNoInitial(self, rec):
dtype = rec.dtypes[0]
x = jnp.zeros((10,), dtype)
where = jnp.ones(10, dtype=int)
func = getattr(jnp, rec.name)
def assert_warns_or_errors(msg):
if deprecations.is_accelerated("jax-numpy-reduction-non-boolean-where"):
return self.assertRaisesRegex(ValueError, msg)
else:
return self.assertWarnsRegex(DeprecationWarning, msg)
with assert_warns_or_errors(f"jnp.{rec.name}: where must be None or a boolean array"):
func(x, where=where)

@parameterized.parameters(itertools.chain.from_iterable(
jtu.sample_product_testcases(
[dict(name=rec.name, rng_factory=rec.rng_factory, inexact=rec.inexact,
Expand Down

0 comments on commit 29a8cce

Please sign in to comment.