Skip to content

Commit

Permalink
ENH: three argument version of where (pydata#1496)
Browse files Browse the repository at this point in the history
* ENH: three argument version of where

Fixes GH576

* Docstring fixup

* Use join=exact for three argument where

* Add where function

* Don't require pytest 3 for raises_regex
  • Loading branch information
shoyer authored and fujiisoup committed Aug 22, 2017
1 parent e3e6db5 commit 2ff5b20
Show file tree
Hide file tree
Showing 17 changed files with 325 additions and 139 deletions.
1 change: 1 addition & 0 deletions doc/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ Top-level functions
broadcast
concat
merge
where
set_options
full_like
zeros_like
Expand Down
8 changes: 7 additions & 1 deletion doc/computation.rst
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ numpy) over all array values:
.. ipython:: python
arr = xr.DataArray(np.random.randn(2, 3),
arr = xr.DataArray(np.random.RandomState(0).randn(2, 3),
[('x', ['a', 'b']), ('y', [10, 20, 30])])
arr - 3
abs(arr)
Expand All @@ -39,6 +39,12 @@ __ http://docs.scipy.org/doc/numpy/reference/ufuncs.html
np.sin(arr)
Use :py:func:`~xarray.where` to conditionally switch between values:

.. ipython:: python
xr.where(arr > 0, 'positive', 'negative')
Data arrays also implement many :py:class:`numpy.ndarray` methods:

.. ipython:: python
Expand Down
25 changes: 24 additions & 1 deletion doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,29 @@ Enhancements
- Speed-up (x 100) of :py:func:`~xarray.conventions.decode_cf_datetime`.
By `Christian Chwala <https://github.com/cchwala>`_.

- New function :py:func:`~xarray.where` for conditionally switching between
values in xarray objects, like :py:func:`numpy.where`:

.. ipython::
:verbatim:

In [1]: import xarray as xr

In [2]: arr = xr.DataArray([[1, 2, 3], [4, 5, 6]], dims=('x', 'y'))

In [3]: xr.where(arr % 2, 'even', 'odd')
Out[3]:
<xarray.DataArray (x: 2, y: 3)>
array([['even', 'odd', 'even'],
['odd', 'even', 'odd']],
dtype='<U4')
Dimensions without coordinates: x, y

Equivalently, the :py:meth:`~xarray.Dataset.where` method also now supports
the ``other`` argument, for filling with a value other than ``NaN``
(:issue:`576`).
By `Stephan Hoyer <https://github.com/shoyer>`_.

Bug fixes
~~~~~~~~~

Expand All @@ -49,7 +72,7 @@ Bug fixes

- Fix :py:func:`xarray.testing.assert_allclose` to actually use ``atol`` and
``rtol`` arguments when called on ``DataArray`` objects.
By `Stephan Hoyer <http://github.com/shoyer>`_.
By `Stephan Hoyer <https://github.com/shoyer>`_.

.. _whats-new.0.9.6:

Expand Down
1 change: 1 addition & 0 deletions xarray/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from .core.alignment import align, broadcast, broadcast_arrays
from .core.common import full_like, zeros_like, ones_like
from .core.combine import concat, auto_combine
from .core.computation import where
from .core.extensions import (register_dataarray_accessor,
register_dataset_accessor)
from .core.variable import as_variable, Variable, IndexVariable, Coordinate
Expand Down
6 changes: 2 additions & 4 deletions xarray/core/accessors.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,9 @@
from __future__ import division
from __future__ import print_function

from .common import is_datetime_like
from .dtypes import is_datetime_like
from .pycompat import dask_array_type

from functools import partial

import numpy as np
import pandas as pd

Expand Down Expand Up @@ -147,4 +145,4 @@ def f(self, dtype=dtype):

time = _tslib_field_accessor(
"time", "Timestamps corresponding to datetimes", object
)
)
7 changes: 4 additions & 3 deletions xarray/core/alignment.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,9 @@

import numpy as np

from . import duck_array_ops, utils
from .common import _maybe_promote
from . import duck_array_ops
from . import dtypes
from . import utils
from .indexing import get_indexer
from .pycompat import iteritems, OrderedDict, suppress
from .utils import is_full_slice, is_dict_like
Expand Down Expand Up @@ -368,7 +369,7 @@ def var_indexers(var, indexers):
if any_not_full_slices(assign_to):
# there are missing values to in-fill
data = var[assign_from].data
dtype, fill_value = _maybe_promote(var.dtype)
dtype, fill_value = dtypes.maybe_promote(var.dtype)

if isinstance(data, np.ndarray):
shape = tuple(new_sizes.get(dim, size)
Expand Down
134 changes: 49 additions & 85 deletions xarray/core/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@
import pandas as pd

from .pycompat import basestring, suppress, dask_array_type, OrderedDict
from . import dtypes
from . import formatting
from . import ops
from .utils import SortedKeysDict, not_implemented, Frozen


Expand Down Expand Up @@ -557,64 +559,74 @@ def resample(self, freq, dim, how='mean', skipna=None, closed=None,
result = result.rename({RESAMPLE_DIM: dim.name})
return result

def where(self, cond, other=None, drop=False):
"""Return an object of the same shape with all entries where cond is
True and all other entries masked.
def where(self, cond, other=dtypes.NA, drop=False):
"""Filter elements from this object according to a condition.
This operation follows the normal broadcasting and alignment rules that
xarray uses for binary arithmetic.
Parameters
----------
cond : boolean DataArray or Dataset
other : unimplemented, optional
Unimplemented placeholder for compatibility with future
numpy / pandas versions
cond : DataArray or Dataset with boolean dtype
Locations at which to preserve this object's values.
other : scalar, DataArray or Dataset, optional
Value to use for locations in this object where ``cond`` is False.
By default, these locations filled with NA.
drop : boolean, optional
Coordinate labels that only correspond to NA values should be
dropped
If True, coordinate labels that only correspond to False values of
the condition are dropped from the result. Mutually exclusive with
``other``.
Returns
-------
same type as caller or if drop=True same type as caller with dimensions
reduced for dim element where mask is True
Same type as caller.
Examples
--------
>>> import numpy as np
>>> a = xr.DataArray(np.arange(25).reshape(5, 5), dims=('x', 'y'))
>>> a.where((a > 6) & (a < 18))
>>> a.where(a.x + a.y < 4)
<xarray.DataArray (x: 5, y: 5)>
array([[ nan, nan, nan, nan, nan],
[ nan, nan, 7., 8., 9.],
[ 10., 11., 12., 13., 14.],
[ 15., 16., 17., nan, nan],
array([[ 0., 1., 2., 3., nan],
[ 5., 6., 7., nan, nan],
[ 10., 11., nan, nan, nan],
[ 15., nan, nan, nan, nan],
[ nan, nan, nan, nan, nan]])
Coordinates:
* y (y) int64 0 1 2 3 4
* x (x) int64 0 1 2 3 4
>>> a.where((a > 6) & (a < 18), drop=True)
Dimensions without coordinates: x, y
>>> a.where(a.x + a.y < 5, -1)
<xarray.DataArray (x: 5, y: 5)>
array([[ nan, nan, 7., 8., 9.],
[ 10., 11., 12., 13., 14.],
[ 15., 16., 17., nan, nan],
Coordinates:
* x (x) int64 1 2 3
* y (y) int64 0 1 2 3 4
array([[ 0, 1, 2, 3, 4],
[ 5, 6, 7, 8, -1],
[10, 11, 12, -1, -1],
[15, 16, -1, -1, -1],
[20, -1, -1, -1, -1]])
Dimensions without coordinates: x, y
>>> a.where(a.x + a.y < 4, drop=True)
<xarray.DataArray (x: 4, y: 4)>
array([[ 0., 1., 2., 3.],
[ 5., 6., 7., nan],
[ 10., 11., nan, nan],
[ 15., nan, nan, nan]])
Dimensions without coordinates: x, y
See also
--------
numpy.where : corresponding numpy function
where : equivalent function
"""
if other is not None:
raise NotImplementedError("The optional argument 'other' has not "
"yet been implemented")
from .alignment import align
from .dataarray import DataArray
from .dataset import Dataset

if drop:
from .dataarray import DataArray
from .dataset import Dataset
from .alignment import align
if other is not dtypes.NA:
raise ValueError('cannot set `other` if drop=True')

if not isinstance(cond, (Dataset, DataArray)):
raise TypeError("Cond argument is %r but must be a %r or %r" %
raise TypeError("cond argument is %r but must be a %r or %r" %
(cond, Dataset, DataArray))

# align so we can use integer indexing
self, cond = align(self, cond)

Expand All @@ -627,16 +639,11 @@ def where(self, cond, other=None, drop=False):
# clip the data corresponding to coordinate dims that are not used
nonzeros = zip(clipcond.dims, np.nonzero(clipcond.values))
indexers = {k: np.unique(v) for k, v in nonzeros}
outobj = self.isel(**indexers)
outcond = cond.isel(**indexers)
else:
outobj = self
outcond = cond

# preserve attributes
out = outobj._where(outcond)
out._copy_attrs_from(self)
return out
self = self.isel(**indexers)
cond = cond.isel(**indexers)

return ops.where_method(self, cond, other)

def close(self):
"""Close any files linked to this object
Expand All @@ -658,42 +665,6 @@ def __exit__(self, exc_type, exc_value, traceback):
__or__ = __div__ = __eq__ = __ne__ = not_implemented


def _maybe_promote(dtype):
"""Simpler equivalent of pandas.core.common._maybe_promote"""
# N.B. these casting rules should match pandas
if np.issubdtype(dtype, float):
fill_value = np.nan
elif np.issubdtype(dtype, int):
# convert to floating point so NaN is valid
dtype = float
fill_value = np.nan
elif np.issubdtype(dtype, complex):
fill_value = np.nan + np.nan * 1j
elif np.issubdtype(dtype, np.datetime64):
fill_value = np.datetime64('NaT')
elif np.issubdtype(dtype, np.timedelta64):
fill_value = np.timedelta64('NaT')
else:
dtype = object
fill_value = np.nan
return np.dtype(dtype), fill_value


def _possibly_convert_objects(values):
"""Convert arrays of datetime.datetime and datetime.timedelta objects into
datetime64 and timedelta64, according to the pandas convention.
"""
return np.asarray(pd.Series(values.ravel())).reshape(values.shape)


def _get_fill_value(dtype):
"""Return a fill value that appropriately promotes types when used with
np.concatenate
"""
_, fill_value = _maybe_promote(dtype)
return fill_value


def full_like(other, fill_value, dtype=None):
"""Return a new object with the same shape and type as a given object.
Expand Down Expand Up @@ -761,10 +732,3 @@ def ones_like(other, dtype=None):
"""Shorthand for full_like(other, 1, dtype)
"""
return full_like(other, 1, dtype)


def is_datetime_like(dtype):
"""Check if a dtype is a subclass of the numpy datetime types
"""
return (np.issubdtype(dtype, np.datetime64) or
np.issubdtype(dtype, np.timedelta64))
Loading

0 comments on commit 2ff5b20

Please sign in to comment.