Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH: three argument version of where #1496

Merged
merged 5 commits into from
Aug 8, 2017
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,9 @@ Enhancements

- Speed-up (x 100) of :py:func:`~xarray.conventions.decode_cf_datetime`.
By `Christian Chwala <https://github.com/cchwala>`_.
- :py:meth:`~xarray.Dataset.where` now supports the ``other`` argument, for
filling with a value other than ``NaN`` (:issue:`576`).
By `Stephan Hoyer <https://github.com/shoyer>`_.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also comment about xr.where function.


Bug fixes
~~~~~~~~~
Expand All @@ -49,7 +52,7 @@ Bug fixes

- Fix :py:func:`xarray.testing.assert_allclose` to actually use ``atol`` and
``rtol`` arguments when called on ``DataArray`` objects.
By `Stephan Hoyer <http://github.com/shoyer>`_.
By `Stephan Hoyer <https://github.com/shoyer>`_.

.. _whats-new.0.9.6:

Expand Down
6 changes: 2 additions & 4 deletions xarray/core/accessors.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,9 @@
from __future__ import division
from __future__ import print_function

from .common import is_datetime_like
from .dtypes import is_datetime_like
from .pycompat import dask_array_type

from functools import partial

import numpy as np
import pandas as pd

Expand Down Expand Up @@ -147,4 +145,4 @@ def f(self, dtype=dtype):

time = _tslib_field_accessor(
"time", "Timestamps corresponding to datetimes", object
)
)
7 changes: 4 additions & 3 deletions xarray/core/alignment.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,9 @@

import numpy as np

from . import duck_array_ops, utils
from .common import _maybe_promote
from . import duck_array_ops
from . import dtypes
from . import utils
from .indexing import get_indexer
from .pycompat import iteritems, OrderedDict, suppress
from .utils import is_full_slice, is_dict_like
Expand Down Expand Up @@ -368,7 +369,7 @@ def var_indexers(var, indexers):
if any_not_full_slices(assign_to):
# there are missing values to in-fill
data = var[assign_from].data
dtype, fill_value = _maybe_promote(var.dtype)
dtype, fill_value = dtypes.maybe_promote(var.dtype)

if isinstance(data, np.ndarray):
shape = tuple(new_sizes.get(dim, size)
Expand Down
129 changes: 44 additions & 85 deletions xarray/core/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@
import pandas as pd

from .pycompat import basestring, suppress, dask_array_type, OrderedDict
from . import dtypes
from . import formatting
from . import ops
from .utils import SortedKeysDict, not_implemented, Frozen


Expand Down Expand Up @@ -557,64 +559,69 @@ def resample(self, freq, dim, how='mean', skipna=None, closed=None,
result = result.rename({RESAMPLE_DIM: dim.name})
return result

def where(self, cond, other=None, drop=False):
"""Return an object of the same shape with all entries where cond is
True and all other entries masked.
def where(self, cond, other=dtypes.NA, drop=False):
"""Return elements from `self` or `other` depending on `cond`.

This operation follows the normal broadcasting and alignment rules that
xarray uses for binary arithmetic.

Parameters
----------
cond : boolean DataArray or Dataset
other : unimplemented, optional
Unimplemented placeholder for compatibility with future
numpy / pandas versions
cond : DataArray or Dataset with boolean dtype
Locations at which to preserve this object's values.
other : scalar, DataArray or Dataset, optional
Value to use for locations in this object where ``cond`` is False.
By default, these locations filled with NA.
drop : boolean, optional
Coordinate labels that only correspond to NA values should be
dropped
If True, coordinate labels that only correspond to False values of
the condition are dropped from the result. Mutually exclusive with
``other``.

Returns
-------
same type as caller or if drop=True same type as caller with dimensions
reduced for dim element where mask is True
Same type as caller.

Examples
--------

>>> import numpy as np
>>> a = xr.DataArray(np.arange(25).reshape(5, 5), dims=('x', 'y'))
>>> a.where((a > 6) & (a < 18))
>>> a.where(a.x + a.y < 4)
<xarray.DataArray (x: 5, y: 5)>
array([[ nan, nan, nan, nan, nan],
[ nan, nan, 7., 8., 9.],
[ 10., 11., 12., 13., 14.],
[ 15., 16., 17., nan, nan],
array([[ 0., 1., 2., 3., nan],
[ 5., 6., 7., nan, nan],
[ 10., 11., nan, nan, nan],
[ 15., nan, nan, nan, nan],
[ nan, nan, nan, nan, nan]])
Coordinates:
* y (y) int64 0 1 2 3 4
* x (x) int64 0 1 2 3 4
>>> a.where((a > 6) & (a < 18), drop=True)
Dimensions without coordinates: x, y
>>> a.where(a.x + a.y < 5, -1)
<xarray.DataArray (x: 5, y: 5)>
array([[ nan, nan, 7., 8., 9.],
[ 10., 11., 12., 13., 14.],
[ 15., 16., 17., nan, nan],
Coordinates:
* x (x) int64 1 2 3
* y (y) int64 0 1 2 3 4
array([[ 0, 1, 2, 3, 4],
[ 5, 6, 7, 8, -1],
[10, 11, 12, -1, -1],
[15, 16, -1, -1, -1],
[20, -1, -1, -1, -1]])
Dimensions without coordinates: x, y
>>> a.where(a.x + a.y < 4, drop=True)
<xarray.DataArray (x: 4, y: 4)>
array([[ 0., 1., 2., 3.],
[ 5., 6., 7., nan],
[ 10., 11., nan, nan],
[ 15., nan, nan, nan]])
Dimensions without coordinates: x, y
"""
if other is not None:
raise NotImplementedError("The optional argument 'other' has not "
"yet been implemented")
from .alignment import align
from .dataarray import DataArray
from .dataset import Dataset

if drop:
from .dataarray import DataArray
from .dataset import Dataset
from .alignment import align
if other is not dtypes.NA:
raise ValueError('cannot set `other` if drop=True')

if not isinstance(cond, (Dataset, DataArray)):
raise TypeError("Cond argument is %r but must be a %r or %r" %
raise TypeError("cond argument is %r but must be a %r or %r" %
(cond, Dataset, DataArray))

# align so we can use integer indexing
self, cond = align(self, cond)

Expand All @@ -627,16 +634,11 @@ def where(self, cond, other=None, drop=False):
# clip the data corresponding to coordinate dims that are not used
nonzeros = zip(clipcond.dims, np.nonzero(clipcond.values))
indexers = {k: np.unique(v) for k, v in nonzeros}
outobj = self.isel(**indexers)
outcond = cond.isel(**indexers)
else:
outobj = self
outcond = cond

# preserve attributes
out = outobj._where(outcond)
out._copy_attrs_from(self)
return out
self = self.isel(**indexers)
cond = cond.isel(**indexers)

return ops.where(self, cond, other)

def close(self):
"""Close any files linked to this object
Expand All @@ -658,42 +660,6 @@ def __exit__(self, exc_type, exc_value, traceback):
__or__ = __div__ = __eq__ = __ne__ = not_implemented


def _maybe_promote(dtype):
"""Simpler equivalent of pandas.core.common._maybe_promote"""
# N.B. these casting rules should match pandas
if np.issubdtype(dtype, float):
fill_value = np.nan
elif np.issubdtype(dtype, int):
# convert to floating point so NaN is valid
dtype = float
fill_value = np.nan
elif np.issubdtype(dtype, complex):
fill_value = np.nan + np.nan * 1j
elif np.issubdtype(dtype, np.datetime64):
fill_value = np.datetime64('NaT')
elif np.issubdtype(dtype, np.timedelta64):
fill_value = np.timedelta64('NaT')
else:
dtype = object
fill_value = np.nan
return np.dtype(dtype), fill_value


def _possibly_convert_objects(values):
"""Convert arrays of datetime.datetime and datetime.timedelta objects into
datetime64 and timedelta64, according to the pandas convention.
"""
return np.asarray(pd.Series(values.ravel())).reshape(values.shape)


def _get_fill_value(dtype):
"""Return a fill value that appropriately promotes types when used with
np.concatenate
"""
_, fill_value = _maybe_promote(dtype)
return fill_value


def full_like(other, fill_value, dtype=None):
"""Return a new object with the same shape and type as a given object.

Expand Down Expand Up @@ -761,10 +727,3 @@ def ones_like(other, dtype=None):
"""Shorthand for full_like(other, 1, dtype)
"""
return full_like(other, 1, dtype)


def is_datetime_like(dtype):
"""Check if a dtype is a subclass of the numpy datetime types
"""
return (np.issubdtype(dtype, np.datetime64) or
np.issubdtype(dtype, np.timedelta64))
20 changes: 17 additions & 3 deletions xarray/core/computation.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@

_DEFAULT_FROZEN_SET = frozenset()
_DEFAULT_FILL_VALUE = object()
_DEFAULT_NAME = object()
_JOINS_WITHOUT_FILL_VALUES = frozenset({'inner', 'exact'})


class _UFuncSignature(object):
Expand Down Expand Up @@ -109,8 +111,8 @@ def result_name(objects):
# type: List[object] -> Any
# use the same naming heuristics as pandas:
# https://github.com/blaze/blaze/issues/458#issuecomment-51936356
names = {getattr(obj, 'name', None) for obj in objects}
names.discard(None)
names = {getattr(obj, 'name', _DEFAULT_NAME) for obj in objects}
names.discard(_DEFAULT_NAME)
if len(names) == 1:
name, = names
else:
Expand Down Expand Up @@ -220,11 +222,22 @@ def ordered_set_intersection(all_keys):
return [key for key in all_keys[0] if key in intersection]


def assert_and_return_exact_match(all_keys):
first_keys = all_keys[0]
for keys in all_keys[1:]:
if keys != first_keys:
raise ValueError(
'exact match required for all data variable names, but %r != %r'
% (keys, first_keys))
return first_keys


_JOINERS = {
'inner': ordered_set_intersection,
'outer': ordered_set_union,
'left': operator.itemgetter(0),
'right': operator.itemgetter(-1),
'exact': assert_and_return_exact_match,
}


Expand Down Expand Up @@ -320,7 +333,8 @@ def apply_dataset_ufunc(func, *args, **kwargs):
keep_attrs = kwargs.pop('keep_attrs', False)
first_obj = args[0] # we'll copy attrs from this in case keep_attrs=True

if dataset_join != 'inner' and fill_value is _DEFAULT_FILL_VALUE:
if (dataset_join not in _JOINS_WITHOUT_FILL_VALUES and
fill_value is _DEFAULT_FILL_VALUE):
raise TypeError('To apply an operation to datasets with different ',
'data variables, you must supply the ',
'dataset_fill_value argument.')
Expand Down
3 changes: 2 additions & 1 deletion xarray/core/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@
from .. import conventions
from .alignment import align
from .coordinates import DatasetCoordinates, LevelCoordinatesSource, Indexes
from .common import ImplementsDatasetReduce, BaseDataObject, is_datetime_like
from .common import ImplementsDatasetReduce, BaseDataObject
from .dtypes import is_datetime_like
from .merge import (dataset_update_method, dataset_merge_method,
merge_data_and_coords)
from .utils import (Frozen, SortedKeysDict, maybe_wrap_array, hashable,
Expand Down
58 changes: 58 additions & 0 deletions xarray/core/dtypes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
import numpy as np


class NA(object):
"""Use as a sentinel value to indicate a dtype appropriate NA value."""


def maybe_promote(dtype):
"""Simpler equivalent of pandas.core.common._maybe_promote

Parameters
----------
dtype : np.dtype

Returns
-------
dtype : Promoted dtype that can hold missing values.
fill_value : Valid missing value for the promoted dtype.
"""
# N.B. these casting rules should match pandas
if np.issubdtype(dtype, float):
fill_value = np.nan
elif np.issubdtype(dtype, int):
# convert to floating point so NaN is valid
dtype = float
fill_value = np.nan
elif np.issubdtype(dtype, complex):
fill_value = np.nan + np.nan * 1j
elif np.issubdtype(dtype, np.datetime64):
fill_value = np.datetime64('NaT')
elif np.issubdtype(dtype, np.timedelta64):
fill_value = np.timedelta64('NaT')
else:
dtype = object
fill_value = np.nan
return np.dtype(dtype), fill_value


def get_fill_value(dtype):
"""Return an appropriate fill value for this dtype.

Parameters
----------
dtype : np.dtype

Returns
-------
fill_value : Missing value corresponding to this dtype.
"""
_, fill_value = maybe_promote(dtype)
return fill_value


def is_datetime_like(dtype):
"""Check if a dtype is a subclass of the numpy datetime types
"""
return (np.issubdtype(dtype, np.datetime64) or
np.issubdtype(dtype, np.timedelta64))
12 changes: 8 additions & 4 deletions xarray/core/duck_array_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import pandas as pd

from . import npcompat
from . import dtypes
from .pycompat import dask_array_type
from .nputils import nanfirst, nanlast

Expand Down Expand Up @@ -148,13 +149,16 @@ def count(data, axis=None):
return sum(~isnull(data), axis=axis)


def where_method(data, cond, other=np.nan):
"""Select values from this object that are True in cond. Everything else
gets masked with other. Follows normal broadcasting and alignment rules.
"""
def where_method(data, cond, other=dtypes.NA):
if other is dtypes.NA:
other = dtypes.get_fill_value(data.dtype)
return where(cond, data, other)


def fillna(data, other):
return where(isnull(data), other, data)


@contextlib.contextmanager
def _ignore_warnings_if(condition):
if condition:
Expand Down
Loading